-
Notifications
You must be signed in to change notification settings - Fork 296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate from Legacy JAX APIs jax.tree_util to jax.tree #986
base: main
Are you sure you want to change the base?
Conversation
axlearn/common/struct_test.py
Outdated
@@ -58,7 +58,7 @@ class SlotsPoint: | |||
|
|||
def test_pytree_nodes(self): | |||
p = _Point(x=1, y=2, meta={"abc": True}) | |||
leaves = jax.tree_util.tree_leaves(p) | |||
leaves = jax.jax.tree.leaves(p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, this looks like a typo?
leaves = jax.jax.tree.leaves(p) | |
leaves = jax.tree.leaves(p) |
Likewise below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
axlearn/vision/virtex.py
Outdated
@@ -146,7 +146,7 @@ def _get_visual_features(self, visual_outputs: NestedTensor) -> Tensor: | |||
pass | |||
|
|||
def _paths(x): | |||
return jax.tree_util.tree_leaves(tree_paths(x)) | |||
return jax.jax.tree.leaves(tree_paths(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed all jax.jax typos and rerun pre-commit and pytype
I wonder how the CI passed with those typos. Do they fail locally for you? |
Apparently, jax.jax.jax.tree.leaves works file
https://colab.research.google.com/drive/1ruOWXG6GXFSh1xdHyBVJVLyRwZTYq6GQ?usp=sharing |
Thanks @apivovarov , though this kind of cleanups would be better handled if you propose and let us fix it. We typically need to run many internal validation etc before merging the PR. The hairy part is from our internal repo which uses AxLearn as the core library. We will take this cleanup PR as a low priority, so does the other cleanups since they are not blocking anything at the moment. It would be great if aws can focus more on prioritizing trainium2 fixes. |
Description
This PR migrates the axlearn codebase from Legacy JAX APIs (
jax.tree_util
) to the recommendedjax.tree
module.The jax.tree API was introduced in JAX v0.4.25 and is now the preferred approach over
jax.tree_util
. Upgrading tojax.tree
ensures better compatibility with future JAX versions and improves code maintainability.jax.tree doc
jax.tree_util doc
pre-commit
pytype
pytest