Skip to content

Commit

Permalink
Migrate from Legacy JAX APIs jax.tree_util to jax.tree
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Feb 12, 2025
1 parent d47d5ce commit a24b322
Show file tree
Hide file tree
Showing 24 changed files with 57 additions and 77 deletions.
4 changes: 1 addition & 3 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,7 @@ async def _run_serializer():

asyncio.run(_run_serializer())

self._add_futures(
jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or [])
)
self._add_futures(jax.tree.flatten(commit_futures)[0] + (additional_futures or []))

# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5161,7 +5161,7 @@ def has_prebuilt_layers(path):
lambda path, spec: spec if has_prebuilt_layers(path) else None, param_specs
)
if prebuilt_layers:
self.assertNotEmpty(jax.tree_util.tree_leaves(prebuilt_specs))
self.assertNotEmpty(jax.tree.leaves(prebuilt_specs))
initialized_state = layer.initialize_parameters_recursively(
prng_key=jax.random.PRNGKey(123), prebuilt=prebuilt_specs
)
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,7 @@ def restore_from_dir(
else:
raise RuntimeError(f"Unknown index entry '{value}'")

restored_state = jax.tree_util.tree_unflatten(
jax.tree_util.tree_structure(state), state_leaves
)
restored_state = jax.tree.unflatten(jax.tree.structure(state), state_leaves)
multihost_utils.sync_global_devices(ckpt_dir)
return restored_state

Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,7 @@ def test_custom_dict(self, checkpointer_cls, custom_dict_type):
step, restored_state = ckpt.restore(step=None, state=state0)
self.assertEqual(100, step)
self.assertEqual(type(restored_state), custom_dict_type)
self.assertIn(
custom_dict_type.__name__, str(jax.tree_util.tree_structure(restored_state))
)
self.assertIn(custom_dict_type.__name__, str(jax.tree.structure(restored_state)))
self.assertNestedEqual(state0, restored_state)
ckpt.stop()

Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __call__(self, input_batch: NestedTensor) -> Output:
isinstance(x, jax.Array) and len(x.devices()) == 1
) or isinstance(x, np.ndarray)
all_host_local_inputs = all(
is_host_local_input_check(t) for t in jax.tree_util.tree_leaves(input_batch)
is_host_local_input_check(t) for t in jax.jax.tree.leaves(input_batch)
)

if all_host_local_inputs:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def write(self, *, input_batch: NestedTensor, output_batch: NestedTensor):
output_batch: A NestedTensor whose leaves must be tensors of shape [batch_size, ...].
"""
local_data = dict(input=input_batch, output=output_batch)
local_batch_size = jax.tree_util.tree_leaves(local_data)[0].shape[0]
local_batch_size = jax.tree.leaves(local_data)[0].shape[0]

for i in range(local_batch_size):
example = jax.tree.map(lambda x, index=i: x[index], local_data)
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run(self, **kwargs):
self.summary_writer(step=batch_index, values=output.summaries)

if (batch_index + 1) % 10 == 0:
global_batch_size = len(jax.tree_util.tree_leaves(global_input_batch)[0])
global_batch_size = len(jax.tree.leaves(global_input_batch)[0])
logging.info(
"Processed %d batches and %d examples",
batch_index + 1,
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ def _learner_tree(self, params: Nested[Any]) -> Nested[str]:
tree_paths(params),
)
# Check that all params is covered.
if not jax.tree_util.tree_reduce(
lambda x, y: x and (y != ""), learner_name_tree, initializer=True
):
if not jax.tree.reduce(lambda x, y: x and (y != ""), learner_name_tree, initializer=True):
raise ValueError("Composite learner rules do not update all model params.")
return learner_name_tree

Expand Down
14 changes: 7 additions & 7 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and Adam mu states are same.
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
jax.tree.structure(updated_model_params),
jax.tree.structure(learner_state["optimizer"][1].mu),
)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
Expand Down Expand Up @@ -983,14 +983,14 @@ def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutp
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
jax.tree.structure(updated_model_params),
jax.tree.structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
jax.tree.structure(updated_model_params),
jax.tree.structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
)
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def loss_fn(model_params, inputs):
summaries={},
module_outputs={},
)
result = jax.tree_util.tree_reduce(lambda x, y: x.sum() + y.sum(), model_params)
result = jax.tree.reduce(lambda x, y: x.sum() + y.sum(), model_params)
return ForwardOutputs(loss=result, aux={}, output_collection=output_collection)

grads = jax.tree_map(lambda p: jnp.ones_like(p.value), params)
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def test_metric_accumulator(self):
)

chex.assert_trees_all_equal_structs(result, expected)
result = jax.tree_util.tree_leaves(result)
expected = jax.tree_util.tree_leaves(expected)
result = jax.jax.tree.leaves(result)
expected = jax.jax.tree.leaves(expected)
chex.assert_trees_all_close(result, expected)

def test_flatten_unflatten_metric_accumulator(self):
Expand All @@ -75,10 +75,10 @@ def test_flatten_unflatten_metric_accumulator(self):
for s in summaries_copy:
acc.update(s)

flat, tree = jax.tree_util.tree_flatten(acc)
unflattened = jax.tree_util.tree_unflatten(tree, flat)
expected = jax.tree_util.tree_leaves(acc.summaries())
result = jax.tree_util.tree_leaves(unflattened.summaries())
flat, tree = jax.tree.flatten(acc)
unflattened = jax.tree.unflatten(tree, flat)
expected = jax.tree.leaves(acc.summaries())
result = jax.tree.leaves(unflattened.summaries())
chex.assert_trees_all_close(result, expected)

@parameterized.parameters(
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def convert_fn(source_parameters: Nested[Tensor]) -> Nested[Tensor]:
) from e
# The target layer is a RepeatedTransformerLayer.
target_parameters = {"repeat": VDict({"layer": {}})}
num_stages = jax.tree_util.tree_leaves(stage_parameter_specs)[0].shape[0]
num_stages = jax.jax.tree.leaves(stage_parameter_specs)[0].shape[0]
# The target stage is expected to be a StackedTransformerLayer.
num_layers_per_stage = len(stage_parameter_specs)
for layer_i in range(num_layers_per_stage):
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def propagate_repeated_output_collections(
# if a repeated layer outputs a scalar summary value, it will have shape [N].
# Below we split the stacked values and output them separately under scope
# "{child_name_prefix}{i}" so that scalar summaries can be handled correctly.
summary_values = jax.tree_util.tree_leaves(repeated_output_collection.summaries)
summary_values = jax.jax.tree.leaves(repeated_output_collection.summaries)
if summary_values:
first_summary_value = summary_values[0]
assert first_summary_value.shape, "Stacked summaries should have a leading stack dimension."
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ def _is_valid_step(
return is_valid, new_drop_stats

# Check if every gradient is finite.
flat_updates = jax.tree_util.tree_flatten(updates)[0]
flat_updates = jax.tree.flatten(updates)[0]
is_finite = jnp.all(jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
g_norm = optax.global_norm(updates)
if drop_norm is not None:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def _run(
cfg: Pipeline.Config = self.config
self.vlog(1, "carry=%s xs=%s", shapes(carry), shapes(xs))

carry_leaves = jax.tree_util.tree_leaves(carry)
carry_leaves = jax.jax.tree.leaves(carry)
if not carry_leaves:
raise ValueError("Expected at least one input leaf.")
if carry_leaves[0].ndim < 2:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _run(self, fn, carry=None, *, xs=None):
with child_context("layer", output_collection=layer_output_collection) as layer_context:
# Note, actual `num_layers` might be smaller than `cfg.num_layers` depending on
# the invocation context.
num_layers = jax.tree_util.tree_reduce(
num_layers = jax.tree.reduce(
lambda num, x: min(num, x.shape[0]),
tree=(layer_context.state, xs),
initializer=cfg.num_layers,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,9 @@ def test_repeat_forward_vs_layerwise(self, norm_cfg, hidden_dim, num_layers):
final_states_list.append(output_collections.module_outputs["final_states"])

# Stack the tree leaves.
tree_leaves = [jax.tree_util.tree_flatten(t)[0] for t in final_states_list]
tree_def = jax.tree_util.tree_structure(final_states_list[0])
final_states = jax.tree_util.tree_unflatten(
tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)]
)
tree_leaves = [jax.tree.flatten(t)[0] for t in final_states_list]
tree_def = jax.tree.structure(final_states_list[0])
final_states = jax.tree.unflatten(tree_def, [jnp.stack(leaf) for leaf in zip(*tree_leaves)])
self.assertEqual(shapes(final_states), shapes(init_states))

forward_outputs, forward_collections = F(
Expand Down
6 changes: 2 additions & 4 deletions axlearn/common/state_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,7 @@ def _run_builder(
**extra_converter_config_kwargs,
):
source_state = _mock_state(source_cfg, seed=0)
initial_trainer_state_tree_structure = jax.tree_util.tree_structure(
source_state.trainer_state
)
initial_trainer_state_tree_structure = jax.tree.structure(source_state.trainer_state)

builder = (
builder_cls.default_config()
Expand All @@ -689,7 +687,7 @@ def _run_builder(
source_model = source_state.trainer_state.model

converted_state = builder(deepcopy(source_state))
assert initial_trainer_state_tree_structure == jax.tree_util.tree_structure(
assert initial_trainer_state_tree_structure == jax.tree.structure(
converted_state.trainer_state
)
converted_model = converted_state.trainer_state.model
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SlotsPoint:

def test_pytree_nodes(self):
p = _Point(x=1, y=2, meta={"abc": True})
leaves = jax.tree_util.tree_leaves(p)
leaves = jax.jax.tree.leaves(p)
self.assertEqual(leaves, [1, 2])
new_p = jax.tree.map(lambda x: x + x, p)
self.assertEqual(new_p, _Point(x=2, y=4, meta={"abc": True}))
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_chex_tree_leaves_compatibility(self):
)
# tree_flatten_with_path is not preserved because Chex does not support this so the
# fallback jax implementation with numbered keys gets used.
flattened.append(jax.tree_util.tree_leaves(instance))
flattened.append(jax.jax.tree.leaves(instance))
chex.assert_trees_all_equal(*flattened)

def test_constructor_order(self):
Expand Down Expand Up @@ -133,7 +133,7 @@ class C:
field_b: int
field_a: int

result = jax.tree_util.tree_leaves(C(field_b=1, field_a=2))
result = jax.jax.tree.leaves(C(field_b=1, field_a=2))
expected = (1, 2)
self.assertSequenceEqual(result, expected)

Expand Down
12 changes: 5 additions & 7 deletions axlearn/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def _compute_layer_outputs(
# Optionally, test that trees also have the same structure.
if require_same_tree_structure:
# Prune empty subtrees so we don't require empty dicts for layers with no params.
ref_structure = jax.tree_util.tree_structure(prune_empty(params_from_ref))
test_structure = jax.tree_util.tree_structure(prune_empty(layer_params))
ref_structure = jax.tree.structure(prune_empty(params_from_ref))
test_structure = jax.tree.structure(prune_empty(layer_params))
self.assertEqual(
ref_structure, test_structure, msg=f"\nRef: {ref_structure}\nTest: {test_structure}"
)
Expand Down Expand Up @@ -428,8 +428,8 @@ def replace_keys(v, mapping):
params_with_nones = jax.tree_map(
partial(replace_keys, mapping={k: None for k in delegates}), params, is_leaf=is_leaf
)
_, treedef = jax.tree_util.tree_flatten(params_with_nones)
inits_with_nones = jax.tree_util.tree_unflatten(treedef, param_init_specs)
_, treedef = jax.tree.flatten(params_with_nones)
inits_with_nones = jax.tree.unflatten(treedef, param_init_specs)

# Replace the Nones with a delegate.
return jax.tree_map(partial(replace_keys, mapping=delegates), inits_with_nones, is_leaf=is_leaf)
Expand Down Expand Up @@ -563,9 +563,7 @@ def patched_register_per_param_settings(
model_params = model.initialize_parameters_recursively(jax.random.PRNGKey(0))

model_specs = model.create_parameter_specs_recursively()
model_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(model_params), model_specs
)
model_specs = complete_partition_spec_tree(jax.tree.structure(model_params), model_specs)
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _opt_params(self, model_params: NestedTensor) -> NestedOptParam:
"""Returns a tree of OptParam for Learner.{init,update}."""
# self._model_param_specs can be incomplete. Complete it first.
specs = utils.complete_partition_spec_tree(
jax.tree_util.tree_structure(model_params), self._model_param_specs
jax.tree.structure(model_params), self._model_param_specs
)
return jax.tree.map(
lambda param, spec: OptParam(
Expand Down Expand Up @@ -852,7 +852,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
# Log trainer state tree.
if not self.step and jax.process_index() == 0:
with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f:
f.write(str(jax.tree_util.tree_structure(self._trainer_state)))
f.write(str(jax.tree.structure(self._trainer_state)))

with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
f.write(model_analysis)
Expand Down
4 changes: 1 addition & 3 deletions axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@ def test_compile_train_step(self, *, platform, mesh_shape):
trainer: SpmdTrainer = cfg.instantiate(parent=None)
compiled_without_args = trainer.compile_train_step()
# pylint: disable=protected-access
input_batch = jax.tree_util.tree_map(
jnp.array, next(trainer.input.batches(trainer._input_iter))
)
input_batch = jax.tree.map(jnp.array, next(trainer.input.batches(trainer._input_iter)))
# pylint: enable=protected-access
compiled_with_input_batch = trainer.compile_train_step(input_batch=input_batch)
# In a single-host environment, both compiled functions should match.
Expand Down
22 changes: 10 additions & 12 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def vectorized_tree_map(fn, tree, *rest):

def vectorized_fn(*nodes):
if isinstance(nodes[0], VDict):
if not jax.tree_util.tree_leaves(nodes[0]):
if not jax.jax.tree.leaves(nodes[0]):
# This can happen when all VDict values are None and cause issues with jax.vmap.
return nodes[0]
nodes = [dict(**node) for node in nodes]
Expand Down Expand Up @@ -469,7 +469,7 @@ def fn(value: Union[Tensor, VDict]) -> NestedTensor:
if not isinstance(value, VDict):
return value

leaves = jax.tree_util.tree_leaves(value)
leaves = jax.jax.tree.leaves(value)
if not leaves:
# An empty VDict.
return value
Expand Down Expand Up @@ -653,7 +653,7 @@ def complete_partition_spec_tree(
prefix of treedef.
"""
proxy = object()
dummy = jax.tree_util.tree_unflatten(treedef, [object()] * treedef.num_leaves)
dummy = jax.tree.unflatten(treedef, [object()] * treedef.num_leaves)
axes = []

def replace_none_with_proxy(tree):
Expand All @@ -672,17 +672,17 @@ def replace_none_with_proxy(tree):
partition_spec_tree_with_proxy = replace_none_with_proxy(partition_spec_tree)

def add_leaves(i, x):
axes.extend([i] * len(jax.tree_util.tree_flatten(x)[0]))
axes.extend([i] * len(jax.tree.flatten(x)[0]))

try:
jax.tree.map(add_leaves, partition_spec_tree_with_proxy, dummy)
except ValueError as err:
logging.info("[complete_partition_spec_tree] ValueError: %s", err)
logging.info(
"[complete_partition_spec_tree] partition_spec_tree_with_proxy=%s",
jax.tree_util.tree_structure(partition_spec_tree_with_proxy),
jax.tree.structure(partition_spec_tree_with_proxy),
)
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree_util.tree_structure(dummy))
logging.info("[complete_partition_spec_tree] dummy=%s", jax.tree.structure(dummy))
for path, value in flatten_items(partition_spec_tree_with_proxy):
logging.info(
"[complete_partition_spec_tree] partition_spec_tree_with_proxy leaf: %s=%s",
Expand All @@ -701,7 +701,7 @@ def add_leaves(i, x):
assert (
len(axes) == treedef.num_leaves
), f"({len(axes)} vs. {treedef.num_leaves}) {axes} {treedef}"
return jax.tree_util.tree_unflatten(treedef, axes)
return jax.tree.unflatten(treedef, axes)


def input_partition_spec() -> PartitionSpec:
Expand Down Expand Up @@ -801,9 +801,7 @@ def host_to_global_device_array(
"""
mesh = thread_resources.env.physical_mesh
partition_spec = data_partition_type_to_spec(partition)
partition_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(host_arrays), partition_spec
)
partition_specs = complete_partition_spec_tree(jax.tree.structure(host_arrays), partition_spec)
process_count = jax.process_count()

def make_gda(x, partition_spec):
Expand Down Expand Up @@ -1031,7 +1029,7 @@ def cast(x: Union[Tensor, TensorSpec]) -> Union[Tensor, TensorSpec]:

def count_model_params(tree: NestedTensor) -> int:
"""Count the number of parameters in a model."""
return sum(x.size for x in jax.tree_util.tree_leaves(tree))
return sum(x.size for x in jax.jax.tree.leaves(tree))


def check_param_shape_alignment(
Expand Down Expand Up @@ -1095,7 +1093,7 @@ def check_jax_type(
pretty_named_args.update({f"kwargs[{key}]": kwargs[key] for key in kwargs})

for name, arg in pretty_named_args.items():
values, _ = jax.tree_util.tree_flatten(arg)
values, _ = jax.tree.flatten(arg)
for value in values:
if not isinstance(value, (type(None), jax.Array, int, float)):
if msg is None:
Expand Down
Loading

0 comments on commit a24b322

Please sign in to comment.