-
-
Save kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 to your computer and use it in GitHub Desktop.
| # License: Apache 2.0 | |
| from typing import Dict, Optional | |
| import struct | |
| import json | |
| import numpy as np | |
| import torch | |
| class MemoryEfficientSafeOpen: | |
| """Memory-efficient reader for safetensors files. | |
| This class provides a memory-efficient way to read tensors from safetensors files | |
| by using memory mapping for large tensors and avoiding unnecessary copies. | |
| """ | |
| def __init__(self, filename): | |
| """Initialize the SafeTensor reader. | |
| Args: | |
| filename (str): Path to the safetensors file to read. | |
| """ | |
| self.filename = filename | |
| self.file = open(filename, "rb") | |
| self.header, self.header_size = self._read_header() | |
| def __enter__(self): | |
| """Enter context manager.""" | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| """Exit context manager and close file.""" | |
| self.file.close() | |
| def keys(self): | |
| """Get all tensor keys in the file. | |
| Returns: | |
| list: List of tensor names (excludes metadata). | |
| """ | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def metadata(self) -> Dict[str, str]: | |
| """Get metadata from the file. | |
| Returns: | |
| Dict[str, str]: Metadata dictionary. | |
| """ | |
| return self.header.get("__metadata__", {}) | |
| def _read_header(self): | |
| """Read and parse the header from the safetensors file. | |
| Returns: | |
| tuple: (header_dict, header_size) containing parsed header and its size. | |
| """ | |
| # Read header size (8 bytes, little-endian unsigned long long) | |
| header_size = struct.unpack("<Q", self.file.read(8))[0] | |
| # Read and decode header JSON | |
| header_json = self.file.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def get_tensor(self, key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): | |
| """Load a tensor from the file with memory-efficient strategies. | |
| **Note:** | |
| If device is 'cuda' , the transfer to GPU is done efficiently using pinned memory and non-blocking transfer. | |
| So you must ensure that the transfer is completed before using the tensor (e.g., by `torch.cuda.synchronize()`). | |
| If the tensor is large (>10MB) and the target device is CUDA, memory mapping with numpy.memmap is used to avoid intermediate copies. | |
| Args: | |
| key (str): Name of the tensor to load. | |
| device (Optional[torch.device]): Target device for the tensor. | |
| dtype (Optional[torch.dtype]): Target dtype for the tensor. | |
| Returns: | |
| torch.Tensor: The loaded tensor. | |
| Raises: | |
| KeyError: If the tensor key is not found in the file. | |
| """ | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| num_bytes = offset_end - offset_start | |
| original_dtype = self._get_torch_dtype(metadata["dtype"]) | |
| target_dtype = dtype if dtype is not None else original_dtype | |
| # Handle empty tensors | |
| if num_bytes == 0: | |
| return torch.empty(metadata["shape"], dtype=target_dtype, device=device) | |
| # Determine if we should use pinned memory for GPU transfer | |
| pin_to_gpu = device is not None and device.type == "cuda" # *** Set to False here to avoid using shared GPU memory *** | |
| non_blocking = device is not None and device.type == "cuda" | |
| # Calculate absolute file offset | |
| tensor_offset = self.header_size + 8 + offset_start # adjust offset by header size | |
| # Memory mapping strategy for large tensors to GPU | |
| # Use memmap for large tensors to avoid intermediate copies. | |
| # If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired. | |
| # So we only use memmap if device is not cpu. | |
| if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu": | |
| # Create memory map for zero-copy reading | |
| mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,)) | |
| byte_tensor = torch.from_numpy(mm) # zero copy | |
| del mm | |
| # Deserialize tensor (view and reshape) | |
| cpu_tensor = self._deserialize_tensor(byte_tensor, metadata) # view and reshape | |
| del byte_tensor | |
| # Pin memory for faster GPU transfer | |
| if pin_to_gpu: | |
| cpu_tensor = cpu_tensor.pin_memory() | |
| # Transfer to target device and dtype | |
| gpu_tensor = cpu_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) | |
| del cpu_tensor | |
| return gpu_tensor | |
| # Standard file reading strategy for smaller tensors or CPU target | |
| # seek to the specified position | |
| self.file.seek(tensor_offset) | |
| # read directly into a numpy array by numpy.fromfile without intermediate copy | |
| numpy_array = np.fromfile(self.file, dtype=np.uint8, count=num_bytes) | |
| byte_tensor = torch.from_numpy(numpy_array) | |
| del numpy_array | |
| # deserialize (view and reshape) | |
| deserialized_tensor = self._deserialize_tensor(byte_tensor, metadata) | |
| del byte_tensor | |
| # Pin memory for GPU transfer if needed | |
| if pin_to_gpu: | |
| deserialized_tensor = deserialized_tensor.pin_memory() | |
| # cast to target dtype and move to device | |
| return deserialized_tensor.to(device=device, dtype=target_dtype, non_blocking=non_blocking) | |
| def _deserialize_tensor(self, byte_tensor: torch.Tensor, metadata: Dict): | |
| """Deserialize byte tensor to the correct shape and dtype. | |
| Args: | |
| byte_tensor (torch.Tensor): Raw byte tensor from file. | |
| metadata (Dict): Tensor metadata containing dtype and shape info. | |
| Returns: | |
| torch.Tensor: Deserialized tensor with correct shape and dtype. | |
| """ | |
| dtype = self._get_torch_dtype(metadata["dtype"]) | |
| shape = metadata["shape"] | |
| # Handle special float8 types | |
| if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
| return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
| # Standard conversion: view as target dtype and reshape | |
| return byte_tensor.view(dtype).reshape(shape) | |
| @staticmethod | |
| def _get_torch_dtype(dtype_str): | |
| """Convert string dtype to PyTorch dtype. | |
| Args: | |
| dtype_str (str): String representation of the dtype. | |
| Returns: | |
| torch.dtype: Corresponding PyTorch dtype. | |
| """ | |
| # Standard dtype mappings | |
| dtype_map = { | |
| "F64": torch.float64, | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| # Add float8 types if available in PyTorch version | |
| if hasattr(torch, "float8_e5m2"): | |
| dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
| if hasattr(torch, "float8_e4m3fn"): | |
| dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
| return dtype_map.get(dtype_str) | |
| @staticmethod | |
| def _convert_float8(byte_tensor, dtype_str, shape): | |
| """Convert byte tensor to float8 format if supported. | |
| Args: | |
| byte_tensor (torch.Tensor): Raw byte tensor. | |
| dtype_str (str): Float8 dtype string ("F8_E5M2" or "F8_E4M3"). | |
| shape (tuple): Target tensor shape. | |
| Returns: | |
| torch.Tensor: Tensor with float8 dtype. | |
| Raises: | |
| ValueError: If float8 type is not supported in current PyTorch version. | |
| """ | |
| # Convert to specific float8 types if available | |
| if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
| return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
| elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
| return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
| else: | |
| # Float8 not supported in this PyTorch version | |
| # # convert to float16 if float8 is not supported | |
| # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
| # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
| raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") |
| import unittest | |
| import torch | |
| import tempfile | |
| import os | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| from mem_eff_safeopen import MemoryEfficientSafeOpen | |
| class TestMemoryEfficientSafeOpen(unittest.TestCase): | |
| def setUp(self): | |
| self.test_tensors = { | |
| "float32": torch.randn(10, 20).float(), | |
| "float16": torch.randn(5, 15).half(), | |
| "int64": torch.randint(-100, 100, (8, 12)).long(), | |
| "bool": torch.randint(0, 2, (6, 6)).bool(), | |
| "empty": torch.empty(0, 10), | |
| "scalar": torch.tensor(3.14), | |
| } | |
| if hasattr(torch, "bfloat16"): | |
| self.test_tensors["bfloat16"] = torch.randn(7, 9).to(torch.bfloat16) | |
| if hasattr(torch, "float8_e5m2"): | |
| self.test_tensors["float8_e5m2"] = torch.randn(4, 8).to(torch.float8_e5m2) | |
| if hasattr(torch, "float8_e4m3fn"): | |
| self.test_tensors["float8_e4m3fn"] = torch.randn(3, 7).to(torch.float8_e4m3fn) | |
| def test_tensor_loading(self): | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp_filename = tmp.name | |
| try: | |
| # 1. テスト用の.safetensorsファイルを作成 | |
| save_file(self.test_tensors, tmp_filename) | |
| # 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
| with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
| official_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
| with MemoryEfficientSafeOpen(tmp_filename) as f: | |
| efficient_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
| # 3. 各テンソルについて比較 | |
| for key in self.test_tensors.keys(): | |
| dtype = self.test_tensors[key].dtype | |
| if "float8" in str(dtype): | |
| # float8型の場合はtorch.allcloseが使えないので、要素ごとに比較 | |
| for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)): | |
| self.assertAlmostEqual(a.item(), b.item(), delta=1e-2) | |
| else: | |
| self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
| self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
| self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
| finally: | |
| os.unlink(tmp_filename) | |
| def test_tensor_loading_dtype(self): | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp_filename = tmp.name | |
| dtype = torch.float16 | |
| try: | |
| # 1. テスト用の.safetensorsファイルを作成 | |
| save_file(self.test_tensors, tmp_filename) | |
| # 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
| with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
| official_tensors = {key: f.get_tensor(key).to(dtype) for key in f.keys()} | |
| with MemoryEfficientSafeOpen(tmp_filename) as f: | |
| efficient_tensors = {key: f.get_tensor(key, dtype=dtype) for key in f.keys()} | |
| # 3. 各テンソルについて比較 | |
| for key in self.test_tensors.keys(): | |
| dtype = self.test_tensors[key].dtype | |
| self.assertEqual(efficient_tensors[key].dtype, torch.float16) | |
| self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
| self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
| self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
| finally: | |
| os.unlink(tmp_filename) | |
| def test_memory_efficiency(self): | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp_filename = tmp.name | |
| try: | |
| # 大きなテンソルを作成 | |
| num_tensors = 10 | |
| large_tensors = {f"large_{i}": torch.randn(10000, 1000) for i in range(num_tensors)} | |
| save_file(large_tensors, tmp_filename) | |
| # メモリ使用量を測定(簡易的な方法) | |
| import psutil | |
| import gc | |
| process = psutil.Process() | |
| def get_memory_usage(): | |
| return process.memory_info().rss / 1024 / 1024 # MB単位 | |
| # 公式safetensorsでの読み込み | |
| gc.collect() | |
| mem_before = get_memory_usage() | |
| with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| t = f.get_tensor(key) | |
| t = t.mul(2) # 何か操作を行い実際にメモリに読み込む | |
| del t | |
| gc.collect() | |
| mem_after_official = get_memory_usage() | |
| # MemoryEfficientSafeOpenでの読み込み | |
| gc.collect() | |
| mem_before = get_memory_usage() | |
| with MemoryEfficientSafeOpen(tmp_filename) as f: | |
| for key in f.keys(): | |
| t = f.get_tensor(key) | |
| t = t.mul(2) # すでに読み込まれている | |
| del t | |
| gc.collect() | |
| mem_after_efficient = get_memory_usage() | |
| # メモリ使用量の比較 | |
| self.assertLess(mem_after_efficient - mem_before, mem_after_official - mem_before) | |
| finally: | |
| os.unlink(tmp_filename) | |
| def test_cuda_device(self): | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp_filename = tmp.name | |
| if not torch.cuda.is_available(): | |
| self.skipTest("CUDAが利用できないためスキップ") | |
| device = torch.device("cuda") | |
| try: | |
| # 1. 大きなテンソルを作成:MemoryEfficientSafeOpenのCUDAサポートは大きなテンソルでのみ実行される | |
| test_tensors = {} | |
| for i, (key, tensor) in enumerate(self.test_tensors.items()): | |
| test_tensors[f"large_{i}"] = ( | |
| torch.randn(10000, 1000, dtype=tensor.dtype) | |
| if tensor.dtype.is_floating_point and tensor.dtype.itemsize >= 2 | |
| else torch.randint(-100, 100, (10000, 1000)).to(tensor.dtype) # supports int, fp8 and bool | |
| ) | |
| # いくつかの小さいテンソルも追加 | |
| test_tensors.update({f"small_{i}": torch.randn(10, 10) for i in range(5)}) | |
| save_file(test_tensors, tmp_filename) | |
| # 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
| with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
| official_tensors = {key: f.get_tensor(key).to(device) for key in f.keys()} | |
| with MemoryEfficientSafeOpen(tmp_filename) as f: | |
| efficient_tensors = {key: f.get_tensor(key, device=device) for key in f.keys()} | |
| # 3. 各テンソルについて比較 | |
| for key in test_tensors.keys(): | |
| dtype = test_tensors[key].dtype | |
| if "float8" in str(dtype): | |
| # # float8型の場合はtorch.allcloseが使えないので、要素ごとに比較 | |
| # for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)): | |
| # self.assertAlmostEqual(a.item(), b.item(), delta=1e-2) | |
| # 大きいテンソルだと要素ごとの比較は時間がかかるので、float16に変換して比較 | |
| official_fp16 = official_tensors[key].to(torch.float16) | |
| efficient_fp16 = efficient_tensors[key].to(torch.float16) | |
| self.assertTrue(torch.allclose(official_fp16, efficient_fp16, atol=1e-2, rtol=1e-2)) | |
| else: | |
| self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
| self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
| self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
| self.assertEqual(official_tensors[key].device, efficient_tensors[key].device) | |
| self.assertEqual(official_tensors[key].device.type, "cuda") | |
| finally: | |
| os.unlink(tmp_filename) | |
| if __name__ == "__main__": | |
| unittest.main() |
| # License: Apache 2.0 | |
| import torch | |
| import json | |
| import struct | |
| from typing import Dict, Any | |
| def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
| _TYPES = { | |
| torch.float64: "F64", | |
| torch.float32: "F32", | |
| torch.float16: "F16", | |
| torch.bfloat16: "BF16", | |
| torch.int64: "I64", | |
| torch.int32: "I32", | |
| torch.int16: "I16", | |
| torch.int8: "I8", | |
| torch.uint8: "U8", | |
| torch.bool: "BOOL", | |
| getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
| getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
| } | |
| _ALIGN = 256 | |
| def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
| validated = {} | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
| if not isinstance(value, str): | |
| print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
| validated[key] = str(value) | |
| else: | |
| validated[key] = value | |
| return validated | |
| header = {} | |
| offset = 0 | |
| if metadata: | |
| header["__metadata__"] = validate_metadata(metadata) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: # empty tensor | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
| else: | |
| size = v.numel() * v.element_size() | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
| offset += size | |
| hjson = json.dumps(header).encode("utf-8") | |
| hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
| with open(filename, "wb") as f: | |
| f.write(struct.pack("<Q", len(hjson))) | |
| f.write(hjson) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: | |
| continue | |
| if v.is_cuda: | |
| # Direct GPU to disk save | |
| with torch.cuda.device(v.device): | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| tensor_bytes = v.contiguous().view(torch.uint8) | |
| tensor_bytes.cpu().numpy().tofile(f) | |
| else: | |
| # CPU tensor save | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| v.contiguous().view(torch.uint8).numpy().tofile(f) | |
| # Usage example | |
| if __name__ == "__main__": | |
| # Create some example tensors on GPU | |
| tensors = {"weight": torch.randn(1000, 1000, device="cuda"), "bias": torch.randn(1000, device="cuda")} | |
| metadata = {"model_type": "example", "version": "1.0"} | |
| mem_eff_save_file(tensors, "model.safetensors", metadata) |
| # License: Apache 2.0 | |
| import unittest | |
| import torch | |
| import os | |
| import tempfile | |
| from safetensors.torch import load_file as official_load_file | |
| from safetensors import safe_open | |
| from mem_eff_save_file import mem_eff_save_file # あなたの実装 | |
| class TestCompatibilityWithOfficialSafetensors(unittest.TestCase): | |
| def setUp(self): | |
| self.temp_dir = tempfile.mkdtemp() | |
| def tearDown(self): | |
| for file in os.listdir(self.temp_dir): | |
| os.remove(os.path.join(self.temp_dir, file)) | |
| os.rmdir(self.temp_dir) | |
| def assert_tensors_equal(self, tensor1, tensor2): | |
| self.assertTrue(torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8), f"Tensors are not equal: {tensor1} vs {tensor2}") | |
| def test_compatibility_cpu_tensor(self): | |
| tensor = torch.randn(100, 100) | |
| tensors = {"test": tensor} | |
| file_path = os.path.join(self.temp_dir, "custom_cpu.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| def test_compatibility_not_contiguous_cpu_tensor(self): | |
| tensor = torch.randn(100, 100) | |
| tensor = tensor[:, ::2] | |
| tensors = {"test": tensor} | |
| assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
| file_path = os.path.join(self.temp_dir, "custom_not_contiguous_cpu.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
| def test_compatibility_gpu_tensor(self): | |
| tensor = torch.randn(100, 100, device="cuda") | |
| tensors = {"test": tensor} | |
| file_path = os.path.join(self.temp_dir, "custom_gpu.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
| def test_compatibility_not_contiguous_gpu_tensor(self): | |
| tensor = torch.randn(100, 100, device="cuda") | |
| tensor = tensor[:, ::2] | |
| tensors = {"test": tensor} | |
| assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
| file_path = os.path.join(self.temp_dir, "custom_not_contiguous_gpu.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
| def test_compatibility_multiple_tensors(self): | |
| tensors = {"weight": torch.randn(100, 100), "bias": torch.randn(100)} | |
| file_path = os.path.join(self.temp_dir, "custom_multiple.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| def test_compatibility_with_empty_tensors(self): | |
| tensors = {"empty": torch.tensor([]), "zero_dim": torch.tensor(1)} | |
| file_path = os.path.join(self.temp_dir, "custom_empty.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| def test_compatibility_different_dtypes(self): | |
| tensors = { | |
| "float32": torch.randn(10, 10, dtype=torch.float32), | |
| "float16": torch.randn(10, 10, dtype=torch.float16), | |
| "int32": torch.randint(0, 10, (10, 10), dtype=torch.int32), | |
| } | |
| file_path = os.path.join(self.temp_dir, "custom_dtypes.safetensors") | |
| mem_eff_save_file(tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| self.assertEqual(tensors[key].dtype, loaded_tensors[key].dtype) | |
| def test_compatibility_with_metadata(self): | |
| tensor = torch.randn(10, 10) | |
| tensors = {"test": tensor} | |
| metadata = {"model_type": "test", "version": "1.0"} | |
| file_path = os.path.join(self.temp_dir, "custom_metadata.safetensors") | |
| mem_eff_save_file(tensors, file_path, metadata) | |
| from safetensors import safe_open | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| # load metadata from .safetensors in official implementation | |
| with safe_open(file_path, framework="pt") as f: | |
| official_metadata = f.metadata() | |
| self.assertEqual(metadata, official_metadata) | |
| def test_compatibility_with_metadata_not_str_to_str(self): | |
| tensor = torch.randn(10, 10) | |
| tensors = {"test": tensor} | |
| metadata = {"model_type": "test", "version": 1.0} | |
| file_path = os.path.join(self.temp_dir, "custom_metadata_not_str_to_str.safetensors") | |
| mem_eff_save_file(tensors, file_path, metadata) | |
| from safetensors import safe_open | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
| for key in tensors: | |
| self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
| # load metadata from .safetensors in official implementation | |
| with safe_open(file_path, framework="pt") as f: | |
| official_metadata = f.metadata() | |
| self.assertEqual({"model_type": "test", "version": "1.0"}, official_metadata) | |
| def test_large_model_compatibility(self): | |
| # 大規模なモデルをシミュレート | |
| large_tensors = {f"layer_{i}": torch.randn(1000, 1000) for i in range(10)} | |
| file_path = os.path.join(self.temp_dir, "large_model.safetensors") | |
| mem_eff_save_file(large_tensors, file_path) | |
| loaded_tensors = official_load_file(file_path) | |
| self.assertEqual(set(large_tensors.keys()), set(loaded_tensors.keys())) | |
| for key in large_tensors: | |
| self.assert_tensors_equal(large_tensors[key], loaded_tensors[key]) | |
| if __name__ == "__main__": | |
| unittest.main() |
お役に立ったようで幸いです。読み込み時のメモリ消費を削減するスクリプトもありますので、必要なら公開いたします。ご連絡いただければ幸いです。/ I'm glad it was helpful. I also have a script to reduce memory consumption when loading, so I'll release it if necessary. Please let me know.
せっかくなので読み込みも付けておきました。各ファイルにライセンスも追記しました。動作は無保証ですので、ご理解の上ご利用ください。 / I've added a loading feature as well. I've also included license information in each file. Please note that this code comes with no warranty - use it at your own discretion.
ありがとうございます!早速組み込ませていただきました! / Thank you very much! I have incorporated it!
MemoryEfficientSafeOpenにメタデータ読み込みの機能を追加しました。
またNumpyを使用することで読み込みを高速化しました。
テンソルをすぐにGPUへ移す場合、get_tensorにdeviceを指定してください。このとき転送は非同期で行われるため、読み込みが終わったらtorch.cuda.synchronize()を呼び出してください。
読み込み時にテンソルをキャストする場合はdtypeを指定してください(こちらの指定のみならsynchronizeは不要)。
device転送を行う場合で5倍ほど、行わない場合で1.5倍程度、高速化されると思います。
pin_memoryをすると共有メモリが消費されるので、それを避けたい場合はpin_to_gpu = device is not None and device.type == "cuda" # *** Set to False here to avoid using shared GPU memory ***の行をpin_to_gpu = Falseとしてください。
(速度向上が限定的になります。)
こちらのコードをKijai氏のfp8化コードと組み合わせて、省メモリのfp8化スクリプトを作成しました。コードを共有してくださりありがとうございます!/ I combined this code with Kijai's fp8 conversion code to create a memory-efficient fp8 conversion script. Thank you for sharing the code!