-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copy-view behaviour and mutating arrays #24
Comments
@honnibal (of spaCy fame) left a really good comment on this topic on the Twitter announcement thread (https://twitter.com/honnibal/status/1295369359653842945):
I think reframing things around optionally reused buffers could be a good way to bridge the gap between mutation based and non-mutation based APIs. For example, most of JAX's APIs are based on non-mutating pure functions but JAX has notion of "donated" arguments ( Interestingly, operations like Would it be too much API innovation to add a new There is also the challenge of how to spell optional buffer reuse for indexing assignment. Perhaps something like |
Maybe not too much, it's only a single keyword after all. It seems to be an incomplete design though - in the absence of whole-program optimization, the semantics of
Maybe I'm missing something here, but I think this
Yep, that's the harder part. Disregarding Regarding your option 3 (no mutation at all), I'm afraid that'll be a reason for outright rejecting of the whole standard from the community - it requires a whole change of thinking and a ton of code changes. Your option 2 (make support for in-place operations optional) seems like the next-best thing. The most likely result will be that functionality that relies on in-place mutation won't support JAX for the time being, until the |
A procedural thought: for this kind of addition of something that doesn't exist yet, we should probably have a separate status in the standard, like "provisional". Like a recommendation to - if one wants to add a feature like that - it must be spelled like X and behave like Y. That guarantees that projects experimenting with this because they like to improve on |
This might become something like:
where the values that end up filling That said, I do think something like "may overwrite" is a better way to spell this than allowing filling into arbitrary buffers, closing in spirit both to working buffers in LAPACK and Python's own inplace arithmetic. As a user, you would write:
The values of Now inplace operations now become an optimization detail, rather than a requirement as would be the case for |
yes, that's the footgun I was worried about
That's better, I can see the appear of that. A boolean keyword may not generalize well for multiple inputs, which I think was a point @honnibal was making with "one or more buffers". It's also why So |
It seems like the |
@rgommers what do you mean by "in-place operations that are unambiguous"? Is this referring specifically to indexing based assignment and Making views immutable sounds pretty reasonable sounds like a pretty reasonable design decision. I don't know if current array libraries would be happy with that (there are certainly existing cases for views are written to), but at the least this sounds like exactly the sort of design decision we should allow. That said, if we allow "immutable" arrays in the case of views, why not allow raising an error for mutating arrays under other circumstances, too? In that case, I don't think either (4) or (5) is much better than (2), making in-place operations optional, with an understanding that this will make it harder to write efficient generic logic in some cases (but this is likely already impossible in general). To elaborate:
Examples of in-place arithmetic: def a(x): # bad, assumes x is modified in-place
x += 1
x /= 2
def b(x): # bad, assumes x is copied by +=
x += 1
return x/2
def c(x): # good, in-place arithmetic is only used as a hint within a function
x = x + 1
x /= 2
return x
def d(x): # good, premature optimization is the root of all evil :)
return (x + 1) / 2 I think users will have to learn these rules themselves. Python function boundaries generally aren't available at the level of array libraries, so we can't do much to help them. Fortunately, these are already best practices for writing code that works generically over both arrays and scalars. Potentially static type checkers could also catch this sort of thing?
Yes, in fact this is exactly the case in JAX ): |
Yes indeed (unambiguous unless the target affects a view).
I think the differences are (a) the impact on existing code (acceptable for immutable views, a train wreck for all mutations), and (b) how easy the resulting code changes are to make (for views, just add a Here's a search result for some forms of image assigment use (
That tells me that if mutation is optional, scikit-image is simply going to ignore the array types that don't implement it. As will SciPy et al. Maybe that just is what it is, but what I want to know if it can be avoided by JAX adding a simple wrapper that maps
to
If that's not that hard to do, and if a lot of those cases are not inefficient (which I think is the case), then would it really be so bad adding that wrapper? |
The reason why
So if we can figure out a usage pattern that avoids that confusion by recommending (or requiring?) a call to ensure a copy, I think we could make it work in JAX, e.g.,
If we have a reliable way to detect "views" (with reference counting?), we might be able to call this method JAX would probably choose to make all operations "views" by default, just to preserve maximum flexibility and require best practices. In fact, I would guess that the default array type wouldn't even implement a |
That's an interesting idea. I think it's feasible to build this, and the extra API surface of |
A quick test to set a baseline, here are the results of running the test suites of some SciPy modules after making
Testing the impact of the |
Coming back to this after finding it's probably too hard to implement
Played with this some more:
yields
This is actually 100% identical to the problem with slice assignment. So saying for
It seems to me that if JAX were principled on no mutation, it should have implemented |
@mattip is going to help with an experiment, creating a NumPy branch which sets |
I'm not entirely sure I follow what you mean here. Examples like your
If it helps, we could still call it My main concern is that we should try to avoid baking specific view vs copy semantics directly into the standard. For example, if a new library wants to implement arithmetic as lazy expressions (like in dask or xtensor) that should be OK. Lazy expressions are effectively a form of views, so you can't count on being able to write something like this, even though it's safe in NumPy:
|
Yes, I agree completely. The question is just what form that takes. I'm trying to get to a place where we don't have to guess at all, but slice assignment is still allowed.
|
Here is a related PyTorch issue with ideas and rationale for adding immutable tensors and returning them in some cases where view/copy behaviour will be hard to predict: pytorch/pytorch#44027 |
Here is a hacky patch for NumPy that makes both the base array and the view read-only if a view is created:
It doesn't completely do the job it's supposed to do - it works for warning when doing a mutating operation that affects a view, however it also prevents regular slice assignment:
That could be prevented by starting from within At that point there's still annoying behaviour left though, for example just evaluating Trying out the effect on SciPy:
gives tl;dr this is going to be a real pain to do. |
I wrote some tests that make it easier to figure out which libraries have matching behaviour for in-place operations: """
To create an environment with the seven array libraries used in this script
installed:
conda create -n many-libs python=3.7
conda activate many-libs
conda install cudatoolkit=10.2
pip install numpy dask toolz torch jax jaxlib tensorflow mxnet cupy-cuda102
Conda doesn't manage to find a winning combination here; pip has a hard time
too and probably not all constraints are satisfied, but nothing crashes
and the tests here work as they are supposed to.
"""
import numpy as np
import dask.array as da
import torch
import tensorflow as tf
import jax.numpy as jnp
import mxnet
try:
import cupy as cp
except ImportError:
# CuPy is GPU-only, so may not be available
cp = None
import pandas as pd
def materialize(x, y):
if isinstance(x, da.Array):
x = x.compute()
if isinstance(y, da.Array):
y = y.compute()
if mod == mxnet.nd:
x0 = int(x[0, 0].asscalar())
y0 = int(y[0, 0].asscalar())
else:
x0 = int(x[0, 0])
y0 = int(y[0, 0])
return x0, y0
def ones(mod):
if mod in (da, mxnet.nd):
x = mod.ones((3, 2), dtype=np.int32)
else:
x = mod.ones((3, 2), dtype=mod.int32)
return x
def reshape(mod, x, shape):
if mod == tf:
return tf.reshape(x, shape)
else:
return x.reshape(shape)
def arange(mod, stop):
if mod == tf:
return tf.range(stop)
elif mod == mxnet.nd:
return mod.arange(stop, dtype=np.int64)
else:
return mod.arange(stop)
def diag(mod, x):
if mod == tf:
return tf.linalg.diag_part(x)
else:
return mod.diag(x)
def slice_assign(mod, x):
# Add 2 to first row of (3, 2)-shaped input x
assert x.shape[0] == 3
idx = reshape(mod, arange(mod, 6), x.shape) < 3
if mod in (tf, da):
x = mod.where(idx, x+2, x)
elif mod == jnp:
x = x.at[idx].set(x[idx] + 2)
else:
x[:, 0] += 2
return x
def f1(mod):
"Add, then in-place subtract"
x = ones(mod)
y = x + 2
y -= 1
return materialize(x, y)
def f2(mod):
"In-place add, alias, then in-place subtract"
x = ones(mod)
x += 2
y = x
y -= 1
return materialize(x, y)
def f3(mod):
"Slice, then in-place add on slice"
x = ones(mod)
y = x[:2, :]
y += 1
return materialize(x, y)
def f4(mod):
"Reshape, then in-place add"
x = ones(mod)
y = reshape(mod, x, (2, 3))
y += 1
return materialize(x, y)
def f5(mod):
"Slice with step size 2, then in-place add"
x = ones(mod)
y = x[::2, :]
y += 1
return materialize(x, y)
def f6(mod):
"Alias, then in-place add"
x = ones(mod)
y = x
y += 1
return materialize(x, y)
def f7(mod):
"Check which array types support slice assignment syntax"
x = ones(mod)
try:
x[0, :] = 2
except (NotImplementedError, TypeError):
x = -9*ones(mod)
return materialize(x, 2*ones(mod))
def f8(mod):
"Do the actual slice assignment in the way each library wants it done"
x = ones(mod)
x = slice_assign(mod, x)
y = 2 * ones(mod)
return materialize(x, y)
def f9(mod):
"`diag` is known to have inconsistent behaviour, test it"
x = reshape(mod, arange(mod, 9), (3, 3))
y = diag(mod, x)
if mod == np:
y = y + 2 # `y` is read-only
else:
y += 2
return materialize(x, 2*ones(mod))
def f_indexing(mod, idx):
"Indexing, then in-place add"
x = ones(mod)
try:
y = x[idx, :]
except TypeError:
x = -9 * x
y = ones(mod)
y += 1
return materialize(x, y)
def f10(mod):
"Indexing with list of integers, then in-place add"
return f_indexing(mod, [0, 1])
def f11(mod):
"Boolean indexing, then in-place add"
return f_indexing(mod, [True, True, False])
def f12(mod):
"Indexing with ellipsis, then in-place add"
return f_indexing(mod, Ellipsis)
libraries = {
'numpy': np,
'pytorch': torch,
'MXNet': mxnet.nd,
'dask': da,
'tensorflow': tf,
'jax': jnp,
}
if cp is not None:
libraries['cupy'] = cp
results = libraries.copy()
funcs = [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]
for name, mod in libraries.items():
results[name] = None
res = []
for func in funcs:
x, y = func(mod)
assert y == 2
assert type(x) == int, (type(x), mod)
res.append(x)
results[name] = res
results = pd.DataFrame(results).T
results.columns = ['f{}'.format(i+1) for i in range(len(results.columns))]
print(results.sort_values(by='f2'))
print('\n')
for i, func in enumerate(funcs):
print(func.__name__ + ':', func.__doc__)
if i == 5:
# First six functions are only about in-place operators, no slice
# assignment - separate those in the output
print('')
if i == 8:
# Last three functions are only about indexing behaviour
print('')
print('\nCheck if behaviour equals that of NumPy\n(1 or -1 means copy/view behaviour '
'mismatch, -9 means unsupported behaviour):\n')
results_vs_numpy = results.sort_values(by='f2') - results.loc['numpy']
# Set NA values (from exceptions) back to -9
results_vs_numpy[results_vs_numpy < -5] = -9
print(results_vs_numpy) Results in:
Some observations:
|
I'm leaning towards the following choices:
|
Would it be more practical to make only the view read-only? If you modify the base array, then in addition to errors like the IPython example you show, functions that create views would now effectively be mutating their input arguments. |
Yes, that's a good point. I undid that change, and added some more fixes (current patch here). That at least makes the NumPy test suite start and run to completion without crashing. Resulting in:
The troubles are:
I think it would take too much time to complete this change and make the NumPy test suite (mostly) pass - and that'd be needed to be able to assess the impact on SciPy and other packages. That, plus that it's a massive backwards compat break for all of NumPy, PyTorch, CuPy and MXNet, makes me think we should simply go with "Add a recommendation that users avoid any mutating operation when a view may be involved". |
Everyone in the meeting today was good with points 1-4 in #24 (comment). I'll open a PR that discusses mutability, and @kgryte will remove |
Context:
That issue and PR were about unrelated topics, so I'll try to summarize the copy-view and mutation topic here and we can continue the discussion.
Note that the two topics are fairly coupled, because copy/view differences only matter (for semantics, not for performance) when mixed with mutation.
Mutating arrays
There's a number of things that may rely on mutation:
+=
,*=
out=
keyword argument__setitem__
Summary of the issue with mutation by @shoyer was: Mutation can be challenging to support in some execution models (at least without another layer of indirection), which is why several projects currently don't support it (TensorFlow and JAX) or only support it half-heartedly (e.g., Dask). The commonality between these libraries is that they build up abstract computations, which is then transformed (e.g., for autodiff) and/or executed in parallel. Even NumPy has "read only" arrays. I'm particularly concerned about new projects that implement this API, which might find the need to support mutation burdensome.
@alextp said: TensorFlow was planning to add mutability and didn't see a real issue with supporting
out=
.@shoyer said: It's definitely always possible to support mutation at the Python level via some sort of wrapper layer.
dask.array
is perhaps a good example of this. It supports mutating operations and out in some cases, but its support for mutation is still rather limited. For example, it doesn't support assignment likex[:2, :] = some_other_array
.Working around limitations of no support for mutation can usually be done by one of:
where
for selection, e.g.,where(arange(4) == 2, 1, 0)
y = array([0, 1]); x = y[[0, 0, 1, 0]]
in this caseSome version of (2) always works, though it can be tricky to work out (especially with current APIs). The duality between indexing and assignment is the difference between specifying where elements come from or where they end up.
The JAX syntax for slice assignment is:
x.at[idx].set(y) vs x[idx] = y
One advantage of the non-mutating version is that JAX can have reliable assigning arithmetic on array slices with
x.at[idx].add(y)
(x[idx] += y
doesn't work ifx[idx]
returns a copy).A disadvantage is that doing this sort thing inside a loop is almost always a bad idea unless you have a JIT compiler, because every indexing assignment operation makes a full copy. So the naive translation of an efficient Python loop that fills out an array row by row would now make a copy in each step. Instead, you'd have to rewrite that loop to use something like concatenate instead (which in my experience is already about as efficient as using indexing assignment).
Copy-view behaviour
Libraries like NumPy and PyTorch return views where possible from function calls. It's sometimes hard to predict when a view will be returned vs. when a copy - it not only depends on the function in question, but also on whether the input array is contiguous, and sometimes even on input dtype.
This is one place where it's hard to avoid implementation choices leaking into the API:
transpose()
.transpose()
).The above copy vs. view difference starts leaking into the API - i.e., the same code starts giving different results for different implementations - when it is combined with an operation that performs in-place mutation of an array (either the base array or the view on it). In the absence of that combination, views are simply a performance optimization that's invisible to the user.
The question is whether copy-view differences should be allowed, and if so how to deal with the semantics that vary between libraries.
To answer whether is should be allowed, let's first ask how often the combination of views and mutation is used. A few observations:
*=
,+=
and] =
in SciPy and scikit-learn.py
files shows that in-place mutation inside functions is heavily used.+= 1
) and mutating part of an array (e.g. withx[:, :2] = y
). The former is a lot easier to support for array libraries employing static graphs or a JIT than the latter. See the discussion at Proposal to standardize element-wise elementary mathematical functions #8 (comment) for details.Options for how to standardize
In #8 @shoyer listed the following options for how to deal with mutability:
ndarray.flags.writeable
. (From later discussion, see Proposal to standardize element-wise elementary mathematical functions #8 (comment) for the implication of that for users of the API).To that I'd like to add a more granular option:
Require support for in-place operations that are unambiguous, and require raising an exception in case a view is mutated.
Rationale:
(a) This would require libraries that don't support mutation to write a wrapper layer, but the behaviour would be unambiguous and in most cases the wrapper would not be inefficient.
(b) In case inefficient mutation is detected (e.g. mutation a large array row-by-row in a loop), a warning may be emitted.
A variant of this option would be:
Require support for in-place operations that are unambiguous and mutate the whole array at once (i.e.
+=
andout=
must be supported, element/slice assignment must raise an exception), and require raising an exception in case a view is mutated.Trade-off here is ease of implementation for libraries like Dask and JAX vs. putting a rewrite burden on SciPy et al. and a usability burden on end users (the alternative to element/slice assignment is unintuitive).
The text was updated successfully, but these errors were encountered: