Created
April 2, 2025 15:38
-
-
Save ev-br/ea880b7ed114e977202c5121caae99ad to your computer and use it in GitHub Desktop.
Parse/output the backend listsings from instrumented test runs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Parse the backend registry JSON, from a pytest run with instrumented imports. | |
| In instrumentation is at https://github.com/ev-br/scipy/tree/messing_with_import_hooks | |
| """ | |
| import json | |
| import argparse | |
| import importlib | |
| MODULE_NAME = "scipy.signal" | |
| def is_public_function(full_name, module): | |
| """Check if a function is a public function in a module.""" | |
| func_name = shorten(full_name) | |
| return getattr(module, func_name, None) | |
| def shorten(full_name): | |
| return full_name.split('.')[-1] | |
| def get_backend_name(full_name): | |
| if full_name.endswith("jax.numpy"): | |
| return "jax.numpy" | |
| elif full_name.endswith("dask.array"): | |
| return "dask.array" | |
| else: | |
| return shorten(full_name) | |
| def parse_docstring(docstring): | |
| """Get function names from the module docstring. For pretty-printing | |
| """ | |
| lines = docstring.split("\n") | |
| groups = {} | |
| current_group = [] | |
| heading = "empty" | |
| for j in range(3, len(lines)): # skip top three lines | |
| line = lines[j] | |
| if not line: | |
| continue | |
| # detect a heading | |
| if j < len(lines): | |
| if line[0] != " " and lines[j+1].startswith("="): | |
| # new heading starts | |
| groups[heading] = current_group | |
| heading = line.strip() | |
| current_group = [] | |
| if line.startswith(" ") and ":toctree:" not in line: | |
| fname = line.split()[0].strip() | |
| if fname != '--': | |
| current_group.append(fname) | |
| # add the last group | |
| groups[heading] = current_group | |
| return groups | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('items', nargs='+', help="List of json files (need at least one).") | |
| args = parser.parse_args() | |
| # load and combine the registries | |
| reg = {} | |
| for fname in args.items: | |
| with open(fname) as f: | |
| dct = json.load(f) | |
| for key in dct: | |
| if key in reg: | |
| reg[key] += dct[key] | |
| else: | |
| reg[key] = dct[key] | |
| # filter public names | |
| module = importlib.import_module(MODULE_NAME) | |
| reg_func = {} | |
| for full_name in reg: | |
| if func := is_public_function(full_name, module): | |
| backends = set( | |
| get_backend_name(backend) | |
| for backend in reg[full_name] | |
| if backend != "multiple" | |
| ) | |
| backends = list(backends) | |
| reg_func[func] = backends | |
| # output | |
| for func, backends in reg_func.items(): | |
| print(f"{module.__name__}.{func.__name__} : {backends}") | |
| # pretty-print now | |
| groups = parse_docstring(module.__doc__) | |
| for heading, names in groups.items(): | |
| print(f"\n\n{heading}") | |
| print("="*len(heading), "\n") | |
| for short_name in names: | |
| func = getattr(module, short_name) | |
| if backends := reg_func.get(func, '---'): | |
| print(f"{module.__name__}.{func.__name__} : {backends}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment