Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for unsharded parameters in state_dict APIs #2023

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def load_from_full_model_state_dict(
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
Expand All @@ -220,6 +223,30 @@ def load_from_full_model_state_dict(
return model.load_state_dict(sharded_sd, strict=strict, assign=True)


def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
"""
Manually gather NF4Tensor parameter since it does not support all_gather
"""
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
return full_param


def gather_cpu_state_dict(
sharded_sd: Dict[str, DTensor], # noqa
is_rank_zero: bool,
Expand All @@ -238,39 +265,21 @@ def gather_cpu_state_dict(
Dict[str, Any]: State dict on CPU
"""
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
if sharded_param.is_cpu:
for param_name, param in sharded_sd.items():
if param.is_cpu:
# Move back to device if offloaded to CPU
sharded_param = sharded_param.to(device)
if isinstance(sharded_param._local_tensor, NF4Tensor):
# NF4Tensor does not support all_gather from DTensor
# so we need to manually all_gather
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
param = param.to(device)
if hasattr(param, "_local_tensor"):
if isinstance(param._local_tensor, NF4Tensor):
param = _gather_nf4_tensor(param)
else:
# Gather DTensor
param = param.full_tensor()
if isinstance(param, NF4Tensor):
# upcasting NF4 to original dtype
full_param = full_param.to(full_param.dtype)
else:
# Gather DTensor
full_param = sharded_param.full_tensor()
param = param.to(param.dtype)
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
return cpu_state_dict

Expand Down
Loading