Skip to content

Commit

Permalink
Fixed eqx.tree_at when used alongside an empty namedtuple.
Browse files Browse the repository at this point in the history
See #715 and #717. CC @lockwo.
  • Loading branch information
patrick-kidger committed May 12, 2024
1 parent 7de34a2 commit b31d814
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
7 changes: 5 additions & 2 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ def tree_at(
#
# Whilst we're here: we also double-check that `where` is well-formed and doesn't
# use leaf information. (As else `node_or_nodes` will be wrong.)
is_empty_tuple = (
lambda x: isinstance(x, tuple) and not hasattr(x, "_fields") and x == ()
)
pytree = jtu.tree_map(
lambda x: _DistinctTuple() if isinstance(x, tuple) and x == () else x,
lambda x: _DistinctTuple() if is_empty_tuple(x) else x,
pytree,
is_leaf=lambda x: isinstance(x, tuple) and x == (),
is_leaf=is_empty_tuple,
)
node_or_nodes_nowrapper = where(pytree)
pytree = jtu.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf)
Expand Down
20 changes: 18 additions & 2 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections as co

import equinox as eqx
import jax
import jax.core
Expand Down Expand Up @@ -47,21 +49,35 @@ def test_tree_at_replace(getkey):
def test_tree_at_empty_tuple():
# Tuples are singletons, so we have a specific test for the wrapper
a = ()
b = (1,)
x1 = [a]
x2 = [a, a]
x3 = [(), ()]

b = (1,)
x4 = [b]
x5 = [b, b]
x6 = [(1,), (1,)]

for x in (x1, x2, x3, x4, x5, x6):
Empty = co.namedtuple("Empty", [])
empty = Empty()
x7 = [empty]
x8 = [empty, empty]
x9 = [Empty(), Empty()]

for x in (x1, x2, x3, x4, x5, x6, x7, x8, x9):
new_x = eqx.tree_at(lambda xi: xi[0], x, "hello")
assert new_x[0] == "hello"
if len(new_x) != 1:
assert new_x[1] != "hello"


def test_tree_at_empty_namedtuple():
Empty = co.namedtuple("Empty", [])
pytree = [Empty(), 5]
out = eqx.tree_at(lambda x: x[1], pytree, 4)
assert isinstance(out[0], Empty)


def test_tree_at_replace_fn(getkey):
key = getkey()
pytree = [1, 2, 3, {"a": jnp.array([1.0, 2.0])}, eqx.nn.Linear(1, 2, key=key)]
Expand Down

0 comments on commit b31d814

Please sign in to comment.