Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add SPA method #276

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f2f107b
add spa
tgnassou Nov 14, 2024
3ae125c
add spa
tgnassou Nov 14, 2024
e097db7
update spa
tgnassou Nov 14, 2024
4145223
merge
tgnassou Nov 14, 2024
c60819c
add sample_idx parameter
tgnassou Nov 18, 2024
fc34fd0
merge main
tgnassou Nov 18, 2024
ca9ecb8
remove SPALoss sampleidx attrbiute
tgnassou Nov 18, 2024
5e232e1
fix test
tgnassou Nov 18, 2024
2cb7154
add callback to init memorybank
tgnassou Nov 18, 2024
12275de
change default hp
tgnassou Nov 18, 2024
90d617e
choose euc
tgnassou Nov 18, 2024
0b9268c
Fix issue with memory_features on cpu
YanisLalou Nov 19, 2024
0fcd3ff
lint
YanisLalou Nov 19, 2024
0392e76
lint
YanisLalou Nov 19, 2024
5536269
Make it fast with matrix multiplications
YanisLalou Nov 19, 2024
959b2b5
Merge branch 'main' into spa
rflamary Nov 19, 2024
678de6c
Merge branch 'main' into spa
rflamary Nov 19, 2024
7780b22
add scheduler param
tgnassou Nov 20, 2024
2956ad5
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou Nov 20, 2024
b9dd30a
rm eff for adv loss
antoinecollas Nov 20, 2024
257ab6b
fix test
antoinecollas Nov 20, 2024
9279024
fix CountEpochs
antoinecollas Nov 20, 2024
d171add
rm default max_epochs in SPALoss
antoinecollas Nov 20, 2024
af38ea1
revert change of mapping
tgnassou Nov 20, 2024
0e342da
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou Nov 20, 2024
539a397
fix scheduler
antoinecollas Nov 20, 2024
ee25bf3
call scheduler
antoinecollas Nov 20, 2024
0b47150
fix gauss kernel
antoinecollas Nov 20, 2024
bd8a698
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
antoinecollas Nov 20, 2024
880d83b
revert _mapping
antoinecollas Nov 20, 2024
ad8ebf8
fix schedulers
antoinecollas Nov 20, 2024
3c06d8b
use gaussian kernel
antoinecollas Nov 20, 2024
d8bc359
rm torch deprecated functions
antoinecollas Nov 20, 2024
ee59ca0
skip test
antoinecollas Nov 20, 2024
91219e0
fix torch.linalg.norm
antoinecollas Nov 20, 2024
aafffdc
rm default sample_idx
antoinecollas Nov 20, 2024
64b08d0
fix distance and K nearest neighbors computation
antoinecollas Nov 20, 2024
29917b4
merge memory bank classe + better init
antoinecollas Nov 20, 2024
ccfa323
fix nap_loss call
antoinecollas Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,6 @@ The library is distributed under the 3-Clause BSD license.

[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020.

[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).
[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).

[36] Xiao, Zhiqing, Wang, Haobo, Jin, Ying, Feng, Lei, Chen, Gang, Huang, Fei, Zhao, Junbo.[SPA: A Graph Spectral Alignment Perspective for Domain Adaptation](https://arxiv.org/pdf/2310.17594). In Neurips, 2023.
3 changes: 3 additions & 0 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss, ModifiedCrossEntropyLoss
from ._class_confusion import MCC, MCCLoss
from ._graph_alignment import SPA, SPALoss
from ._baseline import SourceOnly, TargetOnly

from . import losses
Expand All @@ -45,6 +46,8 @@
"ModifiedCrossEntropyLoss",
"CANLoss",
"CAN",
'SPALoss',
'SPA',
"SourceOnly",
"TargetOnly",
]
32 changes: 16 additions & 16 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ def __init__(self, domain_criterion=None):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
domain_label = torch.zeros(
Expand Down Expand Up @@ -183,13 +179,11 @@ def __init__(self, domain_criterion=None):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
dtype = torch.float32
Expand Down Expand Up @@ -247,7 +241,14 @@ def __init__(
self.max_features = max_features
self.random_state = random_state

def forward(self, X, sample_domain=None, is_fit=False, return_features=False):
def forward(
self,
X,
sample_domain=None,
sample_idx=None,
is_fit=False,
return_features=False,
):
if is_fit:
# predict
y_pred = self.base_module_(X)
Expand Down Expand Up @@ -279,6 +280,7 @@ def forward(self, X, sample_domain=None, is_fit=False, return_features=False):
domain_pred,
features,
sample_domain,
sample_idx,
)
else:
if return_features:
Expand Down Expand Up @@ -417,22 +419,20 @@ def __init__(self, gamma=4.0):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
disc_pred_s,
disc_pred_t,
features_s,
features_t,
domain_pred_s,
domain_pred_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
# TODO: handle binary classification
# Multiclass classification
pseudo_label_s = torch.argmax(y_pred_s, axis=-1)
pseudo_label_t = torch.argmax(y_pred_t, axis=-1)

disc_loss_s = self.disc_criterion_s(disc_pred_s, pseudo_label_s)
disc_loss_t = self.disc_criterion_t(disc_pred_t, pseudo_label_t)
disc_loss_s = self.disc_criterion_s(domain_pred_s, pseudo_label_s)
disc_loss_t = self.disc_criterion_t(domain_pred_t, pseudo_label_t)

# Compute the MDD loss value
disc_loss = self.gamma * disc_loss_s - disc_loss_t
Expand Down
8 changes: 1 addition & 7 deletions skada/deep/_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@ def __init__(

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
return 0
Expand Down
7 changes: 1 addition & 6 deletions skada/deep/_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,8 @@ def __init__(self, T=1, eps=1e-7):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = mcc_loss(
Expand Down
17 changes: 3 additions & 14 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,9 @@ def __init__(

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = deepcoral_loss(features_s, features_t, self.assume_centered)
Expand Down Expand Up @@ -132,13 +128,9 @@ def __init__(self, sigmas=None, eps=1e-7):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = dan_loss(features_s, features_t, sigmas=self.sigmas, eps=self.eps)
Expand Down Expand Up @@ -235,12 +227,9 @@ def __init__(
def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
loss = cdd_loss(
y_s,
Expand Down
201 changes: 201 additions & 0 deletions skada/deep/_graph_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Author: Theo Gnassounou <[email protected]>
#
# License: BSD 3-Clause

import torch

from skada.deep.base import (
BaseDALoss,
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)
from skada.deep.callbacks import ComputeMemoryBank, MemoryBankInit
from skada.deep.losses import gda_loss, nap_loss

from .modules import DomainClassifier


class SPALoss(BaseDALoss):
"""Loss SPA.

This loss tries to minimize the divergence between features with
adversarial method. The weights are updated to make harder
to classify domains (i.e., remove domain-specific features).

See [36]_ for details.

Parameters
----------
target_criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
adversarial loss. If None, a BCELoss is used.
reg_adv : float, default=1
Regularization parameter for adversarial loss.
reg_gsa : float, default=1
Regularization parameter for graph alignment loss
reg_nap : float, default=1
Regularization parameter for nap loss

References
----------
.. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for
Domain Adaptation. In Neurips, 2023.
"""

def __init__(
self,
domain_criterion=None,
memory_features=None,
memory_outputs=None,
K=5,
reg_adv=1,
reg_gsa=1,
reg_nap=1,
):
super().__init__()
if domain_criterion is None:
self.domain_criterion_ = torch.nn.BCELoss()
else:
self.domain_criterion_ = domain_criterion

self.reg_adv = reg_adv
self.reg_gsa = reg_gsa
self.reg_nap = reg_nap
self.K = K
self.memory_features = memory_features
self.memory_outputs = memory_outputs

def forward(
self,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
sample_idx_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
domain_label = torch.zeros(
(domain_pred_s.size()[0]),
device=domain_pred_s.device,
)
domain_label_target = torch.ones(
(domain_pred_t.size()[0]),
device=domain_pred_t.device,
)

# update classification function
loss_adv = self.domain_criterion_(
domain_pred_s, domain_label
) + self.domain_criterion_(domain_pred_t, domain_label_target)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in eq (1) of the paper, they use both the cross entropy and the modified cross entropy loss (log(1-prob))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also use label smoothing as the code of the paper: https://github.com/CrownX/SPA/blob/main/code/loss.py#L65-L84


loss_gda = self.reg_gsa * gda_loss(features_s, features_t)

loss_pl = self.reg_nap * nap_loss(
features_s,
features_t,
self.memory_features,
self.memory_outputs,
K=self.K,
sample_idx=sample_idx_t,
)
loss = loss_adv + loss_gda + loss_pl
return loss


def SPA(
module,
layer_name,
reg_adv=1,
reg_gsa=1,
reg_nap=1,
domain_classifier=None,
num_features=None,
base_criterion=None,
domain_criterion=None,
callbacks=None,
**kwargs,
):
"""Domain Adaptation with SPA.

From [36]_.

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
layer_name : str
The name of the module's layer whose outputs are
collected during the training.
reg : float, default=1
Regularization parameter for DA loss.
domain_classifier : torch module, default=None
A PyTorch :class:`~torch.nn.Module` used to classify the
domain. If None, a domain classifier is created following [1]_.
num_features : int, default=None
Size of the input of domain classifier,
e.g size of the last layer of
the feature extractor.
If domain_classifier is None, num_features has to be
provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
domain_criterion : torch criterion (class)
The criterion (loss) used to compute the
adversarial loss. If None, a BCELoss is used.

References
----------
.. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for
Domain Adaptation. In Neurips, 2023.
"""
if domain_classifier is None:
# raise error if num_feature is None
if num_features is None:
raise ValueError(
"If domain_classifier is None, num_features has to be provided"
)
domain_classifier = DomainClassifier(num_features=num_features)

if callbacks is None:
callbacks = [
ComputeMemoryBank(),
MemoryBankInit(),
]
else:
if isinstance(callbacks, list):
callbacks.append(ComputeMemoryBank())
callbacks.append(MemoryBankInit())
else:
callbacks = [
callbacks,
ComputeMemoryBank(),
MemoryBankInit(),
]
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__domain_classifier=domain_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=1,
criterion__adapt_criterion=SPALoss(
domain_criterion=domain_criterion,
reg_adv=reg_adv,
reg_gsa=reg_gsa,
reg_nap=reg_nap,
),
callbacks=callbacks,
**kwargs,
)
return net
4 changes: 1 addition & 3 deletions skada/deep/_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ def __init__(self, reg_dist=1, reg_cl=1, target_criterion=None):
def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = deepjdot_loss(
Expand Down
Loading
Loading