Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and fixes to gradient accumulation #993

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions axlearn/common/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint
from axlearn.common.utils import Nested, Tensor


def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int:
Expand Down Expand Up @@ -64,9 +64,6 @@ def _make_scan_minibatch_inputs(
within a scan function body and is meant to slice the inputs
into `minibatch_size` sized slices to run the ForwardFn on.

Note that this only preserves the input sharding if the `input_partition_spec`
returns the correct partition spec to shard the input slices with.

Args:
inputs: Same pytree as ForwardFn inputs.
forward_key: The `forward_key` from the ForwardFn inputs
Expand All @@ -78,18 +75,16 @@ def _make_scan_minibatch_inputs(
A tuple of minibatch inputs which of the same structure as `inputs`
and new (carry) forward_key and param_noise_key.
"""
minibatch_input = with_sharding_constraint(
jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
inputs["input_batch"],
minibatch_input = jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
input_partition_spec(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, it seems rather a hack than a proper solution, that is, we want to have a different input_partition_spec() than the default one, then we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the default case, added it.

I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?

(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"), 

inputs["input_batch"],
Copy link
Contributor

@apghml apghml Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step).

Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?)

Maybe we should reshard the batch axis only?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size?

)

next_forward_key, forward_key = jax.random.split(forward_key)
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key)

Expand Down Expand Up @@ -172,12 +167,26 @@ def fwd_helper(
otherwise None.
"""
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps)

# Create a sample minibatch for the carry buffer creation below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain in more detail why this is needed?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

(
sample_minibatch_inputs,
_,
_,
) = _make_scan_minibatch_inputs(
inputs,
forward_key=inputs["forward_key"],
param_noise_key=inputs["param_noise_key"],
minibatch_size=minibatch_size,
minibatch_index=0,
)

# Carry initialization for the lax.scan procedure. Since we are passing a
# `MetricAccumulator` into carry and carry input/output shapes must match
# we need initialize the `MetricAccumulator` summary with the right PyTree
# structure.
_, primal_output_shape = jax.eval_shape(
original_func_positional_args, model_params, inputs
original_func_positional_args, model_params, sample_minibatch_inputs
)
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape)
init_accumulator = maybe_instantiate(metric_accumulator)
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
"""Test module for gradient_accumulation.py"""

import chex
import jax
import jax.numpy as jnp
Expand Down
3 changes: 3 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def get_trainer_kwargs(
),
*trn2_config.module_modifications,
*trn2_config.partition_spec_modifications,
GradientAccumulationModifier.default_config().set(
grad_acc_steps=4,
),
],
),
),
Expand Down