Last active
November 29, 2025 15:37
-
-
Save rodjjo/20e2e842fea9ed58114adb560a4566b6 to your computer and use it in GitHub Desktop.
Keep an ammount of layers of the model on the gpu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import gc | |
| class PartialOffloadMixin: | |
| ''' | |
| usage example: | |
| class MyQwen3ForCausalLM(Qwen3ForCausalLM, PartialOffloadMixin): | |
| LAYERS_KEEP_GPU = 22 | |
| MODEL_ATTR_NAME = "model" | |
| MODEL_LAYERS_ATTR_NAME = "layers" | |
| OFFLOAD_ON_CALL = True | |
| model = MyQwen3ForCausalLM.from_pretrained( | |
| repo_id, | |
| subfolder="text_encoder", | |
| local_files_only=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model.eval() | |
| model.enable_partial_cpu_offload() | |
| # pseudo code of inference | |
| result = model(...) # call was overrided and calls go_gpu(True) go_gpu(False) | |
| example transformer: | |
| class MyZImageTransformer(ZImageTransformer2DModel, PartialOffloadMixin): | |
| MODEL_LAYERS_ATTR_NAME = "layers" | |
| LAYERS_KEEP_GPU = 22 | |
| model = MyZImageTransformer.from_pretrained( | |
| repo_id, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model.eval() | |
| model.enable_partial_cpu_offload() | |
| # denoise step | |
| model.go_gpu(True) | |
| while denoising: #pseudo code | |
| predicted = model(...) | |
| model.go_gpu(False) | |
| # vae decode, etc | |
| ''' | |
| LAYERS_KEEP_GPU = 22 | |
| MODEL_ATTR_NAME = "" | |
| MODEL_LAYERS_ATTR_NAME = "layers" | |
| OFFLOAD_ON_CALL = False | |
| def __call__(self, *args, **kwds): | |
| if self.OFFLOAD_ON_CALL: | |
| self.go_gpu(to_gpu=True) | |
| result = super().__call__(*args, **kwds) | |
| if self.OFFLOAD_ON_CALL: | |
| self.go_gpu(to_gpu=False) | |
| return result | |
| def generate(self, *args, **kwds): | |
| if self.OFFLOAD_ON_CALL: | |
| self.go_gpu(to_gpu=True) | |
| result = super().generate(*args, **kwds) | |
| if self.OFFLOAD_ON_CALL: | |
| self.go_gpu(to_gpu=False) | |
| return result | |
| def get_model(self): | |
| if not self.MODEL_ATTR_NAME: | |
| return self | |
| return getattr(self, self.MODEL_ATTR_NAME) | |
| def all_nn_modules_to_device(self, device): | |
| ''' | |
| check the class model for its attributes, and move all nn.Module attributes to the specified device | |
| ''' | |
| model = self.get_model() | |
| layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME) | |
| setattr(model, self.MODEL_LAYERS_ATTR_NAME, torch.nn.ModuleList()) | |
| if model != self: | |
| self.to(device) | |
| model.to(device) | |
| setattr(model, self.MODEL_LAYERS_ATTR_NAME, layers_obj) | |
| def offload_layers_to_device(self, device): | |
| model = self.get_model() | |
| layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME) | |
| count = int((len(layers_obj) / 20) * 18) | |
| if count > self.LAYERS_KEEP_GPU: | |
| count = self.LAYERS_KEEP_GPU | |
| for i in range(count): | |
| layer = layers_obj[i] | |
| layer.to(device) | |
| def go_gpu(self, to_gpu: bool): | |
| if to_gpu: | |
| self.all_nn_modules_to_device(torch.device("cuda")) | |
| self.offload_layers_to_device(torch.device("cuda")) | |
| else: | |
| self.all_nn_modules_to_device(torch.device("cpu")) | |
| self.offload_layers_to_device(torch.device("cpu")) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def enable_partial_cpu_offload(self): | |
| model = self.get_model() | |
| layer_count = len(getattr(model, self.MODEL_LAYERS_ATTR_NAME)) | |
| layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME) | |
| count = int((layer_count / 20) * 18) | |
| if count > self.LAYERS_KEEP_GPU: | |
| count = self.LAYERS_KEEP_GPU | |
| for i in range(count, layer_count): | |
| self._enable_sequential_cpu_offload(layers_obj[i]) | |
| model = self.get_model() | |
| setattr(self, "sequential_offloaded", True) | |
| def _enable_sequential_cpu_offload(self, module): | |
| import torch | |
| from accelerate import cpu_offload | |
| torch_device = torch.device("cuda") | |
| device_type = torch_device.type | |
| device = torch.device(f"{device_type}:0") | |
| offload_buffers = len(module._parameters) > 0 | |
| cpu_offload(module, device, offload_buffers=offload_buffers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment