Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save chenbaiyujason/4c4d31b12e23d300712017dd77c23d19 to your computer and use it in GitHub Desktop.

Select an option

Save chenbaiyujason/4c4d31b12e23d300712017dd77c23d19 to your computer and use it in GitHub Desktop.
from safetensors.torch import load_file,save_file
import mmap
import torch
import json
import os
def load_metadata(filename):
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
print("_____________________")
print(metadata)
print("_____________________")
return metadata["__metadata__"]
DTYPES = {"bf16": torch.bfloat16}
device = "cpu"
file_path = "inputmodel.safetensors"
out_path = "outmodel.safetensors"
whitelist = ["ss_base_model_version","author","about"]
def main():
# load_file(file_path,device="cpu")
loaded = load_file(file_path,device="cpu")
meta = load_metadata(file_path)
meta["author"] = "ShiChen https://huggingface.co/shichen1231"
meta["about"] = "You can describe what you want here"
for key in list(meta.keys()):
if key not in whitelist:
del meta[key]
save_file(loaded , out_path,metadata=meta)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment