-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements and fixes to gradient accumulation #993
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from axlearn.common.config import ConfigOr, maybe_instantiate | ||
from axlearn.common.metrics import MetricAccumulator | ||
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs | ||
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint | ||
from axlearn.common.utils import Nested, Tensor | ||
|
||
|
||
def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int: | ||
|
@@ -64,9 +64,6 @@ def _make_scan_minibatch_inputs( | |
within a scan function body and is meant to slice the inputs | ||
into `minibatch_size` sized slices to run the ForwardFn on. | ||
|
||
Note that this only preserves the input sharding if the `input_partition_spec` | ||
returns the correct partition spec to shard the input slices with. | ||
|
||
Args: | ||
inputs: Same pytree as ForwardFn inputs. | ||
forward_key: The `forward_key` from the ForwardFn inputs | ||
|
@@ -78,18 +75,16 @@ def _make_scan_minibatch_inputs( | |
A tuple of minibatch inputs which of the same structure as `inputs` | ||
and new (carry) forward_key and param_noise_key. | ||
""" | ||
minibatch_input = with_sharding_constraint( | ||
jax.tree.map( | ||
lambda x: jax.lax.dynamic_slice_in_dim( | ||
x, | ||
start_index=minibatch_index * minibatch_size, | ||
slice_size=minibatch_size, | ||
axis=0, | ||
), | ||
inputs["input_batch"], | ||
minibatch_input = jax.tree.map( | ||
lambda x: jax.lax.dynamic_slice_in_dim( | ||
x, | ||
start_index=minibatch_index * minibatch_size, | ||
slice_size=minibatch_size, | ||
axis=0, | ||
), | ||
input_partition_spec(), | ||
inputs["input_batch"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suppose we have a global input batch of size 100 running on 10 chips (so a per chip size of 10) and we want to switch to doing 10 grad accumulation steps each with a global batch size of 10 (1 per chip per accumulation step). Suppose that the input is originally sharded evenly across the chips (first 10 on first chip, second 10 on second chip, etc). Then when we get the first slice of 10 for the first grad accumulation step, won't all these examples be on the same chip? Will that cause a problem? (E.g., if we worry XLA might not automatically reshard the examples across chips?) Maybe we should reshard the batch axis only? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on the potential design problem here. Can you double check and ensure that axis=0 is confirmed to be batch size? |
||
) | ||
|
||
next_forward_key, forward_key = jax.random.split(forward_key) | ||
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key) | ||
|
||
|
@@ -172,12 +167,26 @@ def fwd_helper( | |
otherwise None. | ||
""" | ||
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps) | ||
|
||
# Create a sample minibatch for the carry buffer creation below | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain in more detail why this is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
( | ||
sample_minibatch_inputs, | ||
_, | ||
_, | ||
) = _make_scan_minibatch_inputs( | ||
inputs, | ||
forward_key=inputs["forward_key"], | ||
param_noise_key=inputs["param_noise_key"], | ||
minibatch_size=minibatch_size, | ||
minibatch_index=0, | ||
) | ||
|
||
# Carry initialization for the lax.scan procedure. Since we are passing a | ||
# `MetricAccumulator` into carry and carry input/output shapes must match | ||
# we need initialize the `MetricAccumulator` summary with the right PyTree | ||
# structure. | ||
_, primal_output_shape = jax.eval_shape( | ||
original_func_positional_args, model_params, inputs | ||
original_func_positional_args, model_params, sample_minibatch_inputs | ||
) | ||
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape) | ||
init_accumulator = maybe_instantiate(metric_accumulator) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it seems rather a hack than a proper solution, that is, we want to have a different
input_partition_spec()
than the default one, then we need this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed the default case, added it.
I think the below partition spec is good as a default, but the ability to change PartitionSpec might be good to have, what do you think?