Skip to content

Commit

Permalink
Fix nightly tests for qat_lora_fintune_distributed (#2085)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewor14 authored Nov 27, 2024
1 parent 160fd96 commit ecf8d22
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/recipes/test_qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_test_config_overrides(self):

def _fetch_expected_loss_values(self, model_type):
loss_values_map = {
"llama3": [11.9325, 11.9325, 11.9325, 11.9369],
"llama3": [11.9835, 11.9694, 11.9615, 11.9383],
}
return loss_values_map[model_type]

Expand All @@ -66,6 +66,7 @@ def test_loss(
):
ckpt = "llama3_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS["llama3"])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)
cmd = f"""
Expand All @@ -80,11 +81,12 @@ def test_loss(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
metric_logger.filename={log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.path={tokenizer_path} \
tokenizer.prompt_template=null \
compile={should_compile} \
enable_activation_checkpointing=False \
enable_activation_offloading=False \
quantizer.groupsize=32 \
""".split()

model_config = MODEL_TEST_CONFIGS["llama3_lora"]
Expand Down Expand Up @@ -154,6 +156,7 @@ def test_training_state_on_resume(
save_adapter_weights_only={save_adapter_weights_only} \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
quantizer.groupsize=32 \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down Expand Up @@ -182,6 +185,7 @@ def test_training_state_on_resume(
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
quantizer.groupsize=32 \
""".split()

cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
Expand Down Expand Up @@ -228,6 +232,7 @@ def test_save_and_load_merged_weights(
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
quantizer.groupsize=32 \
""".split()

model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
Expand Down

0 comments on commit ecf8d22

Please sign in to comment.