Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from Legacy JAX APIs jax.tree_util to jax.tree #986

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apivovarov
Copy link
Contributor

Description

This PR migrates the axlearn codebase from Legacy JAX APIs (jax.tree_util) to the recommended jax.tree module.

The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over jax.tree_util. Upgrading to jax.tree ensures better compatibility with future JAX versions and improves code maintainability.

jax.tree doc

jax.tree_util doc

pre-commit

$ pre-commit run -a      
Check Yaml...............................................................Passed
Fix End of Files.........................................................Passed
Trim Trailing Whitespace.................................................Passed
black....................................................................Passed
isort....................................................................Passed
pylint...................................................................Passed

pytype

$ pytype -j auto axlearn
...
Success: no errors found

pytest

pytest -v -n 96 -m "not (gs_login or tpu or high_cpu or fp64)" axlearn/common

========== 0 failed, 6220 passed, 10364 skipped in 734.23s (0:12:14) ==========

@apivovarov apivovarov requested review from ruomingp, markblee and a team as code owners February 12, 2025 22:41
@@ -58,7 +58,7 @@ class SlotsPoint:

def test_pytree_nodes(self):
p = _Point(x=1, y=2, meta={"abc": True})
leaves = jax.tree_util.tree_leaves(p)
leaves = jax.jax.tree.leaves(p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this looks like a typo?

Suggested change
leaves = jax.jax.tree.leaves(p)
leaves = jax.tree.leaves(p)

Likewise below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -146,7 +146,7 @@ def _get_visual_features(self, visual_outputs: NestedTensor) -> Tensor:
pass

def _paths(x):
return jax.tree_util.tree_leaves(tree_paths(x))
return jax.jax.tree.leaves(tree_paths(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Contributor Author

@apivovarov apivovarov Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed all jax.jax typos and rerun pre-commit and pytype

@markblee
Copy link
Contributor

I wonder how the CI passed with those typos. Do they fail locally for you?

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 14, 2025

Apparently, jax.jax.jax.tree.leaves works file

import jax
import jax.numpy as jnp
import sys

x = (jnp.array(0), jnp.array(1))
y = jax.jax.jax.tree.leaves(x)

print(y)
print(type(y))
print(jax.jax)
print(jax.jax.jax)
[Array(0, dtype=int32, weak_type=True), Array(1, dtype=int32, weak_type=True)]
<class 'list'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>
<module 'jax' from '/usr/local/lib/python3.11/dist-packages/jax/__init__.py'>

https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing

@kelvin-zou
Copy link
Contributor

Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library.

We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants