Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for unsharded parameters in state_dict APIs #2023

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def load_from_full_model_state_dict(
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
Expand Down Expand Up @@ -242,7 +245,9 @@ def gather_cpu_state_dict(
if sharded_param.is_cpu:
# Move back to device if offloaded to CPU
sharded_param = sharded_param.to(device)
if isinstance(sharded_param._local_tensor, NF4Tensor):
if hasattr(sharded_param, "_local_tensor") and isinstance(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the tensor isn't sharded but it still a NF4Tensor, we still need to upcast the datatype as is done on line 271

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so maybe another elif in the case that the tensor is unsharded, but still NF4Tensor, and just need to call line 271?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be much cleaner if at the top of the loop we do:

if hasattr(...):
	# get full tensor (NF4 or .full_tensor)
if full_param is NF4:
	# upcast
if is_rank_zero:
	# the rest the same

This code still seems to assume that the unsharded param is on rank0 which isn't guaranteed

sharded_param._local_tensor, NF4Tensor
):
# NF4Tensor does not support all_gather from DTensor
# so we need to manually all_gather
mesh = sharded_param.device_mesh
Expand All @@ -264,9 +269,16 @@ def gather_cpu_state_dict(
)
# upcasting NF4 to original dtype
full_param = full_param.to(full_param.dtype)
elif isinstance(sharded_param, NF4Tensor):
# upcasting NF4 to original dtype
full_param = sharded_param.to(sharded_param.dtype)
else:
# Gather DTensor
full_param = sharded_param.full_tensor()
if hasattr(sharded_param, "full_tensor"):
# Gather DTensor
full_param = sharded_param.full_tensor()
else:
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
full_param = sharded_param
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
Expand Down
Loading