Skip to content

Commit

Permalink
Fix more typos
Browse files Browse the repository at this point in the history
Signed-off-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tjruwase committed Feb 6, 2025
1 parent 7d5be07 commit e8fc098
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ def _configure_bf16_optimizer(self, optimizer):
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
bfloat16_config=self._config.bfloat_config,
bfloat16_config=self._config.bfloat16_config,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,12 @@ def _post_init_method(self, module):

def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
if ds_config.bfloat16_config.enabled and ds_config.fp16_enabled:
if ds_config.bfloat16_config.enabled and ds_config.float16_config.enabled:
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")

if ds_config.bfloat16_config.enabled:
self.dtype = torch.bfloat16
elif ds_config.fp16_enabled:
elif ds_config.float16_config.enabled:
self.dtype = torch.half
else:
self.dtype = torch.float
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/runtime/test_ds_config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

# A test on its own
import deepspeed
from deepspeed.runtime.config import DeepSpeedConfig, get_bfloat16_enabled
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.runtime.precision_config import get_bfloat16_config


class TestBasicConfig(DistributedTest):
Expand Down Expand Up @@ -151,7 +152,7 @@ def test_get_bfloat16_enabled(bf16_key):
"enabled": True,
},
}
assert get_bfloat16_enabled(cfg) == True
assert get_bfloat16_config(cfg).enabled == True


class TestConfigLoad(DistributedTest):
Expand Down

0 comments on commit e8fc098

Please sign in to comment.