Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange error with multiple tuples as attributes and equinox.tree_at? #715

Closed
leonard-gleyzer opened this issue May 7, 2024 · 5 comments · Fixed by #717
Closed

Strange error with multiple tuples as attributes and equinox.tree_at? #715

leonard-gleyzer opened this issue May 7, 2024 · 5 comments · Fixed by #717
Labels
bug Something isn't working

Comments

@leonard-gleyzer
Copy link

Hello!

I've come across an error I'm not quite sure what to do about.

MWE:

import equinox as eqx


class M(eqx.Module):
    a: tuple
    b: tuple

    def __init__(self):
        self.a = ()
        self.b = ()


eqx.tree_at(lambda m: m.a, M(), ())

Running the above gives me the following error:

ValueError: `where` does not uniquely identify a single element of `pytree`. 
This usually occurs when trying to replace a `None` value:

  >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True)


for which the fix is to specify that `None`s should be treated as leaves:

  >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True,
  ...             is_leaf=lambda x: x is None) 

However, if I use lists, or a tuple and a list

import equinox as eqx


class M(eqx.Module):
    a: tuple
    b: list

    def __init__(self):
        self.a = ()
        self.b = []


eqx.tree_at(lambda m: m.a, M(), ())

or just a single tuple

import equinox as eqx


class M(eqx.Module):
    a: tuple

    def __init__(self):
        self.a = ()


eqx.tree_at(lambda m: m.a, M(), ())

there is no error.

@lockwo
Copy link
Contributor

lockwo commented May 7, 2024

Not 100% sure on the source, but one fix is to force it to recognize tuples as endpoints, via

import equinox as eqx


class M(eqx.Module):
    a: tuple
    b: tuple

    def __init__(self):
        self.a = ()
        self.b = (1,)


eqx.tree_at(lambda m: m.a, M(), (), is_leaf=lambda x: x == ())

@patrick-kidger
Copy link
Owner

Interesting! Looks like Python actually uses the exact same object for tuple literals:

() is ()  # True
(1,) is (1,)  # True

So the logic inside eqx.tree_at works by requiring that for any two nodes x, y in the PyTree, that x is not y. This is how it can identify what to replace. (Here I use "node" to refer to both composite nodes and leaves.)

Now specifically leaves frequently fail this (e.g. x = [4, 4]; x[0] is x[1] # True), so we make this work by first doing a pytree = jtu.tree_map(Wrapper, pytree) to wrap up the leaves. No big deal.

And until this point, there haven't been any examples of composite nodes which fail the x is not y check. Now it seems we have one! This is why two tuples are specifically what is required to bump into this issue.

If anyone feels like tackling this in the near future then I'd be happy to take a pull request.

@patrick-kidger patrick-kidger added the bug Something isn't working label May 8, 2024
@lockwo
Copy link
Contributor

lockwo commented May 8, 2024

Is this something that needs a fix, it seems like its working as it should? This exists because an empty tuple is a singleton in CPython, so only 1 ever exists in memory (so all () have the same id) (https://stackoverflow.com/questions/14135542/how-is-tuple-implemented-in-cpython/). If you have something like

class M(eqx.Module):
    a: tuple
    b: tuple

    def __init__(self):
        self.a = (1,)
        self.b = (1,)

this works because the id's are different. In my mind, equinox is behaving as it should, the where function is underspecified since it points to two objects in memory and its hard to think of a workaround that works for making tuples which are identical in memory that doesn't break other things that are equal in memory (outside of a hardcoded if x == () or something). Even if there was a workaround, I think it wouldn't be ideal for two reasons 1) it's such very much an edge case, 2) people who knew about tuples being singletons going into equinox (I just learned it) would then be surprised by the behavior because now equinox is going against what CPython does.

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 8, 2024

So I think of eqx.tree_at as just being a "please change the node at this path in the PyTree". The fact that it goes via CPython's id is an implementation detail, not its defining characteristic.

I think on that basis this is a minor bug. Though as you say, this is a very very edge case with a loud failure and an easy work-around, so it's not one I'm too worried by overall!

@lockwo lockwo mentioned this issue May 8, 2024
@lockwo
Copy link
Contributor

lockwo commented May 8, 2024

Hmmm I see. This is a fix, by making a wrapper object basically for the tuple singleton: #717. This is definitely the craziest edge case to write this much code for lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants