Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Handle edge cases for CAN #269

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ class CANLoss(BaseDALoss):
If None, uses sigmas proposed in [1]_.
target_kmeans : sklearn KMeans instance, default=None,
Pre-computed target KMeans clustering model.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.

References
----------
Expand All @@ -220,12 +222,14 @@ def __init__(
class_threshold=3,
sigmas=None,
target_kmeans=None,
eps=1e-7,
):
super().__init__()
self.distance_threshold = distance_threshold
self.class_threshold = class_threshold
self.sigmas = sigmas
self.target_kmeans = target_kmeans
self.eps = eps

def forward(
self,
Expand All @@ -245,6 +249,7 @@ def forward(
target_kmeans=self.target_kmeans,
distance_threshold=self.distance_threshold,
class_threshold=self.class_threshold,
eps=self.eps,
)

return loss
Expand Down
60 changes: 31 additions & 29 deletions skada/deep/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,36 @@ def on_epoch_begin(self, net, dataset_train=None, **kwargs):

X_t = X["X"][X["sample_domain"] < 0]

features_s = net.predict_features(X_s)
features_t = net.predict_features(X_t)

features_s = torch.tensor(features_s, device=net.device)
y_s = torch.tensor(y_s, device=net.device)

features_t = torch.tensor(features_t, device=net.device)

n_classes = len(y_s.unique())
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.mean(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)
# Disable gradient computation for feature extraction
with torch.no_grad():
features_s = net.predict_features(X_s)
features_t = net.predict_features(X_t)

features_s = torch.tensor(features_s, device=net.device)
y_s = torch.tensor(y_s, device=net.device)

features_t = torch.tensor(features_t, device=net.device)

n_classes = len(y_s.unique())
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)

net.criterion__adapt_criterion.target_kmeans = target_kmeans
82 changes: 53 additions & 29 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def cdd_loss(
sigmas=None,
distance_threshold=0.5,
class_threshold=3,
eps=1e-7,
):
"""Define the contrastive domain discrepancy loss based on [33]_.

Expand All @@ -225,6 +226,8 @@ def cdd_loss(
to far from the centroids.
class_threshold : int, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.

Returns
-------
Expand All @@ -240,31 +243,34 @@ def cdd_loss(
"""
n_classes = len(y_s.unique())

# Use pre-computed cluster_labels_t
# Use pre-computed target_kmeans
if target_kmeans is None:
warnings.warn(
"Source centroids are not computed for the whole training set, "
"computing them on the current batch set."
)
with torch.no_grad():
warnings.warn(
"Source centroids are not computed for the whole training set, "
"computing them on the current batch set."
)

source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)

# Predict clusters for target samples
cluster_labels_t = target_kmeans.predict(features_t)
Expand All @@ -283,10 +289,11 @@ def cdd_loss(
mask_t = valid_classes[cluster_labels_t]
features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]

# Define sigmas
if sigmas is None:
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
median_pairwise_distance = (
torch.median(torch.cdist(features_s, features_s)) + eps
)
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
Expand All @@ -299,26 +306,43 @@ def cdd_loss(
# Compute CDD
intraclass = 0
interclass = 0

for c1 in range(n_classes):
for c2 in range(c1, n_classes):
if valid_classes[c1] and valid_classes[c2]:
# Compute e1
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas)
mask_c1_c1 = (y_s == c1).float()
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)

# e1 measure the intra-class domain discrepancy
# Thus if mask_c1_c1.sum() = 0 --> e1 = 0
if mask_c1_c1.sum() > 0:
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)
else:
e1 = 0

# Compute e2
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas)
mask_c2_c2 = (cluster_labels_t == c2).float()
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)

# e2 measure the intra-class domain discrepancy
# Thus if mask_c2_c2.sum() = 0 --> e2 = 0
if mask_c2_c2.sum() > 0:
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)
else:
e2 = 0

# Compute e3
kernel_st = _gaussian_kernel(features_s, features_t, sigmas)
mask_c1 = (y_s == c1).float().unsqueeze(1)
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0)
mask_c1_c2 = mask_c1 * mask_c2
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)

# e3 measure the inter-class domain discrepancy
# Thus if mask_c1_c2.sum() = 0 --> e3 = 0
if mask_c1_c2.sum() > 0:
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)
else:
e3 = 0

if c1 == c2:
intraclass += e1 + e2 - 2 * e3
Expand Down
13 changes: 12 additions & 1 deletion skada/deep/tests/test_deep_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from skada.datasets import make_shifted_datasets
from skada.deep import CAN, DAN, DeepCoral
from skada.deep.losses import dan_loss
from skada.deep.losses import cdd_loss, dan_loss
from skada.deep.modules import ToyModule2D


Expand Down Expand Up @@ -199,6 +199,17 @@ def test_can_with_custom_callbacks():
assert "ComputeSourceCentroids" in callback_classes


def test_cdd_loss_edge_cases():
# Test when median pairwise distance is 0
features_s = torch.ones((4, 2)) # All features are identical
features_t = torch.randn((4, 2))
y_s = torch.tensor([0, 0, 1, 1]) # Two classes

# This should not raise any errors due to the eps we added
loss = cdd_loss(y_s, features_s, features_t)
assert not np.isnan(loss)


def test_dan_loss_edge_cases():
# Create identical source features to get median distance = 0
features_s = torch.tensor([[1.0, 2.0], [1.0, 2.0]], dtype=torch.float32)
Expand Down
Loading
Loading