Skip to content

Instantly share code, notes, and snippets.

@saforem2
Created January 31, 2025 22:54
Show Gist options
  • Select an option

  • Save saforem2/39d52279a33557adc0972073b2c1a950 to your computer and use it in GitHub Desktop.

Select an option

Save saforem2/39d52279a33557adc0972073b2c1a950 to your computer and use it in GitHub Desktop.
Torchtune fix on Aurora

Patch to get torchtune working on Aurora

diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py
index ff959c5f..c3966290 100644
--- a/torchtune/training/_distributed.py
+++ b/torchtune/training/_distributed.py
@@ -14,7 +14,11 @@ import torch
 import torch.distributed as dist
 from torch import nn

-from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
+try:
+    from torch.distributed._composable.fsdp import fully_shard
+except (ImportError, ModuleNotFoundError):
+    from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
 from torch.distributed._tensor import distribute_tensor, DTensor
 from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
 from torch.distributed.checkpoint.state_dict import (
@@ -532,6 +536,11 @@ def shard_model(
     """
     fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
     if cpu_offload:
+        try:
+            from torch.distributed._composable.fsdp import CPUOffloadPolicy
+        except (ImportError, ModuleNotFoundError):
+            from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy
+            # from torch.distributed._composable import CPUOffloadPolicy
         fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

     # Shard the model with FSDP, iterating in reverse to start with
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment