Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve test seeding in test_numpy_interoperablity.py (#18762)
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 authored Jul 20, 2020
1 parent a7c6606 commit 6bb3d72
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
from common import assertRaises, with_seed, random_seed
from common import assertRaises, with_seed, random_seed, setup_module, teardown_module
from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST

Expand Down Expand Up @@ -62,8 +62,15 @@ def add_workload(name, *args, **kwargs):

@staticmethod
def get_workloads(name):
if OpArgMngr._args == {}:
_prepare_workloads()
return OpArgMngr._args.get(name, None)

@staticmethod
def randomize_workloads():
# Force a new _prepare_workloads(), which will be based on new random numbers
OpArgMngr._args = {}


def _add_workload_all():
# check bad element in all positions
Expand Down Expand Up @@ -516,8 +523,8 @@ def _add_workload_linalg_cholesky():
shapes = [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]
dtypes = (np.float32, np.float64)

for shape, dtype in itertools.product(shapes, dtypes):
with random_seed(1):
with random_seed(1):
for shape, dtype in itertools.product(shapes, dtypes):
a = _np.random.randn(*shape)

t = list(range(len(shape)))
Expand Down Expand Up @@ -3183,9 +3190,6 @@ def _prepare_workloads():
_add_workload_vander()


_prepare_workloads()


def _get_numpy_op_output(onp_op, *args, **kwargs):
onp_args = [arg.asnumpy() if isinstance(arg, np.ndarray) else arg for arg in args]
onp_kwargs = {k: v.asnumpy() if isinstance(v, np.ndarray) else v for k, v in kwargs.items()}
Expand All @@ -3197,7 +3201,7 @@ def _get_numpy_op_output(onp_op, *args, **kwargs):
return onp_op(*onp_args, **onp_kwargs)


def _check_interoperability_helper(op_name, *args, **kwargs):
def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
strs = op_name.split('.')
if len(strs) == 1:
onp_op = getattr(_np, op_name)
Expand All @@ -3213,11 +3217,11 @@ def _check_interoperability_helper(op_name, *args, **kwargs):
assert type(out) == type(expected_out)
for arr, expected_arr in zip(out, expected_out):
if isinstance(arr, np.ndarray):
assert_almost_equal(arr.asnumpy(), expected_arr, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
assert_almost_equal(arr.asnumpy(), expected_arr, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
else:
_np.testing.assert_equal(arr, expected_arr)
elif isinstance(out, np.ndarray):
assert_almost_equal(out.asnumpy(), expected_out, rtol=1e-3, atol=1e-4, use_broadcast=False, equal_nan=True)
assert_almost_equal(out.asnumpy(), expected_out, rtol=rel_tol, atol=abs_tol, use_broadcast=False, equal_nan=True)
elif isinstance(out, _np.dtype):
_np.testing.assert_equal(out, expected_out)
else:
Expand All @@ -3229,6 +3233,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs):


def check_interoperability(op_list):
OpArgMngr.randomize_workloads()
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
Expand All @@ -3240,13 +3245,17 @@ def check_interoperability(op_list):
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
continue
default_tols = (1e-3, 1e-4)
tols = {'linalg.tensorinv': (1e-2, 5e-3),
'linalg.solve': (1e-3, 5e-2)}
(rel_tol, abs_tol) = tols.get(name, default_tols)
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, *workload['args'], **workload['kwargs'])
_check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])


@with_seed()
Expand Down

0 comments on commit 6bb3d72

Please sign in to comment.