Skip to content

Commit

Permalink
Remove pytree flattening.
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Dec 20, 2024
1 parent 6ee0079 commit 68e968d
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 178 deletions.
106 changes: 41 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,84 +1,60 @@
# NumPy random number generator API for JAX
# `rng-jax` — JAX random number generation as a NumPy generator

**This is a proof of concept only.**

Wraps stateless JAX random number generation in the
[`numpy.random.Generator`](generator) API.
Wraps JAX's stateless random number generation in a class implementing the
[`numpy.random.Generator`](generator) interface.

```py
from jrng import JRNG

rng = JRNG(42)

rng.standard_normal(3)
# Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
rng.standard_normal(3)
# Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
```

The goal of this experiment is to investigate ways in which there can be a
random number generation API that works in tandem with the Python Array API.

## How it works

The `JRNG` class works in the obvious way: it keeps track of the JAX `key` and
calls `jax.random.split()` before every random operation.

## JIT

The problem with a stateful RNG is that it cannot easily be passed into a
compiled function. However, the `JRNG` class is only "stateful" in that it
keeps track of the current `key`.

When a `JRNG` pytree is flattened, the resulting child node contains an
_independent_ random key, while the internal state of the existing `JRNG` is
advanced at the same time. This allows passing `JRNG` instances into compiled
functions and still obtaining independent random outputs:
## Example

```py
import jax
from jrng import JRNG

def print_key(key): print(jax.random.key_data(key))

@jax.jit
def f(x, rng):
return x + rng.standard_normal(x.shape)
>>> import rng_jax
>>>
>>> rng = rng_jax.Generator(42) # same arguments as jax.random.key()
>>> rng.standard_normal(3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> rng.standard_normal(3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
```

x = jax.numpy.array([1, 2, 3])
rng = JRNG(42)
## Rationale

print_key(rng.key) # [ 0 42]
The [Array API](array_api) makes it possible to write array-agnostic Python
libraries. The `rng-jax` package makes it easy to extend this to random number
generation in NumPy and JAX. End users only need to provide a `rng` object, as
usual, which can either be a NumPy one or a `rng_jax.Generator` instance
wrapping JAX's stateless random number generation.

print(f(x, rng)) # [0.047065 1.6797752 3.9650078]
## How it works

print_key(rng.key) # [4249898905 2425127087]
The `rng_jax.Generator` class works in the obvious way: it keeps track of the
JAX `key` and calls `jax.random.split()` before every random operation.

print(f(x, rng)) # [0.60631436 1.0040649 2.4605024 ]
## JIT and native JAX code

print_key(rng.key) # [ 499334550 3925197703]
```
The problem with a stateful RNG is that it cannot be passed into a compiled JAX
function. In practice, this is not usually an issue, since the goal of this
package is to work in tandem with the Array API: array-agnostic code is not
usually compiled at low level. Conversely, native JAX code usually expects a
`key`, anyway, not a `rng_jax.Generator` instance.

However, this mechanism means flattening the `JRNG` pytree changes internal
state (due to the internal details of JAX, it actually advances the random
number generator multiple times).
To interface with a native JAX function expecting a `key`, use the `.key()`
method to obtain a new random key and advance the internal state of the
generator:

```py
# same initial state as above
key = jax.random.key(42)
print_key(key) # [ 0 42]

# pytree is flattened 4 times per invocation
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
print_key(key) # [4249898905 2425127087]
>>> rng = rng_jax.Generator(42)
>>> key = rng.key()
>>> jax.random.normal(key, 3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> key = rng.key()
>>> jax.random.normal(key, 3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
```

While this is not an ideal solution, it may be an acceptable one: the goal of
this API is to work in tandem with the Array API. Array-agnostic code is not
usually compiled at low level. Using the `JRNG` class _inside_ a compiled
function works without issue.
The right way to compile array-agnostic code is usually to compile the "main"
function at the highest level of the code. Using the `rng_jax.Generator` class
fully _within_ a compiled function works without issue.

[array-api]: https://data-apis.org/array-api/latest/
[generator]: https://numpy.org/doc/stable/reference/random/generator.html
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[project]
name = "jrng"
description = "NumPy random number generator API for JAX"
name = "rng-jax"
description = "JAX random number generation as a NumPy generator"
readme = "README.md"
license = "MIT"
authors = [
Expand Down Expand Up @@ -32,8 +32,8 @@ test = [
]

[project.urls]
Repository = "https://github.com/glass-dev/jrng"
Issues = "https://github.com/glass-dev/jrng"
Repository = "https://github.com/glass-dev/rng-jax"
Issues = "https://github.com/glass-dev/rng-jax"

[tool.hatch.version]
source = "vcs"
Expand Down
41 changes: 17 additions & 24 deletions jrng.py → rng_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
NumPy random number generator API for JAX.
JAX random number generation as a NumPy generator.
"""

import math
from typing import Literal, Self, TypeAlias

from jax import Array
Expand All @@ -25,7 +26,6 @@
split,
uniform,
)
from jax.tree_util import register_pytree_node_class
from jax.typing import ArrayLike, DTypeLike


Expand All @@ -47,13 +47,13 @@ def _s(size: Size, *bcast: ArrayLike) -> tuple[int, ...]:
return size


@register_pytree_node_class
class JRNG:
class Generator:
"""
Wrapper class for JAX random number generation.
"""

__slots__ = ("key",)
__slots__ = ("_key",)
_key: Array

@classmethod
def from_key(cls, key: Array) -> Self:
Expand All @@ -63,44 +63,37 @@ def from_key(cls, key: Array) -> Self:
if not isinstance(key, Array) or not issubdtype(key.dtype, prng_key):
raise ValueError("not a random key")
rng = object.__new__(cls)
rng.key = key
rng._key = key
return rng

def __init__(self, seed: int | ArrayLike, *, impl: str | None = None) -> None:
"""
Create a wrapper instance with a new key.
"""
self.key = key(seed, impl=impl)
self._key = key(seed, impl=impl)

@property
def __key(self) -> Array:
"""
Return next key for sampling while updating internal state.
"""
self.key, key = split(self.key)
self._key, key = split(self._key)
return key

def tree_flatten(self) -> tuple[tuple[Array], None]:
def key(self, size: Size = None) -> Array:
"""
Return pytree representation of JRNG instance.
Return random key, advancing internal state.
"""
return (self.__key,), None
shape = _s(size)
keys = split(self._key, 1 + math.prod(shape))
self._key = keys[0]
return keys[1:].reshape(shape)

@classmethod
def tree_unflatten(cls, aux_data: None, children: tuple[Array]) -> Self:
"""
Construct JRNG instance from pytree representation.
"""
(key,) = children
rng = object.__new__(cls)
rng.key = key
return rng

def spawn(self, n: int) -> list[Self]:
def spawn(self, n_children: int) -> list[Self]:
"""
Create new independent child generators.
"""
self.key, *subkeys = split(self.key, num=n + 1)
self._key, *subkeys = split(self._key, num=n_children + 1)
return list(map(self.from_key, subkeys))

def integers(
Expand All @@ -126,7 +119,7 @@ def random(self, size: Size = None, dtype: DTypeLike = float) -> Array:
"""
Return random floats in the half-open interval [0.0, 1.0).
"""
self.key, key = split(self.key)
self._key, key = split(self._key)
return uniform(self.__key, _s(size), dtype)

def choice(
Expand Down
85 changes: 0 additions & 85 deletions test_jrng.py

This file was deleted.

Loading

0 comments on commit 68e968d

Please sign in to comment.