Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Implement a NaN policy for onthefly read_transform #429

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions junifer/onthefly/read_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Authors: Synchon Mandal <[email protected]>
# License: AGPL


from typing import Optional

import numpy as np
import pandas as pd

from ..typing import StorageLike
Expand All @@ -20,6 +20,7 @@
transform: str,
feature_name: Optional[str] = None,
feature_md5: Optional[str] = None,
nan_policy: Optional[str] = "bypass",
transform_args: Optional[tuple] = None,
transform_kw_args: Optional[dict] = None,
) -> pd.DataFrame:
Expand All @@ -36,6 +37,16 @@
Name of the feature to read (default None).
feature_md5 : str, optional
MD5 hash of the feature to read (default None).
nan_policy : str, optional
The policy to handle NaN values (default "ignore").
Options are:

* "bypass": Do nothing and pass NaN values to the transform function.
* "drop_element": Drop (skip) elements with NaN values.
* "drop_rows": Drop (skip) rows with NaN values.
* "drop_columns": Drop (skip) columns with NaN values.
* "drop_symmetric": Drop (skip) symmetric pairs with NaN values.

transform_args : tuple, optional
The positional arguments for the callable of ``transform``
(default None).
Expand Down Expand Up @@ -64,6 +75,18 @@
transform_args = transform_args or ()
transform_kw_args = transform_kw_args or {}

if nan_policy not in [
"bypass",
"drop_element",
"drop_rows",
"drop_columns",
"drop_symmetric",
]:
raise_error(
f"Unknown nan_policy: {nan_policy}",
klass=ValueError,
)

# Read storage
stored_data = storage.read(
feature_name=feature_name, feature_md5=feature_md5
Expand Down Expand Up @@ -107,22 +130,51 @@
except AttributeError as err:
raise_error(msg=str(err), klass=AttributeError)

# Apply function and store subject-wise
# Apply function and store element-wise
output_list = []
element_list = []
logger.debug(
f"Computing '{package}.{func_str}' for feature "
f"{feature_name or feature_md5} ..."
)
for subject in range(stored_data["data"].shape[2]):
for i_element, element in enumerate(stored_data["element"]):
t_data = stored_data["data"][:, :, i_element]
has_nan = np.isnan(np.min(t_data))
if nan_policy == "drop_element" and has_nan:
logger.debug(
f"Skipping element {element} due to NaN values ..."
)
continue
elif nan_policy == "drop_rows" and has_nan:
logger.debug(
f"Skipping rows with NaN values in element {element} ..."
)
t_data = t_data[~np.isnan(t_data).any(axis=1)]
elif nan_policy == "drop_columns" and has_nan:
logger.debug(
f"Skipping columns with NaN values in element {element} ..."

Check failure on line 155 in junifer/onthefly/read_transform.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

junifer/onthefly/read_transform.py:155:80: E501 Line too long (80 > 79)

Check failure on line 155 in junifer/onthefly/read_transform.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

junifer/onthefly/read_transform.py:155:80: E501 Line too long (80 > 79)
)
t_data = t_data[:, ~np.isnan(t_data).any(axis=0)]
elif nan_policy == "drop_symmetric":
logger.debug(
f"Skipping pairs of rows/columns with NaN values in "
f"element {element}..."
)
good_rows = ~np.isnan(t_data).any(axis=1)
good_columns = ~np.isnan(t_data).any(axis=0)
good_idx = np.logical_and(good_rows, good_columns)
t_data = t_data[good_idx][:, good_idx]

output = func(
stored_data["data"][:, :, subject],
t_data,
*transform_args,
**transform_kw_args,
)
output_list.append(output)
element_list.append(element)

# Create dataframe for index
idx_df = pd.DataFrame(data=stored_data["element"])
idx_df = pd.DataFrame(data=element_list)
# Create multiindex from dataframe
logger.debug(
"Generating pandas.MultiIndex for feature "
Expand Down
85 changes: 84 additions & 1 deletion junifer/onthefly/tests/test_read_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Authors: Synchon Mandal <[email protected]>
# License: AGPL


import logging
from pathlib import Path

Expand Down Expand Up @@ -65,6 +64,36 @@ def matrix_storage(tmp_path: Path) -> HDF5FeatureStorage:
return storage


@pytest.fixture
def matrix_storage_with_nan(tmp_path: Path) -> HDF5FeatureStorage:
"""Return a HDF5FeatureStorage with matrix data.

Parameters
----------
tmp_path : pathlib.Path
The path to the test directory.

"""
storage = HDF5FeatureStorage(tmp_path / "matrix_store_nan.hdf5")
data = np.arange(36).reshape(3, 3, 4).astype(float)
data[1, 1, 2] = np.nan
data[1, 2, 2] = np.nan
for i in range(4):
storage.store(
kind="matrix",
meta={
"element": {"subject": f"test{i + 1}"},
"dependencies": [],
"marker": {"name": "matrix"},
"type": "BOLD",
},
data=data[:, :, i],
col_names=["f1", "f2", "f3"],
row_names=["g1", "g2", "g3"],
)
return storage


def test_incorrect_package(matrix_storage: HDF5FeatureStorage) -> None:
"""Test error check for incorrect package name.

Expand Down Expand Up @@ -177,3 +206,57 @@ def test_bctpy_function(
)
assert "Computing" in caplog.text
assert "Generating" in caplog.text


@pytest.mark.parametrize(
"nan_policy, error_msg",
[
("drop_element", None),
("drop_rows", "square"),
("drop_columns", "square"),
("drop_symmetric", None),
("bypass", "NaNs"),
("wrong", "Unknown"),
],
)
def test_bctpy_nans(
matrix_storage_with_nan: HDF5FeatureStorage,
caplog: pytest.LogCaptureFixture,
nan_policy: str,
error_msg: str,
) -> None:
"""Test working function of bctpy.

Parameters
----------
matrix_storage_with_nan : HDF5FeatureStorage
The HDF5FeatureStorage with matrix data, as fixture.
caplog : pytest.LogCaptureFixture
The pytest.LogCaptureFixture object.
nan_policy : str
The NAN policy to test.
error_msg : str
The expected error message snippet. If None, no error should be raised.

"""
# Skip test if import fails
pytest.importorskip("bct")

with caplog.at_level(logging.DEBUG):
if error_msg is None:
read_transform(
storage=matrix_storage_with_nan, # type: ignore
feature_name="BOLD_matrix",
transform="bctpy_eigenvector_centrality_und",
nan_policy=nan_policy,
)
assert "Computing" in caplog.text
assert "Generating" in caplog.text
else:
with pytest.raises(ValueError, match=error_msg):
read_transform(
storage=matrix_storage_with_nan, # type: ignore
feature_name="BOLD_matrix",
transform="bctpy_eigenvector_centrality_und",
nan_policy=nan_policy,
)
Loading