From 41983f5df2bd024c217add27859174779786422f Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 21 Oct 2021 16:04:09 -0700 Subject: [PATCH 1/5] [API] Add method __index__() and __array_namespace__() --- python/mxnet/numpy/multiarray.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a58f1faf5587..a678bc5a210f 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -412,6 +412,28 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- return mx_np_func(*new_args, **new_kwargs) + def __array_namespace__(self, api_version=None): + """ + Returns an object that has all the array API functions on it. + + Notes + ----- + This is a standard API in + https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none. + + Parameters + ---------- + self : ndarray + The indexing key. + api_version : Optional, string + string representing the version of the array API specification to be returned, in `YYYY.MM` form. + If it is None, it should return the namespace corresponding to latest version of the array API specification. + """ + if api_version is not None and not api_version.startswith("2021."): + raise ValueError(f"Unrecognized array API version: {api_version!r}") + return self.__module__ + + def _get_np_basic_indexing(self, key): """ This function indexes ``self`` with a tuple of `slice` objects only. @@ -1255,6 +1277,11 @@ def __bool__(self): __nonzero__ = __bool__ + def __index__(self): + if self.ndim == 0 and _np.issubdtype(self.dtype, _np.integer): + return self.item() + raise TypeError('only integer scalar arrays can be converted to a scalar index') + def __float__(self): num_elements = self.size if num_elements != 1: From 0a4af8a61fbd5a632fd623ca262262df1c3b87b1 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 21 Oct 2021 16:21:25 -0700 Subject: [PATCH 2/5] update doc --- docs/python_docs/python/api/np/arrays.ndarray.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/python_docs/python/api/np/arrays.ndarray.rst b/docs/python_docs/python/api/np/arrays.ndarray.rst index e77d20b8a138..522a667d69b1 100644 --- a/docs/python_docs/python/api/np/arrays.ndarray.rst +++ b/docs/python_docs/python/api/np/arrays.ndarray.rst @@ -512,12 +512,13 @@ Container customization: (see :ref:`Indexing `) ndarray.__getitem__ ndarray.__setitem__ -Conversion; the operations :func:`int()` and :func:`float()`. +Conversion; the operations :func:`index()`, :func:`int()` and :func:`float()`. They work only on arrays that have one element in them and return the appropriate scalar. .. autosummary:: + ndarray.__index__ ndarray.__int__ ndarray.__float__ From 9a328d7d921fdd123dc0c574e4c4a3abc71b30a1 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Thu, 21 Oct 2021 16:32:23 -0700 Subject: [PATCH 3/5] fix lint --- python/mxnet/numpy/multiarray.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a678bc5a210f..b969d059fcbe 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -414,11 +414,11 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- def __array_namespace__(self, api_version=None): """ - Returns an object that has all the array API functions on it. + Returns an object that has all the array API functions on it. - Notes - ----- - This is a standard API in + Notes + ----- + This is a standard API in https://data-apis.org/array-api/latest/API_specification/array_object.html#array-namespace-self-api-version-none. Parameters @@ -427,7 +427,8 @@ def __array_namespace__(self, api_version=None): The indexing key. api_version : Optional, string string representing the version of the array API specification to be returned, in `YYYY.MM` form. - If it is None, it should return the namespace corresponding to latest version of the array API specification. + If it is None, it should return the namespace corresponding to latest version of the array API + specification. """ if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") From 8e2cc75dc26b5eca731aa7b5926d67f213befa54 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 29 Oct 2021 12:05:51 -0700 Subject: [PATCH 4/5] add tests --- python/mxnet/numpy/multiarray.py | 13 +++++++-- tests/python/unittest/test_numpy_ndarray.py | 32 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b969d059fcbe..b3dbe04fbbbb 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -31,6 +31,8 @@ from array import array as native_array import functools import ctypes +import sys +import datetime import warnings import numpy as _np from .. import _deferred_compute as dc @@ -430,9 +432,14 @@ def __array_namespace__(self, api_version=None): If it is None, it should return the namespace corresponding to latest version of the array API specification. """ - if api_version is not None and not api_version.startswith("2021."): - raise ValueError(f"Unrecognized array API version: {api_version!r}") - return self.__module__ + if api_version is not None: + try: + date = datetime.datetime.strptime(api_version, '%Y.%m') + if date.year != 2021: + raise ValueError + except ValueError: + raise ValueError(f"Unrecognized array API version: {api_version!r}") + return sys.modules[self.__module__] def _get_np_basic_indexing(self, key): diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 559b8a575f5d..1d0ecdbacebb 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -21,6 +21,7 @@ import itertools import os import pytest +import operator import numpy as _np import mxnet as mx from mxnet import np, npx, autograd @@ -1426,3 +1427,34 @@ def test_mixed_array_types_share_memory(): def test_save_load_empty(tmp_path): mx.npx.savez(str(tmp_path / 'params.npz')) mx.npx.load(str(tmp_path / 'params.npz')) + +@use_np +@pytest.mark.parametrize('shape', [ + (), + (1,), + (1,2) +]) +@pytest.mark.parametrize('dtype', np._DTYPE_2_STR_.keys()) +def test_index_operator(shape, dtype): + if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer): + x = np.ones(shape=shape, dtype=dtype) + pytest.raises(TypeError, operator.index, x) + else: + assert operator.index(np.ones(shape=shape, dtype=dtype)) == \ + operator.index(_np.ones(shape=shape, dtype=dtype)) + + +@pytest.mark.parametrize('api_version, raise_exception', [ + (None, False), + ('2021.10', False), + ('2020.09', True), + ('2021.24', True), +]) +def test_array_namespace(api_version, raise_exception): + x = np.array([1, 2, 3], dtype="float64") + if raise_exception: + pytest.raises(ValueError, x.__array_namespace__, api_version) + else: + xp = x.__array_namespace__(api_version) + y = xp.array([1, 2, 3], dtype="float64") + assert same(x, y) From ead9f83affd63fe10e7d8c92f4824a297d5b33d0 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Fri, 29 Oct 2021 14:39:14 -0700 Subject: [PATCH 5/5] update tests --- tests/python/unittest/test_numpy_ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 1d0ecdbacebb..2da60aa1fc8e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1434,7 +1434,7 @@ def test_save_load_empty(tmp_path): (1,), (1,2) ]) -@pytest.mark.parametrize('dtype', np._DTYPE_2_STR_.keys()) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'bool', 'int32']) def test_index_operator(shape, dtype): if len(shape) >= 1 or not _np.issubdtype(dtype, _np.integer): x = np.ones(shape=shape, dtype=dtype)