Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eqx.filter_shard; test + update examples/parallelism.ipynb #691

Merged
merged 4 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Most users find that this is a simpler API when working with complicated PyTrees

::: equinox.filter_eval_shape

---

::: equinox.filter_shard

## Automatic differentiation

::: equinox.filter_grad
Expand Down
1 change: 1 addition & 0 deletions equinox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
tree_deserialise_leaves as tree_deserialise_leaves,
tree_serialise_leaves as tree_serialise_leaves,
)
from ._sharding import filter_shard as filter_shard
from ._tree import (
tree_at as tree_at,
tree_check as tree_check,
Expand Down
41 changes: 41 additions & 0 deletions equinox/_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Union

import jax
import jax.lax as lax
from jaxlib.xla_extension import Device
from jaxtyping import PyTree

from ._filters import combine, is_array, partition


def filter_shard(
x: PyTree[Any], device_or_shardings: Union[Device, jax.sharding.Sharding]
):
"""Filtered transform combining `jax.lax.with_sharding_constraint`
and `jax.device_put`.

Enforces sharding within a JIT'd computation (That is, how an array is
split between multiple devices, i.e. multiple GPUs/TPUs.), or moves `x` to
a device.

**Arguments:**

- `x`: A PyTree, with potentially a mix of arrays and non-arrays on the leaves.
They will have their shardings constrained.
- `device_or_shardings`: Either a singular device (e.g. CPU or GPU) or PyTree of
sharding specifications. The structure should be a prefix of `x`.

**Returns:**

A copy of `x` with the specified sharding constraints.

!!! Example
See also the [autoparallelism example](../../../examples/parallelism).
"""
if isinstance(device_or_shardings, Device):
shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
else:
shardings = device_or_shardings
dynamic, static = partition(x, is_array)
dynamic = lax.with_sharding_constraint(dynamic, shardings)
return combine(dynamic, static)
133 changes: 99 additions & 34 deletions examples/parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "83bba892-5425-4eed-a7f7-9c325fe5cc53",
"metadata": {},
"outputs": [],
Expand All @@ -34,7 +34,7 @@
"import jax.experimental.mesh_utils as mesh_utils\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import jax.sharding as sharding\n",
"import jax.sharding as jshard\n",
"import numpy as np\n",
"import optax # https://github.com/deepmind/optax\n",
"\n",
Expand All @@ -57,49 +57,64 @@
"opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n",
"\n",
"\n",
"# Loss function for a batch of data\n",
"def compute_loss(model, x, y):\n",
" pred_y = jax.vmap(model)(x)\n",
" return jnp.mean((y - pred_y) ** 2)\n",
"\n",
"\n",
"@eqx.filter_jit\n",
"def make_step(model, opt_state, x, y):\n",
" grads = eqx.filter_grad(compute_loss)(model, x, y)\n",
" updates, opt_state = optim.update(grads, opt_state)\n",
" model = eqx.apply_updates(model, updates)\n",
" return model, opt_state"
"# Simple dataloader; randomly slices our dataset and shuffles between epochs.\n",
"# In NumPy for speed, as our dataset is small enough to fit entirely in host memory.\n",
"#\n",
"# For larger datasets (that require loading from disk) then use PyTorch's `DataLoader`\n",
"# or TensorFlow's `tf.data`.\n",
"def train_dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" indices = np.arange(dataset_size)\n",
" while True:\n",
" perm = np.random.permutation(indices)\n",
" start = 0\n",
" end = batch_size\n",
" while end <= dataset_size:\n",
" batch_perm = perm[start:end]\n",
" yield tuple(array[batch_perm] for array in arrays)\n",
" start = end\n",
" end = start + batch_size"
]
},
{
"cell_type": "markdown",
"id": "0fb345b0-c9b3-44df-94e8-d74c7ad172b8",
"metadata": {},
"source": [
"Here's a very simple dataloader, that randomly shuffles and slices our dataset. We keep everything in pure-NumPy for speed, as this all happens on the host, prior to moving our data to our devices. (Which will often be a cluster of GPUs.)\n",
"Okay, now the interesting things start happening!\n",
"\n",
"In practice it's also common to load data using either PyTorch's `DataLoader` or TensorFlow's `tf.data` API; see [here](../mnist/) for more details."
"First, we're going to arrange to \\\"donate\\\" memory, which specifes that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)\n",
"\n",
"Second, we're going to use `eqx.filter_shard` to assert (on the inputs) and enforce (on the outputs) how each array is split across each of our devices. As we're doing data parallelism in this example, then we'll be replicating our model parameters on to every device, whilst sharding our data between devices."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "fd94db04-9fe4-4530-808e-945becef9df5",
"metadata": {},
"outputs": [],
"source": [
"def dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" indices = np.arange(dataset_size)\n",
" while True:\n",
" perm = np.random.permutation(indices)\n",
" start = 0\n",
" end = batch_size\n",
" while end <= dataset_size:\n",
" batch_perm = perm[start:end]\n",
" yield tuple(array[batch_perm] for array in arrays)\n",
" start = end\n",
" end = start + batch_size"
"@eqx.filter_jit(donate=\"all\")\n",
"def train_step(model, opt_state, x, y, sharding):\n",
" replicated = sharding.replicate()\n",
" model, opt_state = eqx.filter_shard((model, opt_state), replicated)\n",
" x, y = eqx.filter_shard((x, y), sharding)\n",
"\n",
" grads = eqx.filter_grad(compute_loss)(model, x, y)\n",
" updates, opt_state = optim.update(grads, opt_state)\n",
" model = eqx.apply_updates(model, updates)\n",
"\n",
" model, opt_state = eqx.filter_shard((model, opt_state), replicated)\n",
"\n",
" return model, opt_state"
]
},
{
Expand All @@ -112,7 +127,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "32c6b58e-f72f-4dd4-bf2c-f1dc75643eda",
"metadata": {
"tags": []
Expand All @@ -121,11 +136,57 @@
"source": [
"num_devices = len(jax.devices())\n",
"devices = mesh_utils.create_device_mesh((num_devices, 1))\n",
"shard = sharding.PositionalSharding(devices)\n",
"\n",
"for step, (x, y) in zip(range(num_steps), dataloader((xs, ys), batch_size)):\n",
" x, y = jax.device_put((x, y), shard)\n",
" model, opt_state = make_step(model, opt_state, x, y)"
"sharding = jshard.PositionalSharding(devices)\n",
"replicated = sharding.replicate()\n",
"\n",
"model = eqx.filter_shard(model, replicated)\n",
"for step, (x, y) in zip(\n",
" range(1, num_steps + 1), train_dataloader((xs, ys), batch_size)\n",
"):\n",
" x, y = eqx.filter_shard((x, y), sharding)\n",
" model, opt_state = train_step(model, opt_state, x, y, sharding)"
]
},
{
"cell_type": "markdown",
"id": "036c2ddf",
"metadata": {},
"source": [
"Not strictly related to parallelism, but a common question at this point: if we want to evaluate our model, then we probably don't want to donate its parameters (which would render the model unusable, as all its memory is freed). As such, inference looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "299933ca",
"metadata": {},
"outputs": [],
"source": [
"def eval_dataloader(arrays, batch_size):\n",
" dataset_size = arrays[0].shape[0]\n",
" assert all(array.shape[0] == dataset_size for array in arrays)\n",
" start = 0\n",
" end = batch_size\n",
" while start < dataset_size:\n",
" yield tuple(array[start:end] for array in arrays)\n",
" start = end\n",
" end = start + batch_size\n",
"\n",
"\n",
"@eqx.filter_jit(donate=\"all-except-first\")\n",
"def evaluate(model, x, y, sharding):\n",
" replicated = sharding.replicate()\n",
" model = eqx.filter_shard(model, replicated)\n",
" x, y = eqx.filter_shard((x, y), sharding)\n",
" return compute_loss(model, x, y)\n",
"\n",
"\n",
"loss = 0\n",
"num_batches = 0\n",
"for x, y in eval_dataloader((xs, ys), batch_size):\n",
" loss = loss + evaluate(model, x, y, sharding).item()\n",
" num_batches = num_batches + 1\n",
"print(f\"train loss={loss/num_batches}\")"
]
},
{
Expand All @@ -149,7 +210,11 @@
"\n",
"There are multiple types of parallelism. In this example we demonstrated _data parallelism_, in which we parallelise over the data. This is one of the simplest to set up, and often very effective.\n",
"\n",
"For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic. :)\n",
"For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic.\n",
"\n",
"**{`jax.device_put`, `jax.lax.with_sharding_constraint`} vs `eqx.filter_shard`**\n",
"\n",
"These are the usual story in Equinox: we have a filtered version of the operation that leaves any non-arrays alone. In this case, they are used because we have an activation function (i.e. just some arbitrary Python function, which isn't an array) as part of the MLP.\n",
"\n",
"**Further reading**\n",
"\n",
Expand All @@ -163,9 +228,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "jax39",
"display_name": ".venv",
"language": "python",
"name": "jax39"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -177,7 +242,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import equinox as eqx
import jax
import jax.random as jr


[cpu] = jax.local_devices(backend="cpu")
sharding = jax.sharding.PositionalSharding([cpu])


def test_sharding():
mlp = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))

eqx.filter_shard(mlp, cpu)

@eqx.filter_jit
def f(x):
a, b = eqx.partition(x, eqx.is_array)
a = jax.tree_map(lambda x: x + 1, a)
x = eqx.combine(a, b)
return eqx.filter_shard(x, sharding)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that no-op computations are special-cased by XLA, so this might not actually test anything. I think this should just do +1 or something simple.

Copy link
Contributor Author

@homerjed homerjed Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the test here to actually do something, basically just filters the params out and adds one to them, then filter_shards them in the JIT'd function.


f(mlp)
Loading