#!/usr/bin/env python3
# Copyright 2025 The Emscripten Authors.  All rights reserved.
# Emscripten is available under two separate licenses, the MIT license and the
# University of Illinois/NCSA Open Source License.  Both these licenses can be
# found in the LICENSE file.

"""
Wrapper for 'wasm-split --multi-split' functionality. This script generates a
.manifest file based on the list of user source paths, using source map
information.

This assumes the name section exists in the input wasm file, and also assumes
the sourceMappingURL section exists in the input or a source map file is
separately supplied with --sourcemap. If we have two files a.c and b.c, to
generate a source map and the name section, if you compile and link within a
single command, you can do something like
$ emcc -g2 -gsrouce-map a.c b.c -o result.js
If you want to compile and link in separate commands, you can do
$ emcc -gsource-map a.c -o a.o
$ emcc -gsource-map b.c -o b.o
$ emcc -g2 -gsource-map a.o b.o -o result.js
See https://emscripten.org/docs/porting/Debugging.html for more details.

This takes a wasm file and a paths file as inputs. The paths file defines how
to split modules. The format is similar to the manifest file for wasm-split, but
with paths instead of function names. A module is defined by a name on a line,
followed by paths on subsequent lines. Modules are separated by empty lines.
Module names be written with a colon (:).
For example:
module1:
path/to/a
path/to/b

module2:
path/to/c

This will create two modules, 'module1' and 'module2'. 'module1' will contain
functions from source files under path/to/a and path/to/b. 'module2' will
contain functions from source files under path/to/c.

If a specified path contains another specified path, functions contained in the
inner path will be split as the inner path's module, and the rest of the
functions will be split as the outer path's module. Functions that do not belong
to any of the specified paths will remain in the primary module.

The paths in the paths file can be either absolute or relative, but they should
match those of 'sources' field in the source map file. Sometimes a source map's
'sources' field contains paths relative to a build directory, so source files
may be recorded as '../src/subdir/test.c', for example. In this case, if you
want to split the directory src/subdir, you should list it as ../src/subdir. You
can manually open the source map file and check 'sources' field, but we also an
option to help that. You can do like
$ empath-split --print-sources test.wasm
or
$ empath-split --print-sources --source-map test.wasm.map
to print the list of sources in 'sources' field in the source map. Note that
emscripten's libraries' source files have /emsdk/emscripten prefix, which is a
fake deterministic prefix to produce reproducible builds across platforms.
"""

import argparse
import json
import os
import sys
import tempfile
from pathlib import PurePath

__scriptdir__ = os.path.dirname(os.path.abspath(__file__))
__rootdir__ = os.path.dirname(__scriptdir__)
sys.path.insert(0, __rootdir__)

from tools import building, diagnostics, emsymbolizer, utils, webassembly
from tools.utils import exit_with_error


def parse_args():
  parser = argparse.ArgumentParser(
      description='Split a wasm file based on user paths',
      epilog="""
This is a wrapper for 'wasm-split --multi-split' functionality, so you should
add wasm-split's command line options as well. You should or may want to add
wasm-split options like -o (--output), --out-prefix, -g, and feature
enabling/disabling options. Run 'wasm-split -h' for the list of options. But you
should NOT add --manifest, because this will be generated from this script.
""")
  parser.add_argument('wasm', nargs='?', help='Path to the input wasm file')
  parser.add_argument('paths_file', nargs='?', help='Path to the input file containing paths')
  parser.add_argument('-s', '--sourcemap', help='Force source map file')
  parser.add_argument('-v', '--verbose', action='store_true',
                      help='Print verbose info for debugging this script')
  parser.add_argument('--wasm-split', help='Path to wasm-split executable')
  parser.add_argument('--preserve-manifest', action='store_true',
                      help='Preserve generated manifest file. This sets --verbose too.')
  parser.add_argument('--print-sources', action='store_true',
                      help='Print the list of sources in the source map to help figure out splitting boundaries. Does NOT perform the splitting.')

  args, forwarded_args = parser.parse_known_args()
  if args.preserve_manifest:
    args.verbose = True
  if not args.wasm_split:
    args.wasm_split = utils.find_exe(building.get_binaryen_bin(), 'wasm-split')

  if '--manifest' in forwarded_args:
    parser.error('manifest file will be generated by this script and should not be given')

  if args.print_sources:
    if not args.wasm and not args.sourcemap:
      parser.error('--print-sources requires either wasm or --sourcemap')
    return args, forwarded_args

  if not args.wasm and not args.paths_file:
    parser.error("the following arguments are required: wasm, paths_file")
  if not args.paths_file:
    parser.error("the following arguments are required: paths_file")
  if '-o' not in forwarded_args and '--output' not in forwarded_args:
    parser.error('-o (--output) is required')
  return args, forwarded_args


def check_errors(args):
  if args.wasm and not os.path.isfile(args.wasm):
    exit_with_error(f"'{args.wasm}' was not found or not a file")
  if args.paths_file and not os.path.isfile(args.paths_file):
    exit_with_error(f"'{args.paths_file}' was not found or not a file")

  if args.sourcemap:
    sourcemap = args.sourcemap

  if args.wasm:
    with webassembly.Module(args.wasm) as module:
      if not args.sourcemap:
        if not emsymbolizer.get_sourceMappingURL_section(module):
          exit_with_error('sourceMappingURL section does not exist')
        sourcemap = module.get_sourceMappingURL()
      if not module.has_name_section():
        exit_with_error('Name section does not exist')

  if not os.path.isfile(sourcemap):
    exit_with_error(f"'{sourcemap}' was not found or not a file")
  if not os.path.isfile(args.wasm_split):
    exit_with_error(f"'{args.wasm_split}' was not found or not a file")

  # Check source map validity. Just perform simple checks to make sure mandatory
  # fields exist.
  try:
    with open(sourcemap) as f:
      source_map_data = json.load(f)
  except json.JSONDecodeError:
    exit_with_error(f'Invalid JSON format in file {args.sourcemap}')
  for field in ['version', 'sources', 'mappings']:
    if field not in source_map_data:
      exit_with_error(f"Field '{field}' is missing in the source map")


def get_sourceMappingURL(wasm, arg_sourcemap):
  if arg_sourcemap:
    return arg_sourcemap
  with webassembly.Module(wasm) as module:
    return module.get_sourceMappingURL()


def print_sources(sourcemap):
  with open(sourcemap) as f:
    sources = json.load(f).get('sources')
    assert(isinstance(sources, list))
    for src in sources:
      print(src)


def get_path_to_functions_map(wasm, sourcemap, paths):
  def is_synthesized_func(func):
    # TODO There can be more
    synthesized_names = [
      'main',
      '__wasm_call_ctors',
      '__clang_call_terminate',
    ]
    synthesized_prefixes = [
      'legalstub$',
      'legalfunc$',
      '__cxx_global_',
      '_GLOBAL__',
      'virtual thunk to ',
    ]
    if func in synthesized_names:
      return True
    return func.startswith(tuple(synthesized_prefixes))

  # Compute {func_name: src file} map, and invert it to get
  # {src file: list of functions} map, and construct {path: list of functions}
  # map from it
  with webassembly.Module(wasm) as module:
    funcs = module.get_functions()
    func_names = module.get_function_names()
    assert len(funcs) == len(func_names)

    func_to_src = {}
    src_to_funcs = {}

    sm = emsymbolizer.WasmSourceMap()
    sm.parse(sourcemap)

    for func_name, func in zip(func_names, funcs, strict=True):
      # From the last address, decrement the address by 1 until we find location
      # info with source file information. The reason we do this is to reduce
      # the probability of picking an address where another function is inlined
      # into, picking the inlined function's source.
      # We start from the end because it is simpler; it is harder to compute the
      # first instruction's address, because there is a gap for local types
      # between function offset and the first instruction.
      addr = func.offset + func.size - 1
      while addr > func.offset:
        loc = sm.lookup(addr, func.offset)
        # This means there is no source map mappings for the entire function
        # (because we give func.offset as a lower bound). Exit the loop.
        if not loc:
          break
        # Exit the loop only if a location info with source file information is
        # found. If not, continue the search.
        if loc.source:
          break
        addr -= 1

      if loc and loc.source:
        func_to_src[func_name] = utils.normalize_path(loc.source)
      else:
        if not is_synthesized_func(func_name):
          diagnostics.warn(f"No source file information found in the source map for function '{func_name}'")

    for func_name, src in func_to_src.items():
      if src not in src_to_funcs:
        src_to_funcs[src] = []
      src_to_funcs[src].append(func_name)

  # Visit paths in the reverse sorting order, so that we can process inner paths
  # first.
  # e.g. If we have /a/b and /a/b/c, /a/b/c will come first, so we can assign
  # functions contained in /a/b/c to it first and assign the remaining functions
  # to /a/b.
  visited_funcs = set()
  path_to_funcs = {}
  for path in sorted(paths, reverse=True):
    ppath = PurePath(path)
    path_to_funcs[path] = []
    for src, funcs in src_to_funcs.items():
      psrc = PurePath(src)
      if ppath == psrc or ppath in psrc.parents:
        for func in funcs:
          if func not in visited_funcs:
            visited_funcs.add(func)
            path_to_funcs[path].append(func)
  return path_to_funcs


# 1. Strip whitespaces
# 2. Normalize separators
# 3. Make /a/b/c and /a/b/c/ equivalent
def normalize_path(path):
  return utils.normalize_path(path.strip()).rstrip(os.sep)


def parse_paths_file(paths_file_content):
  module_to_paths = {}
  path_to_module = {}
  cur_module = None
  cur_paths = []

  for line in paths_file_content.splitlines():
    line = line.strip()
    if not line:
      if cur_module:
        if not cur_paths:
          diagnostics.warn(f"Module '{cur_module}' has no paths specified.")
        module_to_paths[cur_module] = cur_paths
        cur_module = None
        cur_paths = []
      continue

    if not cur_module:
      if line[-1] != ':':
        exit_with_error(f'Module name should end with a colon: {line}')
      if len(line) == 1:
        exit_with_error('Module name is empty')
      cur_module = line[:-1]
    else:
      path = normalize_path(line)
      if path in path_to_module:
        exit_with_error("Path '{path}' cannot be assigned to module '{cur_module}; it is already assigned to module '{path_to_module[path]}'")
      cur_paths.append(path)
      path_to_module[path] = cur_module

  if cur_module:
    if not cur_paths:
      diagnostics.warn(f"Module '{cur_module}' has no paths specified.")
    module_to_paths[cur_module] = cur_paths

  if not module_to_paths:
    exit_with_error('The paths file is empty or invalid.')

  return module_to_paths


def main():
  args, forwarded_args = parse_args()
  check_errors(args)

  sourcemap = get_sourceMappingURL(args.wasm, args.sourcemap)
  if args.print_sources:
    print_sources(sourcemap)
    return

  content = utils.read_file(args.paths_file)
  module_to_paths = parse_paths_file(content)

  # Compute {path: list of functions} map
  all_paths = []
  for paths in module_to_paths.values():
    all_paths.extend(paths)
  path_to_funcs = get_path_to_functions_map(args.wasm, sourcemap, all_paths)

  # Write .manifest file
  f = tempfile.NamedTemporaryFile(suffix=".manifest", mode='w+', delete=False)
  manifest = f.name
  try:
    for i, (module, paths) in enumerate(module_to_paths.items()):
      if i != 0: # Unless we are the first entry add a newline separator
        f.write('\n')
      funcs = []
      for path in paths:
        if not path_to_funcs[path]:
          diagnostics.warn(f'{path} does not match any functions')
        funcs += path_to_funcs[path]
      if not funcs:
        diagnostics.warn(f"Module '{module}' does not match any functions")

      if args.verbose:
        print(f'{module}: {len(funcs)} functions')
        for path in paths:
          if path in path_to_funcs:
            print(f'  {path}: {len(path_to_funcs[path])} functions')
          for func in path_to_funcs[path]:
            print('    ' + func)
        print()

      f.write(f'{module}:\n')
      for func in funcs:
        f.write(func + '\n')
    f.close()

    cmd = [args.wasm_split, '--multi-split', args.wasm, '--manifest', manifest]
    if args.verbose:
      # This option is used both in this script and wasm-split
      cmd.append('-v')
    cmd += forwarded_args
    if args.verbose:
      print('\n' + ' '.join(cmd))
    utils.run_process(cmd)
  finally:
    if not args.preserve_manifest:
      os.remove(manifest)


if __name__ == '__main__':
  sys.exit(main())
