diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py
index ff959c5f..c3966290 100644
--- a/torchtune/training/_distributed.py
+++ b/torchtune/training/_distributed.py
@@ -14,7 +14,11 @@ import torch
import torch.distributed as dist
from torch import nn
-from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
+try:
+ from torch.distributed._composable.fsdp import fully_shard
+except (ImportError, ModuleNotFoundError):
+ from torch.distributed._composable.fsdp.fully_shard import fully_shard
+
from torch.distributed._tensor import distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.distributed.checkpoint.state_dict import (
@@ -532,6 +536,11 @@ def shard_model(
"""
fsdp_kwargs = {"reshard_after_forward": reshard_after_forward}
if cpu_offload:
+ try:
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy
+ except (ImportError, ModuleNotFoundError):
+ from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy
+ # from torch.distributed._composable import CPUOffloadPolicy
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
# Shard the model with FSDP, iterating in reverse to start with
Created
January 31, 2025 22:54
-
-
Save saforem2/39d52279a33557adc0972073b2c1a950 to your computer and use it in GitHub Desktop.
Torchtune fix on Aurora
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment