Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MERGE] Add new scorer: MixValScorer #221

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,6 @@ The library is distributed under the 3-Clause BSD license.

[31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia.[ "Optimal transport for multi-source domain adaptation under target shift."](https://proceedings.mlr.press/v89/redko19a/redko19a.pdf) In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019.

[32] Hu, D., Liang, J., Liew, J. H., Xue, C., Bai, S., & Wang, X. (2023). [Mixed Samples as Probes for Unsupervised Model Selection in Domain Adaptation](https://proceedings.neurips.cc/paper_files/paper/2023/file/7721f1fea280e9ffae528dc78c732576-Paper-Conference.pdf). Advances in Neural Information Processing Systems 36 (2024).


1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ DA metrics :py:mod:`skada.metrics`
DeepEmbeddedValidation
SoftNeighborhoodDensity
CircularValidation
MixValScorer


Model Selection :py:mod:`skada.model_selection`
Expand Down
2 changes: 2 additions & 0 deletions skada/deep/tests/test_deep_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from skada.metrics import (
CircularValidation,
DeepEmbeddedValidation,
MixValScorer,
PredictionEntropyScorer,
SoftNeighborhoodDensity,
)
Expand All @@ -25,6 +26,7 @@
PredictionEntropyScorer(),
SoftNeighborhoodDensity(),
CircularValidation(),
MixValScorer(),
],
)
def test_generic_scorer_on_deepmodel(scorer, da_dataset):
Expand Down
122 changes: 122 additions & 0 deletions skada/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,125 @@ def _score(self, estimator, X, y, sample_domain=None):
score = self.source_scorer(y[source_idx], y_pred_source)

return self._sign * score


class MixValScorer(_BaseDomainAwareScorer):
"""
MixVal scorer for unsupervised domain adaptation.

This scorer uses mixup to create mixed samples from the target domain,
and evaluates the model's consistency on these mixed samples.

See [32]_ for details.

Parameters
----------
alpha : float, default=0.55
Mixing parameter for mixup.
random_state : int, RandomState instance or None, default=None
Controls the randomness of the mixing process.
greater_is_better : bool, default=True
Whether higher scores are better.
ice_type : {'both', 'intra', 'inter'}, default='both'
Type of ICE score to compute:
- 'both': Compute both intra-cluster and inter-cluster ICE scores (average).
- 'intra': Compute only intra-cluster ICE score.
- 'inter': Compute only inter-cluster ICE score.

Attributes
----------
alpha : float
Mixing parameter.
random_state : RandomState
Random number generator.
_sign : int
1 if greater_is_better is True, -1 otherwise.
ice_type : str
Type of ICE score to compute.

References
----------
.. [32] Dapeng Hu et al. Mixed Samples as Probes for Unsupervised Model
Selection in Domain Adaptation.
NeurIPS, 2023.
"""

def __init__(
self,
alpha=0.55,
random_state=None,
greater_is_better=True,
ice_type="both",
):
super().__init__()
self.alpha = alpha
self.random_state = random_state
self._sign = 1 if greater_is_better else -1
self.ice_type = ice_type

if self.ice_type not in ["both", "intra", "inter"]:
raise ValueError("ice_type must be 'both', 'intra', or 'inter'")

def _score(self, estimator, X, y=None, sample_domain=None, **params):
"""
Compute the Interpolation Consistency Evaluation (ICE) score.

Parameters
----------
estimator : object
The fitted estimator to evaluate.
X : array-like of shape (n_samples, n_features)
The input samples.
y : Ignored
Not used, present for API consistency by convention.
sample_domain : array-like, default=None
Domain labels for each sample.

Returns
-------
score : float
The ICE score.
"""
X, _, sample_domain = check_X_y_domain(X, y, sample_domain)
source_idx = extract_source_indices(sample_domain)
X_target = X[~source_idx]

rng = check_random_state(self.random_state)
rand_idx = rng.permutation(X_target.shape[0])

# Get predictions for target samples
labels_a = estimator.predict(X_target, sample_domain=sample_domain[~source_idx])
labels_b = labels_a[rand_idx]

# Intra-cluster and inter-cluster mixup
same_idx = (labels_a == labels_b).nonzero()[0]
diff_idx = (labels_a != labels_b).nonzero()[0]

# Mixup with images and hard pseudo labels
mix_inputs = self.alpha * X_target + (1 - self.alpha) * X_target[rand_idx]
mix_labels = self.alpha * labels_a + (1 - self.alpha) * labels_b

# Obtain predictions for the mixed samples
mix_pred = estimator.predict(
mix_inputs, sample_domain=np.full(mix_inputs.shape[0], -1)
)

# Calculate ICE scores based on ice_type
if self.ice_type in ["both", "intra"]:
ice_same = (
np.sum(mix_pred[same_idx] == mix_labels[same_idx]) / same_idx.shape[0]
)

if self.ice_type in ["both", "inter"]:
ice_diff = (
np.sum(mix_pred[diff_idx] == mix_labels[diff_idx]) / diff_idx.shape[0]
)

if self.ice_type == "both":
ice_score = (ice_same + ice_diff) / 2
elif self.ice_type == "intra":
ice_score = ice_same
else: # self.ice_type == 'inter'
ice_score = ice_diff

return self._sign * ice_score
51 changes: 51 additions & 0 deletions skada/tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
CircularValidation,
DeepEmbeddedValidation,
ImportanceWeightedScorer,
MixValScorer,
PredictionEntropyScorer,
SoftNeighborhoodDensity,
SupervisedScorer,
Expand Down Expand Up @@ -246,3 +247,53 @@ def test_deep_embedding_validation_no_transform(da_dataset):
)["test_score"]
assert scores.shape[0] == 3, "evaluate 3 splits"
assert np.all(~np.isnan(scores)), "all scores are computed"


def test_mixval_scorer(da_dataset):
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
estimator = make_da_pipeline(
DensityReweightAdapter(),
LogisticRegression()
.set_fit_request(sample_weight=True)
.set_score_request(sample_weight=True),
)
cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0)

# Test with default parameters
scorer = MixValScorer(alpha=0.55, random_state=42)
scores = cross_validate(
estimator,
X,
y,
cv=cv,
params={"sample_domain": sample_domain},
scoring=scorer,
)["test_score"]

assert scores.shape[0] == 3, "evaluate 3 splits"
assert np.all(~np.isnan(scores)), "all scores are computed"
assert np.all(scores >= 0) and np.all(scores <= 1), "scores are between 0 and 1"

# Test different ice_type options
for ice_type in ["both", "intra", "inter"]:
scorer = MixValScorer(alpha=0.55, random_state=42, ice_type=ice_type)
scores = cross_validate(
estimator,
X,
y,
cv=cv,
params={"sample_domain": sample_domain},
scoring=scorer,
)["test_score"]

assert scores.shape[0] == 3, f"evaluate 3 splits for ice_type={ice_type}"
assert np.all(
~np.isnan(scores)
), f"all scores are computed for ice_type={ice_type}"
assert np.all(scores >= 0) and np.all(
scores <= 1
), f"scores are between 0 and 1 for ice_type={ice_type}"

# Test invalid ice_type
with pytest.raises(ValueError):
MixValScorer(ice_type="invalid")
Loading