You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This MWE is the test_aot_compilation from test_jit.py with the addition of a keyword argument.
importequinoxaseqximportjax.numpyasjnpdeff(x, y, **kwargs):
return2*x+yx, y=jnp.array(3), 4lowered=eqx.filter_jit(f, donate="none").lower(x, y, test=123)
lowered.as_text()
compiled=lowered.compile()
compiled(x, y, test=123)
As of v0.11.2 (working in v0.11.1) the above MWE raises the following error:
I believe the error is caused by the updated _preprocess function used in _JitWrapper._call yielding a different set of arguments for the Lowered.lowered than those yielded for the Compiled.compiled inside of Compiled.__call__. For example if we make the following change to _preprocess, without considering any other implications, the error appears to be resolved:
This MWE is the
test_aot_compilation
fromtest_jit.py
with the addition of a keyword argument.As of v0.11.2 (working in v0.11.1) the above MWE raises the following error:
I believe the error is caused by the updated
_preprocess
function used in_JitWrapper._call
yielding a different set of arguments for theLowered.lowered
than those yielded for theCompiled.compiled
inside ofCompiled.__call__
. For example if we make the following change to_preprocess
, without considering any other implications, the error appears to be resolved:I'm not familiar with the argument donation semantics so perhaps such a modification is unsuitable?
The text was updated successfully, but these errors were encountered: