Last active
March 12, 2026 19:50
-
-
Save aurotripathy/e48f20c0f5a366d80d53626c17cb80f4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import os | |
| import onnx | |
| import onnx_safetensors | |
| INPUT_DIR = "onnx-files" # onnx files have to go here | |
| OUTPUT_DIR = "safetensors-files" # out put files will generated here | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Convert an ONNX model to a safetensors file and optionally " | |
| "create an ONNX model that uses the safetensors file as external data." | |
| ) | |
| parser.add_argument( | |
| "--onnx-model-name", | |
| type=str, | |
| required=True, | |
| help=( | |
| "Name of the input ONNX model (with or without .onnx extension). " | |
| f"The model is loaded from `{INPUT_DIR}` and all outputs are written to `{OUTPUT_DIR}`." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default=OUTPUT_DIR, | |
| help=f"Directory where the safetensors file and output ONNX model will be written. Defaults to `{OUTPUT_DIR}`.", | |
| ) | |
| parser.add_argument( | |
| "--only-create-safetensors", | |
| action="store_true", | |
| help="If set, only create the safetensors file and do not write a modified ONNX model.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| print(args) | |
| # Ensure the output directory exists | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Resolve base directory used for external data references | |
| base_dir = args.output_dir or OUTPUT_DIR | |
| # Full path to the input ONNX model | |
| model_path = os.path.join(INPUT_DIR, args.onnx_model_name) | |
| # Load the model into memory (ModelProto object) | |
| model = onnx.load(model_path) | |
| # Safetensors file is keyed off the ONNX model name | |
| data_path = f"{args.onnx_model_name}.safetensors" | |
| # data_path = os.path.join(base_dir, safetensors_filename) | |
| # Offload weights from ONNX model to safetensors file without changing the model | |
| print(f"data-path: {data_path}") | |
| print(f"out-dir: {base_dir}") | |
| onnx_safetensors.save_file( | |
| model, | |
| data_path, | |
| base_dir=base_dir, | |
| replace_data=False, | |
| ) | |
| print(f"Safetensors file written to: {data_path}") | |
| if args.only_create_safetensors: | |
| return | |
| # Offload weights from ONNX model to safetensors file and use it as external data | |
| model_with_external_data = onnx_safetensors.save_file( | |
| model, | |
| data_path, | |
| base_dir=base_dir, | |
| replace_data=True, | |
| ) | |
| # Determine output ONNX path, keyed off the ONNX model name | |
| base_name = os.path.basename(model_path) | |
| output_name = f"{base_name}.model_using_safetensors.onnx" | |
| output_model_path = os.path.join(OUTPUT_DIR, output_name) | |
| # Save the modified model | |
| onnx.save(model_with_external_data, output_model_path) | |
| print(f"ONNX model with external safetensors data written to: {output_model_path}") | |
| if __name__ == "__main__": | |
| main() |
Author
aurotripathy
commented
Mar 12, 2026
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment