Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Accuracy from cross_entropy in causal_lm.py::CrossEntropyLossMetrics #983

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ def forward(
if live_targets is None:
live_targets = target_labels >= 0
num_targets = live_targets.sum()
accuracy = (
jnp.equal(jnp.argmax(logits, axis=-1), target_labels) * live_targets
).sum() / jnp.maximum(1, num_targets)

loss, loss_dict = cross_entropy(
logits=logits,
Expand All @@ -118,7 +115,7 @@ def forward(
z_loss_scale=cfg.z_loss_scale if cfg.z_loss_scale is not None else 0.0,
)
per_token_loss = loss_dict["per_target_loss"] * live_targets
self.add_summary("accuracy", WeightedScalar(accuracy, num_targets))
self.add_summary("accuracy", WeightedScalar(loss_dict["accuracy"], num_targets))
self.add_summary("z_loss", WeightedScalar(loss_dict["z_loss"], num_targets))
if target_num_bytes is not None:
# N.B. we calculate bpb following Appendix D.2. of <https://arxiv.org/abs/2112.11446>,
Expand Down Expand Up @@ -606,6 +603,7 @@ def residual_initializer_cfg(num_layers, scale=0.02):
return init_cfg


# pylint: disable=too-many-positional-arguments
def gpt_decoder_config(
stack_cfg: TransformerStackConfig,
num_layers: int,
Expand Down
12 changes: 9 additions & 3 deletions axlearn/common/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def _reduce_loss(
if sample_weight is not None:
loss = loss * sample_weight

# Initialize reduced_loss to prevent linter errors
reduced_loss = loss # Default initialization
if reduction == ReductionMethod.NONE:
reduced_loss = loss
elif reduction == ReductionMethod.SUM:
Expand Down Expand Up @@ -117,14 +119,18 @@ def cross_entropy(
target_labels will only be used for inferring the live targets during loss calculation.

Returns:
(loss, all_losses), where
(loss, loss_dict), where
loss is a scalar tensor for the cross entropy loss;
all_losses is a dictionary containing:
loss_dict is a dictionary containing:
* "total_loss": a scalar representing the overall
loss = cross_entropy_loss + z_loss_scale * z_loss.
* "cross_entropy_loss": the cross_entropy_loss.
* "z_loss": the unscaled z_loss.
* "cross_entropy_loss": the cross_entropy_loss.
* "per_target_loss": the loss per target, of the same shape as `target_labels`.
* "accuracy": the proportion of correctly predicted targets, computed as
the number of instances where the predicted class (argmax of logits)
matches the target label, weighted by `live_targets`, and normalized by
the total number of valid targets.

Raises:
ValueError: If z_loss_scale is negative.
Expand Down
Loading