From 919a20fa5bd85d5fb6cc3e9ba4f3b264c303cb2b Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 12 Feb 2025 00:45:01 +0000 Subject: [PATCH] Use Accuracy from cross_entropy in CrossEntropyLossMetrics --- axlearn/common/causal_lm.py | 6 ++---- axlearn/common/loss.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index e4797fa9f..e24c47015 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -107,9 +107,6 @@ def forward( if live_targets is None: live_targets = target_labels >= 0 num_targets = live_targets.sum() - accuracy = ( - jnp.equal(jnp.argmax(logits, axis=-1), target_labels) * live_targets - ).sum() / jnp.maximum(1, num_targets) loss, loss_dict = cross_entropy( logits=logits, @@ -118,7 +115,7 @@ def forward( z_loss_scale=cfg.z_loss_scale if cfg.z_loss_scale is not None else 0.0, ) per_token_loss = loss_dict["per_target_loss"] * live_targets - self.add_summary("accuracy", WeightedScalar(accuracy, num_targets)) + self.add_summary("accuracy", WeightedScalar(loss_dict["accuracy"], num_targets)) self.add_summary("z_loss", WeightedScalar(loss_dict["z_loss"], num_targets)) if target_num_bytes is not None: # N.B. we calculate bpb following Appendix D.2. of , @@ -606,6 +603,7 @@ def residual_initializer_cfg(num_layers, scale=0.02): return init_cfg +# pylint: disable=too-many-positional-arguments def gpt_decoder_config( stack_cfg: TransformerStackConfig, num_layers: int, diff --git a/axlearn/common/loss.py b/axlearn/common/loss.py index 1e6881d25..f04da5cab 100644 --- a/axlearn/common/loss.py +++ b/axlearn/common/loss.py @@ -68,6 +68,8 @@ def _reduce_loss( if sample_weight is not None: loss = loss * sample_weight + # Initialize reduced_loss to prevent linter errors + reduced_loss = loss # Default initialization if reduction == ReductionMethod.NONE: reduced_loss = loss elif reduction == ReductionMethod.SUM: @@ -117,14 +119,18 @@ def cross_entropy( target_labels will only be used for inferring the live targets during loss calculation. Returns: - (loss, all_losses), where + (loss, loss_dict), where loss is a scalar tensor for the cross entropy loss; - all_losses is a dictionary containing: + loss_dict is a dictionary containing: * "total_loss": a scalar representing the overall loss = cross_entropy_loss + z_loss_scale * z_loss. - * "cross_entropy_loss": the cross_entropy_loss. * "z_loss": the unscaled z_loss. + * "cross_entropy_loss": the cross_entropy_loss. * "per_target_loss": the loss per target, of the same shape as `target_labels`. + * "accuracy": the proportion of correctly predicted targets, computed as + the number of instances where the predicted class (argmax of logits) + matches the target label, weighted by `live_targets`, and normalized by + the total number of valid targets. Raises: ValueError: If z_loss_scale is negative.