Filter an array based on another boolean array (JIT compiled) #10105
-
I think my request might be inherently impossible to achieve but I still want to ask. This is a simplified version of my case: I have an array of points, say I want create a new array such that only points which have a In standard python this is trivial, you can do However I cannot manage to find a solution in compiled jax with jit since the shape of The reason why I don't want to use |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 5 replies
-
No, unfortunately If you have a specific application in mind that you're having trouble re-expressing this way, feel free to open a discussion with more detail. |
Beta Was this translation helpful? Give feedback.
-
Hey!
But it doesn't help if the computation isn't much more expensive than memory access. import jax
from jax import numpy as jnp
from jax import lax
def jax_filter(conds, xs):
"""
returns:
result array
number of valid elements
"""
res = jnp.full_like(xs, jnp.nan)
def body_fun(carry, v):
res, idx = carry
cond, x = v
# there is no need for a jnp.where or lax.cond
# since only [0, last_idx) is valid
return (res.at[idx].set(x), idx + cond), None
return lax.scan(body_fun, (res, 0), (conds, xs))[0]
def valid_chunks_vmap(f, num_chunks: int):
fs = jax.vmap(f)
def valid_chunks_f(xs, num_valid_chunks):
def body_fun(idx, x_chunk):
def no_op(x):
s = jax.eval_shape(fs, x)
return jnp.full(s.shape, jnp.nan, s.dtype)
return idx + 1, lax.cond(idx < num_valid_chunks, fs, no_op, x_chunk)
y_chunks = lax.scan(body_fun, 0, jnp.reshape(xs, (num_chunks, -1, *xs.shape[1:])))[1]
return jnp.reshape(y_chunks, (-1, *y_chunks.shape[2:]))
return valid_chunks_f
def f(x, y):
return jnp.sum((x - y) ** 2)
@jax.jit
def filtered_distances(conds, xs):
xs, n_valid = jax_filter(conds, xs)
num_chunks = 16
assert xs.shape[0] % num_chunks == 0
chunk_size = xs.shape[0] // num_chunks
num_valid_chunks = (n_valid + chunk_size - 1) // chunk_size
def inner(x):
return valid_chunks_vmap(lambda z: f(x, z), num_chunks)(xs, num_valid_chunks)
return valid_chunks_vmap(inner, num_chunks)(xs, num_valid_chunks)
xs = jnp.arange(256, dtype=jnp.float32)
conds = (jnp.arange(256) % 2).astype(jnp.bool_)
print(filtered_distances(conds, xs)) |
Beta Was this translation helpful? Give feedback.
-
oops, as @jakevdp pointed out, import jax
from jax import numpy as jnp
from jax import lax
@jax.jit
@jax.vmap
def ff_1(conds, xs):
cumsum = jnp.cumsum(conds)
return jnp.zeros_like(xs).at[cumsum - 1].add(jnp.where(conds, xs, 0)), cumsum[-1]
@jax.jit
@jax.vmap
def ff_2(conds, xs):
cumsum = jnp.cumsum(conds)
intervals = jnp.zeros_like(cumsum).at[cumsum].add(1)
idx = jnp.cumsum(intervals)
return xs[idx], cumsum[-1]
@jax.jit
@jax.vmap
def ff_3(conds, xs):
return xs[jnp.where(conds, size=len(xs))], conds.sum()
@jax.jit
@jax.vmap
def ff_4(conds, xs):
res = jnp.full_like(xs, jnp.nan)
def body_fun(carry, v):
res, idx = carry
cond, x = v
return (res.at[idx].set(x), idx + cond), None
return lax.scan(body_fun, (res, 0), (conds, xs))[0]
bsz = 1024
n = 65536
xs = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32), (bsz, n))
conds = jnp.broadcast_to(lax.convert_element_type(jnp.arange(n) % 2, jnp.bool_), (bsz, n))
def timer(f):
from time import time
f() # warmup
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
timer(lambda: jax.block_until_ready(ff_1(conds, xs)))
timer(lambda: jax.block_until_ready(ff_2(conds, xs)))
timer(lambda: jax.block_until_ready(ff_3(conds, xs)))
timer(lambda: jax.block_until_ready(ff_4(conds, xs)))
8x larger
|
Beta Was this translation helpful? Give feedback.
No, unfortunately
jax.jit
requires statically-shaped arrays, so there is no way to do what you're asking within a JIT-compiled function. The standard workaround is to re-express your computation in terms of statically shaped arrays, and I find in practice this can often be done with some thought. Another possiblity is to split your computation such that the dynamic arrays are generated outside of a JIT context, and the computation on those arrays are then done within JIT.If you have a specific application in mind that you're having trouble re-expressing this way, feel free to open a discussion with more detail.