Created
September 6, 2023 10:16
-
-
Save chenbaiyujason/4c4d31b12e23d300712017dd77c23d19 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from safetensors.torch import load_file,save_file | |
| import mmap | |
| import torch | |
| import json | |
| import os | |
| def load_metadata(filename): | |
| with open(filename, mode="r", encoding="utf8") as file_obj: | |
| with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: | |
| header = m.read(8) | |
| n = int.from_bytes(header, "little") | |
| metadata_bytes = m.read(n) | |
| metadata = json.loads(metadata_bytes) | |
| print("_____________________") | |
| print(metadata) | |
| print("_____________________") | |
| return metadata["__metadata__"] | |
| DTYPES = {"bf16": torch.bfloat16} | |
| device = "cpu" | |
| file_path = "inputmodel.safetensors" | |
| out_path = "outmodel.safetensors" | |
| whitelist = ["ss_base_model_version","author","about"] | |
| def main(): | |
| # load_file(file_path,device="cpu") | |
| loaded = load_file(file_path,device="cpu") | |
| meta = load_metadata(file_path) | |
| meta["author"] = "ShiChen https://huggingface.co/shichen1231" | |
| meta["about"] = "You can describe what you want here" | |
| for key in list(meta.keys()): | |
| if key not in whitelist: | |
| del meta[key] | |
| save_file(loaded , out_path,metadata=meta) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment