Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not change function argument names when decorator preserve_float_dtype is used #2645

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions esmvalcore/preprocessor/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import inspect
import logging
import warnings
from collections import defaultdict
Expand All @@ -20,6 +21,7 @@
from iris.cube import Cube
from iris.exceptions import CoordinateMultiDimError, CoordinateNotFoundError
from iris.util import broadcast_to_shape
from numpy.typing import DTypeLike

from esmvalcore.iris_helpers import has_regular_grid
from esmvalcore.typing import DataType
Expand Down Expand Up @@ -221,6 +223,25 @@ def get_normalized_cube(
return normalized_cube


def _get_dtype_of_first_arg(
func: Callable,
*args: Any,
**kwargs: Any,
) -> DTypeLike:
"""Get dtype of first argument given to a function."""
# If positional arguments are given, use the first one
if args:
return args[0].dtype

# Otherwise, use the keyword argument given by the name of the first
# function argument
# Note: this function is called AFTER func(*args, **kwargs) is run, so we
# can be sure that the required arguments are there
signature = inspect.signature(func)
first_arg_name = list(signature.parameters.values())[0].name
return kwargs[first_arg_name].dtype


def preserve_float_dtype(func: Callable) -> Callable:
"""Preserve object's float dtype (all other dtypes are allowed to change).

Expand All @@ -230,11 +251,17 @@ def preserve_float_dtype(func: Callable) -> Callable:
to give output with any type.

"""
signature = inspect.signature(func)
if not signature.parameters:
raise TypeError(
f"Cannot preserve float dtype during function '{func.__name__}, "
f"function takes no arguments"
)

@wraps(func)
def wrapper(data: DataType, *args: Any, **kwargs: Any) -> DataType:
dtype = data.dtype
result = func(data, *args, **kwargs)
def wrapper(*args: Any, **kwargs: Any) -> DataType:
result = func(*args, **kwargs)
dtype = _get_dtype_of_first_arg(func, *args, **kwargs)
if np.issubdtype(dtype, np.floating) and result.dtype != dtype:
if isinstance(result, Cube):
result.data = result.core_data().astype(dtype)
Expand Down
78 changes: 57 additions & 21 deletions tests/unit/preprocessor/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,27 +200,27 @@ def _dummy_func(obj, arg, kwarg=2.0):
return obj


@pytest.mark.parametrize(
"data,dtype",
[
(np.array([1.0], dtype=np.float64), np.float64),
(np.array([1.0], dtype=np.float32), np.float32),
(np.array([1], dtype=np.int64), np.float64),
(np.array([1], dtype=np.int32), np.float64),
(da.array([1.0], dtype=np.float64), np.float64),
(da.array([1.0], dtype=np.float32), np.float32),
(da.array([1], dtype=np.int64), np.float64),
(da.array([1], dtype=np.int32), np.float64),
(Cube(np.array([1.0], dtype=np.float64)), np.float64),
(Cube(np.array([1.0], dtype=np.float32)), np.float32),
(Cube(np.array([1], dtype=np.int64)), np.float64),
(Cube(np.array([1], dtype=np.int32)), np.float64),
(Cube(da.array([1.0], dtype=np.float64)), np.float64),
(Cube(da.array([1.0], dtype=np.float32)), np.float32),
(Cube(da.array([1], dtype=np.int64)), np.float64),
(Cube(da.array([1], dtype=np.int32)), np.float64),
],
)
TEST_PRESERVE_FLOAT_TYPE = [
(np.array([1.0], dtype=np.float64), np.float64),
(np.array([1.0], dtype=np.float32), np.float32),
(np.array([1], dtype=np.int64), np.float64),
(np.array([1], dtype=np.int32), np.float64),
(da.array([1.0], dtype=np.float64), np.float64),
(da.array([1.0], dtype=np.float32), np.float32),
(da.array([1], dtype=np.int64), np.float64),
(da.array([1], dtype=np.int32), np.float64),
(Cube(np.array([1.0], dtype=np.float64)), np.float64),
(Cube(np.array([1.0], dtype=np.float32)), np.float32),
(Cube(np.array([1], dtype=np.int64)), np.float64),
(Cube(np.array([1], dtype=np.int32)), np.float64),
(Cube(da.array([1.0], dtype=np.float64)), np.float64),
(Cube(da.array([1.0], dtype=np.float32)), np.float32),
(Cube(da.array([1], dtype=np.int64)), np.float64),
(Cube(da.array([1], dtype=np.int32)), np.float64),
]


@pytest.mark.parametrize("data,dtype", TEST_PRESERVE_FLOAT_TYPE)
def test_preserve_float_dtype(data, dtype):
"""Test `preserve_float_dtype`."""
input_data = data.copy()
Expand All @@ -238,6 +238,42 @@ def test_preserve_float_dtype(data, dtype):
assert list(signature.parameters) == ["obj", "arg", "kwarg"]


@pytest.mark.parametrize("data,dtype", TEST_PRESERVE_FLOAT_TYPE)
def test_preserve_float_dtype_kwargs_only(data, dtype):
"""Test `preserve_float_dtype`."""
input_data = data.copy()

result = _dummy_func(arg=2.0, obj=input_data, kwarg=2.0)

assert input_data.dtype == data.dtype
assert result.dtype == dtype
assert isinstance(result, type(data))
if isinstance(data, Cube):
assert result.has_lazy_data() == data.has_lazy_data()

assert _dummy_func.__name__ == "_dummy_func"
signature = inspect.signature(_dummy_func)
assert list(signature.parameters) == ["obj", "arg", "kwarg"]


def test_preserve_float_dtype_invalid_args():
"""Test `preserve_float_dtype`."""
with pytest.raises(TypeError):
_dummy_func()


def test_preserve_float_dtype_invalid_kwarg():
"""Test `preserve_float_dtype`."""
with pytest.raises(TypeError):
_dummy_func(np.array(1), 2.0, data=3.0)


def test_preserve_float_dtype_invalid_func():
"""Test `preserve_float_dtype`."""
with pytest.raises(TypeError):
preserve_float_dtype(lambda: None)


def test_get_array_module_da():
npx = get_array_module(da.array([1, 2]))
assert npx is da
Expand Down