Skip to content

Instantly share code, notes, and snippets.

@grave0x
Forked from nboyd/hack.py
Created December 7, 2024 01:33
Show Gist options
  • Select an option

  • Save grave0x/4e66dd7063435ed04263eca856a6749b to your computer and use it in GitHub Desktop.

Select an option

Save grave0x/4e66dd7063435ed04263eca856a6749b to your computer and use it in GitHub Desktop.
t2j + randomness
class FunctionWrapperContextManager:
def __init__(self, module, function_names, callback):
"""
Create a context manager that wraps multiple functions in a module.
Calls to those functions are dispatched to a callback.
Args:
module (module): The module containing the functions to wrap
function_names (list): List of function names to wrap
callback: Function to call when wrapped functions are invoked
"""
self.module = module
self.function_names = function_names
self.custom_callback = callback
# Store original implementations
self.original_implementations = {}
def __enter__(self):
# Wrap each specified function
for func_name in self.function_names:
# Get the original function
original_func = getattr(self.module, func_name)
# Create a wrapped version of the function
@functools.wraps(original_func)
def wrapped_func(*args, **kwargs):
# Call the custom callback with function name
return self.custom_callback(func_name, *args, **kwargs)
# Store the original implementation
self.original_implementations[func_name] = original_func
# Replace the function in the module
setattr(self.module, func_name, wrapped_func)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the original function implementations
for func_name in self.function_names:
setattr(self.module, func_name, self.original_implementations[func_name])
return False # Propagate any exceptions
def _rand_like(key, *args, size = None, device=None):
if size is not None:
return jax.random.uniform(key, size)
input = args[0]
if all(isinstance(s, int) for s in input):
return jax.random.uniform(key, input)
else:
return jax.random.uniform(key, coerce(input).shape)
def _randn_like(key, *args, size = None, device=None):
if size is not None:
return jax.random.normal(key, size)
input = args[0]
if all(isinstance(s, int) for s in input):
return jax.random.normal(key, input)
else:
return jax.random.normal(key, coerce(input).shape)
_RAND_METHODS = {
"rand" : jax.random.uniform,
"randn" : jax.random.normal,
"rand_like" : _rand_like,
"randn_like" : _randn_like
}
@dataclass
class KeyTracker:
key : jax.random.PRNGKey
counter : int = dataclasses.field(default=0)
def __call__(self, func_name, *args, **kwargs):
# assert all(isinstance(s, int) for s in args), "TODO: implement non-integer size"
self.counter += 1
jax_fn = _RAND_METHODS[func_name]
print(f"'[{self.counter}] {func_name}' called with args: {args} \t kwargs: {kwargs}")
print(f"inserting {jax_fn} call!")
if self.key is None:
raise ValueError("No PRNGKey provided, and randomness detected in torch call!")
key, self.key = jax.random.split(self.key, 2)
# hack to handle int sizes
if all(isinstance(s, int) for s in args):
args = [args]
return tree.map(Torchish, jax_fn(key, *args, **kwargs))
def t2j_function(f):
def torchish_non_scalar(input):
return tree.map(lambda x: Torchish(x) if not isinstance(x, (int, float)) else x, input)
def transformed(*args, key=None, **kwargs):
with FunctionWrapperContextManager(
module=torch,
function_names=_RAND_METHODS.keys(),
callback=KeyTracker(key)
):
return jax.tree.map(lambda v: v.value, f(*torchish_non_scalar(args), **torchish_non_scalar(kwargs)))
return transformed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment