From 995eed8c897911d7859a34552ef4c0c560e943d9 Mon Sep 17 00:00:00 2001 From: Ruoming Pang Date: Sat, 25 Jan 2025 23:11:05 -0500 Subject: [PATCH] Adds logging to aux loss collection. --- axlearn/common/causal_lm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 63cbb99b4..6f0e66edf 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -187,12 +187,18 @@ def forward( ctx = ctx.parent module_outputs = ctx.get_module_outputs() + for k, v in flatten_items(module_outputs): + if re.fullmatch(regex, k): + logging.info("aux loss found at %s", k) + else: + logging.info("aux loss not found at %s", k) accumulation = list( v.mean() for k, v in flatten_items(module_outputs) if re.fullmatch(regex, k) ) if accumulation: aux_loss = sum(accumulation) / len(accumulation) else: + logging.warning("aux loss not found: %s", cfg.aux_loss_regex) aux_loss = 0.0 self.add_summary("aux_loss", WeightedScalar(aux_loss, num_targets))