Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Created November 14, 2025 19:00
Show Gist options
  • Select an option

  • Save vanbasten23/1dd9a30a26fda4d7d112cf7611bcbaa4 to your computer and use it in GitHub Desktop.

Select an option

Save vanbasten23/1dd9a30a26fda4d7d112cf7611bcbaa4 to your computer and use it in GitHub Desktop.
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/old_exports.pkl", "rb") as f:
data = pickle.load(f)
exported = export.deserialize(data)
with open("/home/xiowei_google_com/old_weights.pkl", "rb") as f:
weights = pickle.load(f)
positions=jnp.arange(16)
key = jax.random.key(0)
hidden_states=jax.random.normal(key, (16, 1536), dtype=jnp.bfloat16)
# Time it
all_time=[]
for _ in range(20):
start = time.perf_counter_ns()
exported.call(weights, (positions, hidden_states)).block_until_ready()
end = time.perf_counter_ns()
all_time.append(end-start)
print("Running old jax finished in [ns] ", statistics.mean(all_time[5:])) # 4548131.9
# Profile it
# profile_path='/home/xiowei_google_com/myprofiles'
# for i in range(20):
# if i == 5:
# jax.profiler.start_trace(profile_path)
# exported.call(weights, (positions, hidden_states)).block_until_ready()
#
# jax.profiler.stop_trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment