Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime errors inside eqx.filter_jit are now very readable. #803

Merged
merged 1 commit into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
is_inexact_array_like as is_inexact_array_like,
partition as partition,
)
from ._jit import filter_jit as filter_jit
from ._jit import EquinoxRuntimeError as EquinoxRuntimeError, filter_jit as filter_jit
from ._make_jaxpr import filter_make_jaxpr as filter_make_jaxpr
from ._module import (
field as field,
Expand Down
93 changes: 34 additions & 59 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Int, PyTree

from . import _jit
from ._ad import filter_custom_jvp
from ._config import EQX_ON_ERROR, EQX_ON_ERROR_BREAKPOINT_FRAMES
from ._doc_utils import doc_remove_args
from ._filters import combine, is_array, partition
from ._jit import filter_jit
from ._misc import currently_jitting
from ._unvmap import unvmap_any, unvmap_max

Expand Down Expand Up @@ -52,76 +52,50 @@ def _nan_like(x: Union[Array, np.ndarray]) -> Union[Array, np.ndarray]:
"""


_on_error_msg = """
---------------------------------------------------------------------------

An error occurred during the runtime of your JAX program.

---------------------------------------------------------------------------

Traceback:

{stack}

---------------------------------------------------------------------------

Error message:

{msg}

---------------------------------------------------------------------------

You have a few options to try and debug this issue.

1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.

If taking this approach, then it is recommended to also set
`EQX_ON_ERROR_BREAKPOINT_FRAMES=<some number>`, corresponding to the number of frames to
add to the debugger.

If you get trace-time errors from JAX then try reducing the value of
`EQX_ON_ERROR_BREAKPOINT_FRAMES`. See
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.

2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.

3) For more suggestions, see `https://docs.kidger.site/equinox/api/debug/`.
"""


_frames_msg = f"""
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames.
You can control this value by setting the environment variable
`EQX_ON_ERROR_BREAKPOINT_FRAMES=<some number>`.
-------------------

Note that setting large values of this number may lead to crashes at trace time; see
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.
Opening a breakpoint with {EQX_ON_ERROR_BREAKPOINT_FRAMES} frames. You can control this
value by setting the environment variable `EQX_ON_ERROR_BREAKPOINT_FRAMES=<some value>`.
(Note that setting large values of this number may lead to crashes at trace time; see
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.)
"""


# This is never actually surfaced to an end user -- it always becomes an XlaRuntimeError
class EqxRuntimeError(RuntimeError):
# The name of this is looked for in `_jit.py` in order to determine if we have a
# runtime error -- and if so then the custom reporting will engage.
#
# Note that this is *not* the class that is raised at runtime to a user: this is an
# internal implementation detail of Equinox. It is caught by `equinox.filter_jit` and
# replaced with the actual run time error. (Without any of the misleading baggage that
# XLA would otherwise attach.)
class _EquinoxRuntimeError(RuntimeError):
pass


class EquinoxTracetimeError(RuntimeError):
pass


EquinoxTracetimeError.__module__ = "equinox"


@filter_custom_jvp
def _error(x, pred, index, *, msgs, on_error, stack):
if on_error == "raise":

def raises(_index):
raise EqxRuntimeError(
_on_error_msg.format(stack=stack, msg=msgs[_index.item()])
# Sneakily smuggle out the information about the error. Inspired by
# `sys.last_value`.
_jit.last_msg = msg = msgs[_index.item()]
_jit.last_stack = stack
raise _EquinoxRuntimeError(
f"{msg}\n\n\n"
"--------------------\n"
"An error occurred during the runtime of your JAX program! "
"Unfortunately you do not appear to be using `equinox.filter_jit` "
"(perhaps you are using `jax.jit` instead?) and so further information "
"about the error cannot be displayed. (Probably you are seeing a very "
"large but uninformative error message right now.) Please wrap your "
"program with `equinox.filter_jit`.\n"
"--------------------\n"
)

def tpu_msg(_out, _index):
Expand All @@ -148,7 +122,7 @@ def handle_error(): # pyright: ignore

def display_msg(_index):
print(_frames_msg)
print(msgs[_index.item()])
print("equinox.EquinoxRuntimeError: " + msgs[_index.item()])
return _index

def to_nan(_index):
Expand Down Expand Up @@ -356,11 +330,12 @@ def branched_error_if_impl(
return x

tb = None
frames = list(traceback.walk_stack(None))
for f, lineno in reversed(frames):
for f, lineno in traceback.walk_stack(None):
if f.f_locals.get("__equinox_filter_jit__", False):
break
if traceback_util.include_frame(f):
tb = types.TracebackType(tb, f, f.f_lasti, lineno)
stack = "\n".join(traceback.format_tb(tb))
stack = "".join(traceback.format_tb(tb)).rstrip()
dynamic_x, static_x = partition(x, is_array)
flat = jtu.tree_leaves(dynamic_x)
if len(flat) == 0:
Expand All @@ -373,7 +348,7 @@ def branched_error_if_impl(

# filter_jit does some work to produce nicer runtime error messages.
# We also place it here to ensure a consistent experience when using JAX in eager mode.
branched_error_if_impl_jit = filter_jit(branched_error_if_impl)
branched_error_if_impl_jit = _jit.filter_jit(branched_error_if_impl)


def assert_dce(
Expand Down
154 changes: 92 additions & 62 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools as ft
import inspect
import sys
import warnings
from collections.abc import Callable
from typing import Any, Literal, overload, TypeVar
Expand All @@ -22,6 +23,7 @@
from ._deprecate import deprecated_0_10
from ._doc_utils import doc_remove_args
from ._filters import combine, is_array, partition
from ._misc import currently_jitting
from ._module import field, Module, module_update_wrapper, Partial, Static


Expand Down Expand Up @@ -110,38 +112,54 @@ class XlaRuntimeError(Exception):
pass


def _modify_traceback(e: Exception):
# Remove JAX's UnfilteredStackTrace, with its huge error messages.
e.__cause__ = None
# Remove _JitWrapper.__call__ and _JitWrapper._call and Method.__call__ from the
# traceback
tb = e.__traceback__ = e.__traceback__.tb_next.tb_next.tb_next # pyright: ignore
try:
# See https://github.com/google/jax/blob/69cd3ebe99ce12a9f22e50009c00803a095737c7/jax/_src/traceback_util.py#L190 # noqa: E501
jax.lib.xla_extension.replace_thread_exc_traceback(tb) # pyright: ignore
except AttributeError:
pass
# IPython ignores __tracebackhide__ directives for the frame that actually raises
# the error. We fix that here.
try:
get_ipython() # pyright: ignore
except NameError:
pass
else:
import IPython # pyright: ignore

# Check that IPython supports __tracebackhide__
if IPython.version_info[:2] >= (7, 17): # pyright: ignore
tb_stack = []
while tb is not None:
tb_stack.append(tb)
tb = tb.tb_next
for tb in reversed(tb_stack):
if not tb.tb_frame.f_locals.get("__tracebackhide__", False):
tb.tb_next = None
break
else:
e.__traceback__ = None
# This is the class we use to raise runtime errors from `eqx.error_if`.
class EquinoxRuntimeError(RuntimeError):
pass


# Magic value that means error messages are displayed as `{__qualname__}: ...` rather
# than `{__module__}.{__qualname__}`. (At least, I checked the default Python
# interpreter, the default Python REPL, ptpython, ipython, pdb, and ipdb.)
EquinoxRuntimeError.__module__ = "builtins"
# Note that we don't also override `__name__` or `__qualname__`. Suppressing the
# `equinox._jit` module bit is useful for readability, but we don't want to go so far as
# deleting the name altogether. (Or even e.g. setting it to the 'Above is the stack...'
# first section of our error message below!) The reason is that whilst that gives a
# nicer displayed error in default Python, it doesn't necessarily do as well with other
# tools, e.g. debuggers. So what we have here is a compromise.


last_msg = None
last_stack = None


_on_error_msg = """Above is the stack outside of JIT. Below is the stack inside of JIT:
{stack}
equinox.EquinoxRuntimeError: {msg}

-------------------

An error occurred during the runtime of your JAX program.

1) Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful
way to debug such errors. This can be interacted with using most of the usual commands
for the Python debugger: `u` and `d` to move up and down frames, the name of a variable
to print its value, etc.

2) You may also like to try setting `JAX_DISABLE_JIT=1`. This will mean that you can
(mostly) inspect the state of your program as if it was normal Python.

3) See `https://docs.kidger.site/equinox/api/debug/` for more suggestions.
"""


class _FilteredStderr:
def __init__(self, stderr):
self.stderr = stderr

def write(self, data: str):
if "_EquinoxRuntimeError" not in data:
self.stderr.write(data)


class _JitWrapper(Module):
Expand All @@ -160,6 +178,9 @@ def __wrapped__(self):

def _call(self, is_lower, args, kwargs):
__tracebackhide__ = True
# Used by our error messages when figuring out where to stop walking the stack.
if not currently_jitting():
__equinox_filter_jit__ = True # noqa: F841
info = (
self._signature,
self._dynamic_fun,
Expand All @@ -178,49 +199,58 @@ def _call(self, is_lower, args, kwargs):
_postprocess, # pyright: ignore
)
else:
if self.filter_warning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable*"
)
# Filter stderr to remove our default "you don't seem to be using
# `equinox.filter_jit`" message. (Which also comes with a misleading stack
# trace from XLA.)
stderr = sys.stderr
sys.stderr = _FilteredStderr(stderr)
try:
if self.filter_warning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable*"
)
out = self._cached(dynamic_donate, dynamic_nodonate, static)
else:
out = self._cached(dynamic_donate, dynamic_nodonate, static)
else:
out = self._cached(dynamic_donate, dynamic_nodonate, static)
except XlaRuntimeError as e:
# Catch Equinox's runtime errors, and re-raise them with actually useful
# information. (By default XlaRuntimeError produces a lot of terrifying
# but useless information.)
if (
last_msg is not None
and last_stack is not None
and "_EquinoxRuntimeError: " in str(e)
):
# We check `last_msg` and `last_stack` just in case. I'm not sure
# what happens in distributed/multiprocess environments. Is the
# callback necessarily executed in the same interpreter as we are in
# here?
raise EquinoxRuntimeError(
_on_error_msg.format(msg=last_msg, stack=last_stack)
) from None
# `from None` to hide the large but uninformative XlaRuntimeError.
else:
raise
finally:
sys.stderr = stderr
return _postprocess(out)

def __call__(self, /, *args, **kwargs):
__tracebackhide__ = True
try:
return self._call(False, args, kwargs)
except XlaRuntimeError as e:
# Catch Equinox's runtime errors, and strip the more intimidating parts of
# the error message.
if len(e.args) != 1 or not isinstance(e.args[0], str):
raise # No idea if this ever happens. But if it does, just bail.
(msg,) = e.args
if "EqxRuntimeError: " in msg:
_, msg = msg.split("EqxRuntimeError: ", 1)
msg, *_ = msg.rsplit("\n\nAt:\n", 1)
e.args = (msg,)
if jax.config.jax_traceback_filtering in ( # pyright: ignore
None,
"auto",
):
_modify_traceback(e)
except EquinoxRuntimeError as e:
# Use a two-part try/except here and in `_call` to delete the
# `raise EquinoxRuntimeError` line from the stack trace.
e.__traceback__ = None
raise
# I considered also catching `Exception`, and removing the terrifying-looking
# JAX exception that occurs by default.
# This ends up being difficult to get working reliably (e.g. KeyError has a
# different __str__ so modifying the `.args` is hard/undefined; JAX errors have
# a different __init__ so overwriting __str__ in a new class ends up requiring
# magic; taking a different approach and overwriting sys.excepthook is ignored
# under IPython, ...)
# All in all, not worth it.

def lower(self, /, *args, **kwargs) -> Lowered:
return self._call(True, args, kwargs)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down
Loading