Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ahead of time compilation issue with kwargs in the function signature. #625

Closed
tttc3 opened this issue Dec 22, 2023 · 1 comment · Fixed by #627
Closed

Ahead of time compilation issue with kwargs in the function signature. #625

tttc3 opened this issue Dec 22, 2023 · 1 comment · Fixed by #627
Labels
bug Something isn't working

Comments

@tttc3
Copy link

tttc3 commented Dec 22, 2023

This MWE is the test_aot_compilation from test_jit.py with the addition of a keyword argument.

import equinox as eqx
import jax.numpy as jnp

def f(x, y, **kwargs):
    return 2 * x + y

x, y = jnp.array(3), 4
lowered = eqx.filter_jit(f, donate="none").lower(x, y, test=123)
lowered.as_text()
compiled = lowered.compile()
compiled(x, y, test=123)

As of v0.11.2 (working in v0.11.1) the above MWE raises the following error:

TypeError: function compiled for PyTreeDef((({}, {'first': (*,), 'fun': (None,), 'rest': (None, None)}), {})), called with PyTreeDef((({}, {'first': (*,), 'fun': (None,), 'rest': (None,)}), {}))

I believe the error is caused by the updated _preprocess function used in _JitWrapper._call yielding a different set of arguments for the Lowered.lowered than those yielded for the Compiled.compiled inside of Compiled.__call__. For example if we make the following change to _preprocess, without considering any other implications, the error appears to be resolved:

def _preprocess(info, args, kwargs, return_static: bool = False):
    signature, dynamic_fun, static_fun, donate_first, donate_rest = info
    args, kwargs = _bind(signature, args, kwargs)
    # add dummy to avoid special casing `len(args) == 0`.
    args = args + (None,)
    first_arg = args[0]
    rest_args = args[1:]
    if return_static:
        dynamic_first, static_first = hashable_partition(first_arg, is_array)
        dynamic_rest, static_rest = hashable_partition((rest_args, kwargs), is_array)
    else:
        dynamic_first = hashable_filter(first_arg, is_array)
        # dynamic_rest = hashable_filter(rest_args, is_array)
        # Include kwargs in dynamic_rest
        dynamic_rest = hashable_filter((rest_args, kwargs), is_array)
    dynamic_donate = dict()
    dynamic_nodonate = dict()
    if donate_first:
        dynamic_donate["first"] = dynamic_first
    else:
        dynamic_nodonate["first"] = dynamic_first
    if donate_rest:
        dynamic_donate["fun"] = dynamic_fun
        dynamic_donate["rest"] = dynamic_rest
    else:
        dynamic_nodonate["fun"] = dynamic_fun
        dynamic_nodonate["rest"] = dynamic_rest

    if return_static:
        static = (static_fun, static_first, static_rest)  # pyright: ignore
        return dynamic_donate, dynamic_nodonate, static
    else:
        return dynamic_donate, dynamic_nodonate

I'm not familiar with the argument donation semantics so perhaps such a modification is unsuitable?

@patrick-kidger
Copy link
Owner

Thanks for the report! I agree with your proposed fix -- I've just written #627 to fix this.
We'll have this fixed in the next release :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants