From 53302cfb5932294561c029a9f172b887aea6872e Mon Sep 17 00:00:00 2001 From: antoinedemathelin Date: Thu, 29 Feb 2024 17:18:32 +0100 Subject: [PATCH] Fix disc + weights avg to one --- skada/_reweight.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/skada/_reweight.py b/skada/_reweight.py index 2deb6aaf..5d325056 100644 --- a/skada/_reweight.py +++ b/skada/_reweight.py @@ -105,7 +105,7 @@ def adapt(self, X, y=None, sample_domain=None): ws = self.weight_estimator_source_.score_samples(X[source_idx]) wt = self.weight_estimator_target_.score_samples(X[source_idx]) source_weights = np.exp(wt - ws) - source_weights /= source_weights.sum() + source_weights /= source_weights.mean() weights = np.zeros(X.shape[0], dtype=source_weights.dtype) weights[source_idx] = source_weights else: @@ -371,7 +371,10 @@ def adapt(self, X, y=None, sample_domain=None, **kwargs): # xxx(okachaiev): move this to API if source_idx.sum() > 0: source_idx, = np.where(source_idx) - source_weights = self.domain_classifier_.predict_proba(X[source_idx])[:, 1] + probas = self.domain_classifier_.predict_proba(X[source_idx])[:, 1] + probas = np.clip(probas, EPS, 1.) + source_weights = (1 - probas) / probas + source_weights /= source_weights.mean() weights = np.zeros(X.shape[0], dtype=source_weights.dtype) weights[source_idx] = source_weights else: