Skip to content

Commit

Permalink
Address #4986
Browse files Browse the repository at this point in the history
Signed-off-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tjruwase committed Feb 6, 2025
1 parent e8fc098 commit 2bbb7b4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
1 change: 0 additions & 1 deletion deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self,
self.grad_acc_dtype = grad_acc_dtype

self.immediate_grad_update = bfloat16_config.immediate_grad_update
self.check_overflow = bfloat16_config.check_overflow

self.clip_grad = clip_grad
self.norm_type = norm_type
Expand Down
16 changes: 7 additions & 9 deletions deepspeed/runtime/fp16/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,17 @@ class DynamicLossScaler(LossScalerBase):
"""

def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False,
init_scale,
scale_window,
min_scale,
delayed_shift,
consecutive_hysteresis,
raise_error_at_min_scale=True,
dtype=torch.half):
super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_factor = 2.0
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
Expand Down Expand Up @@ -209,8 +208,7 @@ def update_scale(self, overflow):
# we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling.
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
if dtype == torch.half and dynamic_scaling:
if dynamic_loss_args is None:
return DynamicLossScaler(dtype=dtype)
assert dynamic_loss_args is not None, f"Dynamic loss scaling parameters must be defined."
return DynamicLossScaler(dtype=dtype, **dynamic_loss_args)

loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
Expand Down

0 comments on commit 2bbb7b4

Please sign in to comment.