Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update parallelism.ipynb #804

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"source": [
"Okay, now the interesting things start happening!\n",
"\n",
"First, we're going to arrange to \\\"donate\\\" memory, which specifes that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)\n",
"First, we're going to arrange to \"donate\" memory, which specifes that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)\n",
"\n",
"Second, we're going to use `eqx.filter_shard` to assert (on the inputs) and enforce (on the outputs) how each array is split across each of our devices. As we're doing data parallelism in this example, then we'll be replicating our model parameters on to every device, whilst sharding our data between devices."
]
Expand Down Expand Up @@ -198,9 +198,7 @@
"\n",
"Once you've specified how you want to split up your input data, then JAX does the rest of it for you! It takes your single JIT'd computation (which you wrote as if you were targeting a single huge device), and it then automatically determined how to split up that computation and have each device handle part of the computation. This is JAX's computation-follows-data approach to autoparallelism.\n",
"\n",
"If you ran the above example on a cluster of NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running `nvidia-smi` from the command line. You can also use `jax.debug.visualize_array_sharding(array)` to inspect the sharding manually.\n",
"\n",
"One possible optimisation here is to re-use the memory used by the input arrays, to store the output arrays. This often improves speed a little bit. This is disabled by default, but can be enabled by passing `eqx.filter_jit(donate=\"all\")`.\n",
"If you ran the above example on a computer with multiple NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running `nvidia-smi` from the command line. You can also use `jax.debug.visualize_array_sharding(array)` to inspect the sharding manually.\n",
"\n",
"**What about pmap?**\n",
"\n",
Expand Down