diff --git a/esmvalcore/preprocessor/_shared.py b/esmvalcore/preprocessor/_shared.py index 7ed6ae4375..6ab4c10974 100644 --- a/esmvalcore/preprocessor/_shared.py +++ b/esmvalcore/preprocessor/_shared.py @@ -6,6 +6,7 @@ from __future__ import annotations +import inspect import logging import warnings from collections import defaultdict @@ -221,6 +222,21 @@ def get_normalized_cube( return normalized_cube +def _get_first_arg(func: Callable, *args: Any, **kwargs: Any) -> Any: + """Get first argument given to a function.""" + # If positional arguments are given, use the first one + if args: + return args[0] + + # Otherwise, use the keyword argument given by the name of the first + # function argument + # Note: this function should be called AFTER func(*args, **kwargs) is run, + # so that we can be sure that the required arguments are there + signature = inspect.signature(func) + first_arg_name = list(signature.parameters.values())[0].name + return kwargs[first_arg_name] + + def preserve_float_dtype(func: Callable) -> Callable: """Preserve object's float dtype (all other dtypes are allowed to change). @@ -230,16 +246,34 @@ def preserve_float_dtype(func: Callable) -> Callable: to give output with any type. """ + signature = inspect.signature(func) + if not signature.parameters: + raise TypeError( + f"Cannot preserve float dtype during function '{func.__name__}', " + f"function takes no arguments" + ) @wraps(func) - def wrapper(data: DataType, *args: Any, **kwargs: Any) -> DataType: - dtype = data.dtype - result = func(data, *args, **kwargs) - if np.issubdtype(dtype, np.floating) and result.dtype != dtype: - if isinstance(result, Cube): - result.data = result.core_data().astype(dtype) - else: - result = result.astype(dtype) + def wrapper(*args: Any, **kwargs: Any) -> DataType: + result = func(*args, **kwargs) + first_arg = _get_first_arg(func, *args, **kwargs) + + if hasattr(first_arg, "dtype") and hasattr(result, "dtype"): + dtype = first_arg.dtype + if np.issubdtype(dtype, np.floating) and result.dtype != dtype: + if isinstance(result, Cube): + result.data = result.core_data().astype(dtype) + else: + result = result.astype(dtype) + else: + raise TypeError( + f"Cannot preserve float dtype during function " + f"'{func.__name__}', the function's first argument of type " + f"{type(first_arg)} and/or the function's return value of " + f"type {type(result)} do not have the necessary attribute " + f"'dtype'" + ) + return result return wrapper diff --git a/tests/unit/preprocessor/test_shared.py b/tests/unit/preprocessor/test_shared.py index 46a6283573..8fca24dc1c 100644 --- a/tests/unit/preprocessor/test_shared.py +++ b/tests/unit/preprocessor/test_shared.py @@ -200,27 +200,27 @@ def _dummy_func(obj, arg, kwarg=2.0): return obj -@pytest.mark.parametrize( - "data,dtype", - [ - (np.array([1.0], dtype=np.float64), np.float64), - (np.array([1.0], dtype=np.float32), np.float32), - (np.array([1], dtype=np.int64), np.float64), - (np.array([1], dtype=np.int32), np.float64), - (da.array([1.0], dtype=np.float64), np.float64), - (da.array([1.0], dtype=np.float32), np.float32), - (da.array([1], dtype=np.int64), np.float64), - (da.array([1], dtype=np.int32), np.float64), - (Cube(np.array([1.0], dtype=np.float64)), np.float64), - (Cube(np.array([1.0], dtype=np.float32)), np.float32), - (Cube(np.array([1], dtype=np.int64)), np.float64), - (Cube(np.array([1], dtype=np.int32)), np.float64), - (Cube(da.array([1.0], dtype=np.float64)), np.float64), - (Cube(da.array([1.0], dtype=np.float32)), np.float32), - (Cube(da.array([1], dtype=np.int64)), np.float64), - (Cube(da.array([1], dtype=np.int32)), np.float64), - ], -) +TEST_PRESERVE_FLOAT_TYPE = [ + (np.array([1.0], dtype=np.float64), np.float64), + (np.array([1.0], dtype=np.float32), np.float32), + (np.array([1], dtype=np.int64), np.float64), + (np.array([1], dtype=np.int32), np.float64), + (da.array([1.0], dtype=np.float64), np.float64), + (da.array([1.0], dtype=np.float32), np.float32), + (da.array([1], dtype=np.int64), np.float64), + (da.array([1], dtype=np.int32), np.float64), + (Cube(np.array([1.0], dtype=np.float64)), np.float64), + (Cube(np.array([1.0], dtype=np.float32)), np.float32), + (Cube(np.array([1], dtype=np.int64)), np.float64), + (Cube(np.array([1], dtype=np.int32)), np.float64), + (Cube(da.array([1.0], dtype=np.float64)), np.float64), + (Cube(da.array([1.0], dtype=np.float32)), np.float32), + (Cube(da.array([1], dtype=np.int64)), np.float64), + (Cube(da.array([1], dtype=np.int32)), np.float64), +] + + +@pytest.mark.parametrize("data,dtype", TEST_PRESERVE_FLOAT_TYPE) def test_preserve_float_dtype(data, dtype): """Test `preserve_float_dtype`.""" input_data = data.copy() @@ -238,6 +238,78 @@ def test_preserve_float_dtype(data, dtype): assert list(signature.parameters) == ["obj", "arg", "kwarg"] +@pytest.mark.parametrize("data,dtype", TEST_PRESERVE_FLOAT_TYPE) +def test_preserve_float_dtype_kwargs_only(data, dtype): + """Test `preserve_float_dtype`.""" + input_data = data.copy() + + result = _dummy_func(arg=2.0, obj=input_data, kwarg=2.0) + + assert input_data.dtype == data.dtype + assert result.dtype == dtype + assert isinstance(result, type(data)) + if isinstance(data, Cube): + assert result.has_lazy_data() == data.has_lazy_data() + + assert _dummy_func.__name__ == "_dummy_func" + signature = inspect.signature(_dummy_func) + assert list(signature.parameters) == ["obj", "arg", "kwarg"] + + +def test_preserve_float_dtype_invalid_args(): + """Test `preserve_float_dtype`.""" + msg = r"missing 2 required positional arguments: 'obj' and 'arg'" + with pytest.raises(TypeError, match=msg): + _dummy_func() + + +def test_preserve_float_dtype_invalid_kwarg(): + """Test `preserve_float_dtype`.""" + msg = r"got an unexpected keyword argument 'data'" + with pytest.raises(TypeError, match=msg): + _dummy_func(np.array(1), 2.0, data=3.0) + + +def test_preserve_float_dtype_invalid_func(): + """Test `preserve_float_dtype`.""" + msg = ( + r"Cannot preserve float dtype during function '', function " + r"takes no arguments" + ) + with pytest.raises(TypeError, match=msg): + preserve_float_dtype(lambda: None) + + +def test_preserve_float_dtype_first_arg_no_dtype(): + """Test `preserve_float_dtype`.""" + + @preserve_float_dtype + def func(obj): + return obj * np.array(1) + + msg = ( + r"Cannot preserve float dtype during function 'func', the function's " + r"first argument of type" + ) + with pytest.raises(TypeError, match=msg): + func(1.0) + + +def test_preserve_float_dtype_return_value_no_dtype(): + """Test `preserve_float_dtype`.""" + + @preserve_float_dtype + def func(_): + return 1 + + msg = ( + r"Cannot preserve float dtype during function 'func', the function's " + r"first argument of type" + ) + with pytest.raises(TypeError, match=msg): + func(np.array(1.0)) + + def test_get_array_module_da(): npx = get_array_module(da.array([1, 2])) assert npx is da