Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better compatibility between eqx.filter_eval_shape and eqx.tree_serialise_leaves #259

Closed
patrick-kidger opened this issue Feb 13, 2023 · 0 comments · Fixed by #623
Closed
Labels
feature New feature next Higher-priority items

Comments

@patrick-kidger
Copy link
Owner

Right now it's possible to load a model from a checkpoint by doing something like:

def run(..., checkpoint=None):
    model = Model(...)
    if checkpoint is not None:
        model = eqx.tree_deserialise_leaves(checkpoint, model)
    ...

If one really cares about efficiency then the first line can be replaced by model = eqx.filter_eval_shape(Model, ...). However this returns ShapeDtypeStructs rather than arrays, so the deserialisation line won't work. Right now you need to manually hack something together. We should instead make this easier to do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature next Higher-priority items
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant