-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Odd behavior for jax.tree_util.Partial and interactions with eqx.Module #480
Labels
bug
Something isn't working
Comments
patrick-kidger
added a commit
that referenced
this issue
Sep 8, 2023
patrick-kidger
added a commit
that referenced
this issue
Sep 8, 2023
Thanks for the report! This is quite the edge-case. Indeed, it looks like I've just written #485, which should fix this. |
Thanks! |
patrick-kidger
added a commit
that referenced
this issue
Sep 8, 2023
Closing as fixed in #485. This will be included the next release (v0.11.0) of Equinox. |
patrick-kidger
added a commit
that referenced
this issue
Sep 12, 2023
patrick-kidger
added a commit
that referenced
this issue
Sep 14, 2023
patrick-kidger
added a commit
that referenced
this issue
Sep 29, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's something pretty odd about
jax.tree_util.Partial
: it mimics the equality behavior offunctools.partial
, which returns that two partial calls on the same function on the same arguments are not equal, unless they are the same object.This is a problem, because equinox wraps its method calls to return a
jax.tree_util.Partial
, giving you the following very unexpected behavior:The same instance method isn't the same! This is because they come from separate calls to the wrapper that turns the
class's method into
jax.tree_util.Partial
.One possible fix is to switch the wrapper to use equinox's own
Partial
, which behaves sensibly with equality. Except, you can't usefunctools.wraps
on an equinoxPartial
because it's a frozen dataclass andwraps
tries to change attributes on the method. The only simple fix is to drop thewraps
call altogether but that seems like a bad choice too. Maybe thePartial
could be unfrozen while it is wrapped?Probably the better solution is for jax to change how
__eq__
works onjax.tree_util.Partial
, because it leads to more unusual behavior. The choice for equality may be a fine decision forfunctools.partial
, butjax.tree_util.Partial
is PyTree compatible and so can be flattened. I would expect that if two flattened PyTrees are equal that the unflattened PyTrees would be equal. That doesn't hold forjax.tree_util.Partial
, which flattens to the function and its arguments, which are equal for two equivalentjax.tree_util.Partial
objects.Equinox's
Partial
doesn't have any of these problems. See below, that onlyjax.tree_util.Partial
is inconsistent between equality of its flattened and unflattened representations.gives the output:
The text was updated successfully, but these errors were encountered: