Last active
December 7, 2024 01:34
-
-
Save nboyd/99a575009d39029e5a8c23e308824429 to your computer and use it in GitHub Desktop.
t2j + randomness
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class FunctionWrapperContextManager: | |
| def __init__(self, module, function_names, callback): | |
| """ | |
| Create a context manager that wraps multiple functions in a module. | |
| Calls to those functions are dispatched to a callback. | |
| Args: | |
| module (module): The module containing the functions to wrap | |
| function_names (list): List of function names to wrap | |
| callback: Function to call when wrapped functions are invoked | |
| """ | |
| self.module = module | |
| self.function_names = function_names | |
| self.custom_callback = callback | |
| # Store original implementations | |
| self.original_implementations = {} | |
| def __enter__(self): | |
| # Wrap each specified function | |
| for func_name in self.function_names: | |
| # Get the original function | |
| original_func = getattr(self.module, func_name) | |
| # Create a wrapped version of the function | |
| @functools.wraps(original_func) | |
| def wrapped_func(*args, **kwargs): | |
| # Call the custom callback with function name | |
| return self.custom_callback(func_name, *args, **kwargs) | |
| # Store the original implementation | |
| self.original_implementations[func_name] = original_func | |
| # Replace the function in the module | |
| setattr(self.module, func_name, wrapped_func) | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| # Restore the original function implementations | |
| for func_name in self.function_names: | |
| setattr(self.module, func_name, self.original_implementations[func_name]) | |
| return False # Propagate any exceptions | |
| def _rand_like(key, *args, size = None, device=None): | |
| if size is not None: | |
| return jax.random.uniform(key, size) | |
| input = args[0] | |
| if all(isinstance(s, int) for s in input): | |
| return jax.random.uniform(key, input) | |
| else: | |
| return jax.random.uniform(key, coerce(input).shape) | |
| def _randn_like(key, *args, size = None, device=None): | |
| if size is not None: | |
| return jax.random.normal(key, size) | |
| input = args[0] | |
| if all(isinstance(s, int) for s in input): | |
| return jax.random.normal(key, input) | |
| else: | |
| return jax.random.normal(key, coerce(input).shape) | |
| _RAND_METHODS = { | |
| "rand" : jax.random.uniform, | |
| "randn" : jax.random.normal, | |
| "rand_like" : _rand_like, | |
| "randn_like" : _randn_like | |
| } | |
| @dataclass | |
| class KeyTracker: | |
| key : jax.random.PRNGKey | |
| counter : int = dataclasses.field(default=0) | |
| def __call__(self, func_name, *args, **kwargs): | |
| self.counter += 1 | |
| jax_fn = _RAND_METHODS[func_name] | |
| print(f"'[{self.counter}] {func_name}' called with args: {args} \t kwargs: {kwargs}") | |
| print(f"inserting {jax_fn} call!") | |
| if self.key is None: | |
| raise ValueError("No PRNGKey provided, and randomness detected in torch call!") | |
| key, self.key = jax.random.split(self.key, 2) | |
| # hack to handle int sizes | |
| if all(isinstance(s, int) for s in args): | |
| args = [args] | |
| return tree.map(Torchish, jax_fn(key, *args, **kwargs)) | |
| def t2j_function(f): | |
| def torchish_non_scalar(input): | |
| return tree.map(lambda x: Torchish(x) if not isinstance(x, (int, float)) else x, input) | |
| def transformed(*args, key=None, **kwargs): | |
| with FunctionWrapperContextManager( | |
| module=torch, | |
| function_names=_RAND_METHODS.keys(), | |
| callback=KeyTracker(key) | |
| ): | |
| return jax.tree.map(lambda v: v.value, f(*torchish_non_scalar(args), **torchish_non_scalar(kwargs))) | |
| return transformed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment