Last active
August 11, 2025 13:50
-
-
Save gcr/4d8833bb63a85fc8ef1fd77de6622770 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| # Kimmy's Torch Hardware Survey | |
| I'm seeing differences between cpu and mps outputs on MacOS on torch versions >=2.8.0. | |
| Please run the following commands and send me the output, over either signal or via | |
| email to <username kimmy at kjwilber.org>. | |
| ``` | |
| # Install uv if you don't have it already: | |
| which uv || curl -LsSf https://astral.sh/uv/install.sh | sh | |
| # Download this script | |
| cd /tmp | |
| wget https://gist.githubusercontent.com/gcr/4d8833bb63a85fc8ef1fd77de6622770/raw/b9757c592313423c4d99cb4879a09a49cbdb9041/example_calculate_embedding.py | |
| # Newer torch: | |
| uv run --isolated --python=3.11.12 --with torch==2.8.0,transformer-lens==2.9.1,transformers==4.47.0 main.py cpu float32 | |
| uv run --isolated --python=3.11.12 --with torch==2.8.0,transformer-lens==2.9.1,transformers==4.47.0 main.py mps float32 | |
| # Older torch: | |
| uv run --isolated --python=3.11.12 --with torch==2.7.0,transformer-lens==2.9.1,transformers==4.47.0 main.py cpu float32 | |
| uv run --isolated --python=3.11.12 --with torch==2.7.0,transformer-lens==2.9.1,transformers==4.47.0 main.py mps float32 | |
| ``` | |
| On my mac, I get the following output: | |
| ``` | |
| Platform: darwin | |
| Python version: 3.11.12 (main, Apr 8 2025, 14:15:29) [Clang 17.0.0 (clang-1700.0.13.3)] | |
| Torch version: 2.8.0 | |
| Transformerlens version: 2.9.1 | |
| Device: cpu | |
| Dtype: float32 | |
| Layer: blocks.13.hook_resid_pre | |
| Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer | |
| [[ 0.019 -0.277 1.28 ... -0.225 0.829 -0.048] | |
| [-0.023 0.136 0.275 ... 0.003 0.119 -0.224] | |
| [-0.007 -0.008 -0.43 ... -0.07 -0. -0.083] | |
| ... | |
| [ 0.06 0.103 -0.135 ... 0.019 -0.082 -0.213] | |
| [-0.169 0.227 0.387 ... 0.06 0.049 0.075] | |
| [-0.177 0.107 -0.044 ... 0.168 -0.159 -0.137]] torch.float32 cpu | |
| Platform: darwin | |
| Python version: 3.11.12 (main, Apr 8 2025, 14:15:29) [Clang 17.0.0 (clang-1700.0.13.3)] | |
| Torch version: 2.8.0 | |
| Transformerlens version: 2.9.1 | |
| Device: mps | |
| Dtype: float32 | |
| Layer: blocks.13.hook_resid_pre | |
| Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer | |
| [[-0.189 -0.061 0.937 ... -0.25 0.437 -0.198] | |
| [-0.13 0.023 0.086 ... -0.07 0.17 -0.191] | |
| [ 0.053 -0.011 -0.494 ... -0.073 0.17 0.004] | |
| ... | |
| [ 0.001 -0.235 -0.295 ... -0.032 0.023 -0.169] | |
| [-0.054 -0.152 -0.142 ... -0.004 0.019 -0.23 ] | |
| [-0.022 -0.134 0.128 ... -0.101 0.108 -0.098]] torch.float32 mps:0 | |
| ``` | |
| """ | |
| from importlib.metadata import version | |
| import numpy as np | |
| import torch | |
| import transformer_lens | |
| import sys | |
| _DEVICE = sys.argv[1] | |
| _DTYPE = sys.argv[2] | |
| _LAYER = "blocks.13.hook_resid_pre" | |
| print("Platform: ", sys.platform) | |
| print("Python version: ", sys.version) | |
| print("Torch version: ", torch.__version__) | |
| print("Transformerlens version:", version("transformer_lens")) | |
| print("Device: ", _DEVICE) | |
| print("Dtype: ", _DTYPE) | |
| print("Layer: ", _LAYER) | |
| tokens = [128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 2318, 5033, 220, 2366, 20, 271, 128009, 128006, 882, 128007, 271, 9241, 596, 279, 1888, 4717, 30, 128009, 128006, 78191, 128007, 271] | |
| model = transformer_lens.HookedTransformer.from_pretrained_no_processing( | |
| "meta-llama/Llama-3.2-1B-Instruct", | |
| device=_DEVICE, | |
| dtype=_DTYPE, | |
| default_padding_side='left', | |
| ) | |
| model.eval() | |
| torch.set_grad_enabled(False) | |
| logits, cache = model.run_with_cache( | |
| torch.tensor([tokens]), | |
| remove_batch_dim=True, | |
| names_filter=lambda name: _LAYER == name, | |
| ) | |
| arr = cache[_LAYER] | |
| print( | |
| np.array2string(arr.cpu().numpy(), precision=3, suppress_small=True), | |
| arr.dtype, | |
| arr.device, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment