diff --git a/examples/deep/plot_optimal_transport.py b/examples/deep/plot_optimal_transport.py index 67ce14dc..a34d7f11 100644 --- a/examples/deep/plot_optimal_transport.py +++ b/examples/deep/plot_optimal_transport.py @@ -49,7 +49,7 @@ batch_size=128, max_epochs=5, train_split=False, - reg=0.1, + reg_dist=0.1, reg_cl=0.01, lr=1e-2, ) diff --git a/skada/deep/_optimal_transport.py b/skada/deep/_optimal_transport.py index f1f0f547..f48f2fe1 100644 --- a/skada/deep/_optimal_transport.py +++ b/skada/deep/_optimal_transport.py @@ -24,6 +24,8 @@ class DeepJDOTLoss(BaseDALoss): Parameters ---------- + reg_dist : float, default=1 + Divergence regularization parameter. reg_cl : float, default=1 Class distance term regularization parameter. target_criterion : torch criterion (class) @@ -40,8 +42,9 @@ class DeepJDOTLoss(BaseDALoss): September 2018. Springer. """ - def __init__(self, reg_cl=1, target_criterion=None): + def __init__(self, reg_dist=1, reg_cl=1, target_criterion=None): super().__init__() + self.reg_dist = reg_dist self.reg_cl = reg_cl self.criterion_ = target_criterion @@ -61,6 +64,7 @@ def forward( y_pred_t, features_s, features_t, + self.reg_dist, self.reg_cl, criterion=self.criterion_, ) @@ -70,7 +74,7 @@ def forward( def DeepJDOT( module, layer_name, - reg=1, + reg_dist=1, reg_cl=1, base_criterion=None, target_criterion=None, @@ -117,8 +121,8 @@ def DeepJDOT( iterator_train=DomainBalancedDataLoader, criterion=DomainAwareCriterion, criterion__base_criterion=base_criterion, - criterion__adapt_criterion=DeepJDOTLoss(reg_cl, target_criterion), - criterion__reg=reg, + criterion__adapt_criterion=DeepJDOTLoss(reg_dist, reg_cl, target_criterion), + criterion__reg=1, **kwargs, ) return net diff --git a/skada/deep/losses.py b/skada/deep/losses.py index 7c1f2ceb..c4249365 100644 --- a/skada/deep/losses.py +++ b/skada/deep/losses.py @@ -47,6 +47,7 @@ def deepjdot_loss( y_pred_t, features_s, features_t, + reg_dist, reg_cl, sample_weights=None, target_sample_weights=None, @@ -64,6 +65,8 @@ def deepjdot_loss( features of the source data used to perform the distance matrix. features_t : tensor features of the target data used to perform the distance matrix. + reg_dist : float + Divergence term regularization parameter. reg_cl : float, default=1 Class distance term regularization parameter. sample_weights : tensor @@ -98,7 +101,7 @@ def deepjdot_loss( criterion = torch.nn.CrossEntropyLoss(reduction="none") loss_target = criterion(y_target_matrix, y_s.repeat(len(y_s), 1)).T - M = dist + reg_cl * loss_target + M = reg_dist * dist + reg_cl * loss_target # Compute the loss if sample_weights is None: diff --git a/skada/deep/tests/test_deep_optimal_transport.py b/skada/deep/tests/test_deep_optimal_transport.py index 5bd25f9a..46e4adfd 100644 --- a/skada/deep/tests/test_deep_optimal_transport.py +++ b/skada/deep/tests/test_deep_optimal_transport.py @@ -29,7 +29,7 @@ def test_deepjdot(): method = DeepJDOT( ToyModule2D(), - reg=1, + reg_dist=1, reg_cl=1, layer_name="dropout", batch_size=10,