-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
eqx.filter_shard
; test + update examples/parallelism.ipynb
#691
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Any, Union | ||
|
||
import jax | ||
import jax.lax as lax | ||
from jaxlib.xla_extension import Device | ||
from jaxtyping import PyTree | ||
|
||
from ._filters import combine, is_array, partition | ||
|
||
|
||
def filter_shard( | ||
x: PyTree[Any], device_or_shardings: Union[Device, jax.sharding.Sharding] | ||
): | ||
"""Filtered transform combining `jax.lax.with_sharding_constraint` | ||
and `jax.device_put`. | ||
|
||
Enforces sharding within a JIT'd computation (That is, how an array is | ||
split between multiple devices, i.e. multiple GPUs/TPUs.), or moves `x` to | ||
a device. | ||
|
||
**Arguments:** | ||
|
||
- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. | ||
They will have their shardings constrained. | ||
- `device_or_shardings`: Either a singular device (e.g. CPU or GPU) or PyTree of | ||
sharding specifications. The structure should be a prefix of `x`. | ||
|
||
**Returns:** | ||
|
||
A copy of `x` with the specified sharding constraints. | ||
|
||
!!! Example | ||
See also the [autoparallelism example](../../../examples/parallelism). | ||
""" | ||
if isinstance(device_or_shardings, Device): | ||
shardings = jax.sharding.SingleDeviceSharding(device_or_shardings) | ||
else: | ||
shardings = device_or_shardings | ||
dynamic, static = partition(x, is_array) | ||
dynamic = lax.with_sharding_constraint(dynamic, shardings) | ||
return combine(dynamic, static) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import equinox as eqx | ||
import jax | ||
import jax.random as jr | ||
|
||
|
||
[cpu] = jax.local_devices(backend="cpu") | ||
sharding = jax.sharding.PositionalSharding([cpu]) | ||
|
||
|
||
def test_sharding(): | ||
mlp = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0)) | ||
|
||
eqx.filter_shard(mlp, cpu) | ||
|
||
@eqx.filter_jit | ||
def f(x): | ||
a, b = eqx.partition(x, eqx.is_array) | ||
a = jax.tree_map(lambda x: x + 1, a) | ||
x = eqx.combine(a, b) | ||
return eqx.filter_shard(x, sharding) | ||
|
||
f(mlp) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that no-op computations are special-cased by XLA, so this might not actually test anything. I think this should just do
+1
or something simple.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the test here to actually do something, basically just filters the params out and adds one to them, then
filter_shard
s them in the JIT'd function.