-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
155 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,60 @@ | ||
# NumPy random number generator API for JAX | ||
# `rng-jax` — JAX random number generation as a NumPy generator | ||
|
||
**This is a proof of concept only.** | ||
|
||
Wraps stateless JAX random number generation in the | ||
[`numpy.random.Generator`](generator) API. | ||
Wraps JAX's stateless random number generation in a class implementing the | ||
[`numpy.random.Generator`](generator) interface. | ||
|
||
```py | ||
from jrng import JRNG | ||
|
||
rng = JRNG(42) | ||
|
||
rng.standard_normal(3) | ||
# Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32) | ||
rng.standard_normal(3) | ||
# Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32) | ||
``` | ||
|
||
The goal of this experiment is to investigate ways in which there can be a | ||
random number generation API that works in tandem with the Python Array API. | ||
|
||
## How it works | ||
|
||
The `JRNG` class works in the obvious way: it keeps track of the JAX `key` and | ||
calls `jax.random.split()` before every random operation. | ||
|
||
## JIT | ||
|
||
The problem with a stateful RNG is that it cannot easily be passed into a | ||
compiled function. However, the `JRNG` class is only "stateful" in that it | ||
keeps track of the current `key`. | ||
|
||
When a `JRNG` pytree is flattened, the resulting child node contains an | ||
_independent_ random key, while the internal state of the existing `JRNG` is | ||
advanced at the same time. This allows passing `JRNG` instances into compiled | ||
functions and still obtaining independent random outputs: | ||
## Example | ||
|
||
```py | ||
import jax | ||
from jrng import JRNG | ||
|
||
def print_key(key): print(jax.random.key_data(key)) | ||
|
||
@jax.jit | ||
def f(x, rng): | ||
return x + rng.standard_normal(x.shape) | ||
>>> import rng_jax | ||
>>> | ||
>>> rng = rng_jax.Generator(42) # same arguments as jax.random.key() | ||
>>> rng.standard_normal(3) | ||
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32) | ||
>>> rng.standard_normal(3) | ||
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32) | ||
``` | ||
|
||
x = jax.numpy.array([1, 2, 3]) | ||
rng = JRNG(42) | ||
## Rationale | ||
|
||
print_key(rng.key) # [ 0 42] | ||
The [Array API](array_api) makes it possible to write array-agnostic Python | ||
libraries. The `rng-jax` package makes it easy to extend this to random number | ||
generation in NumPy and JAX. End users only need to provide a `rng` object, as | ||
usual, which can either be a NumPy one or a `rng_jax.Generator` instance | ||
wrapping JAX's stateless random number generation. | ||
|
||
print(f(x, rng)) # [0.047065 1.6797752 3.9650078] | ||
## How it works | ||
|
||
print_key(rng.key) # [4249898905 2425127087] | ||
The `rng_jax.Generator` class works in the obvious way: it keeps track of the | ||
JAX `key` and calls `jax.random.split()` before every random operation. | ||
|
||
print(f(x, rng)) # [0.60631436 1.0040649 2.4605024 ] | ||
## JIT and native JAX code | ||
|
||
print_key(rng.key) # [ 499334550 3925197703] | ||
``` | ||
The problem with a stateful RNG is that it cannot be passed into a compiled JAX | ||
function. In practice, this is not usually an issue, since the goal of this | ||
package is to work in tandem with the Array API: array-agnostic code is not | ||
usually compiled at low level. Conversely, native JAX code usually expects a | ||
`key`, anyway, not a `rng_jax.Generator` instance. | ||
|
||
However, this mechanism means flattening the `JRNG` pytree changes internal | ||
state (due to the internal details of JAX, it actually advances the random | ||
number generator multiple times). | ||
To interface with a native JAX function expecting a `key`, use the `.key()` | ||
method to obtain a new random key and advance the internal state of the | ||
generator: | ||
|
||
```py | ||
# same initial state as above | ||
key = jax.random.key(42) | ||
print_key(key) # [ 0 42] | ||
|
||
# pytree is flattened 4 times per invocation | ||
key, _ = jax.random.split(key) | ||
key, _ = jax.random.split(key) | ||
key, _ = jax.random.split(key) | ||
key, _ = jax.random.split(key) | ||
print_key(key) # [4249898905 2425127087] | ||
>>> rng = rng_jax.Generator(42) | ||
>>> key = rng.key() | ||
>>> jax.random.normal(key, 3) | ||
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32) | ||
>>> key = rng.key() | ||
>>> jax.random.normal(key, 3) | ||
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32) | ||
``` | ||
|
||
While this is not an ideal solution, it may be an acceptable one: the goal of | ||
this API is to work in tandem with the Array API. Array-agnostic code is not | ||
usually compiled at low level. Using the `JRNG` class _inside_ a compiled | ||
function works without issue. | ||
The right way to compile array-agnostic code is usually to compile the "main" | ||
function at the highest level of the code. Using the `rng_jax.Generator` class | ||
fully _within_ a compiled function works without issue. | ||
|
||
[array-api]: https://data-apis.org/array-api/latest/ | ||
[generator]: https://numpy.org/doc/stable/reference/random/generator.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.