diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 206172b624..96c9e6f65b 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -207,6 +207,9 @@ def load_from_full_model_state_dict( requires_grad=sharded_meta_param.requires_grad, ) + elif not hasattr(sharded_meta_param, "device_mesh"): + # In cases where parts of the model aren't sharded, some parameters will be plain tensors + sharded_tensor = full_tensor else: sharded_tensor = distribute_tensor( full_tensor, @@ -220,6 +223,30 @@ def load_from_full_model_state_dict( return model.load_state_dict(sharded_sd, strict=strict, assign=True) +def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter: + """ + Manually gather NF4Tensor parameter since it does not support all_gather + """ + mesh = sharded_param.device_mesh + nf4_tensor = sharded_param._local_tensor + quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh) + full_quant_params = [] + for quant_param in quant_params: + d0, *dn = quant_param.shape + shape = (d0 * mesh.get_group().size(), *dn) + full_quant_param = torch.empty( + shape, device=quant_param.device, dtype=quant_param.dtype + ) + dist.all_gather_into_tensor( + full_quant_param, quant_param, mesh.get_group(), async_op=False + ) + full_quant_params.append(full_quant_param) + full_param, _ = nf4_tensor.fsdp_post_all_gather( + full_quant_params, metadata, nf4_tensor.dtype + ) + return full_param + + def gather_cpu_state_dict( sharded_sd: Dict[str, DTensor], # noqa is_rank_zero: bool, @@ -238,39 +265,21 @@ def gather_cpu_state_dict( Dict[str, Any]: State dict on CPU """ cpu_state_dict = {} - for param_name, sharded_param in sharded_sd.items(): - if sharded_param.is_cpu: + for param_name, param in sharded_sd.items(): + if param.is_cpu: # Move back to device if offloaded to CPU - sharded_param = sharded_param.to(device) - if isinstance(sharded_param._local_tensor, NF4Tensor): - # NF4Tensor does not support all_gather from DTensor - # so we need to manually all_gather - mesh = sharded_param.device_mesh - nf4_tensor = sharded_param._local_tensor - quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh) - full_quant_params = [] - for quant_param in quant_params: - d0, *dn = quant_param.shape - shape = (d0 * mesh.get_group().size(), *dn) - full_quant_param = torch.empty( - shape, device=quant_param.device, dtype=quant_param.dtype - ) - dist.all_gather_into_tensor( - full_quant_param, quant_param, mesh.get_group(), async_op=False - ) - full_quant_params.append(full_quant_param) - full_param, _ = nf4_tensor.fsdp_post_all_gather( - full_quant_params, metadata, nf4_tensor.dtype - ) + param = param.to(device) + if hasattr(param, "_local_tensor"): + if isinstance(param._local_tensor, NF4Tensor): + param = _gather_nf4_tensor(param) + else: + # Gather DTensor + param = param.full_tensor() + if isinstance(param, NF4Tensor): # upcasting NF4 to original dtype - full_param = full_param.to(full_param.dtype) - else: - # Gather DTensor - full_param = sharded_param.full_tensor() + param = param.to(param.dtype) if is_rank_zero: - cpu_state_dict[param_name] = full_param.cpu() - else: - del full_param + cpu_state_dict[param_name] = param.cpu() torch.distributed.barrier() return cpu_state_dict