-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add SPA method #276
Merged
Merged
[MRG] Add SPA method #276
Changes from 9 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
f2f107b
add spa
tgnassou 3ae125c
add spa
tgnassou e097db7
update spa
tgnassou 4145223
merge
tgnassou c60819c
add sample_idx parameter
tgnassou fc34fd0
merge main
tgnassou ca9ecb8
remove SPALoss sampleidx attrbiute
tgnassou 5e232e1
fix test
tgnassou 2cb7154
add callback to init memorybank
tgnassou 12275de
change default hp
tgnassou 90d617e
choose euc
tgnassou 0b9268c
Fix issue with memory_features on cpu
YanisLalou 0fcd3ff
lint
YanisLalou 0392e76
lint
YanisLalou 5536269
Make it fast with matrix multiplications
YanisLalou 959b2b5
Merge branch 'main' into spa
rflamary 678de6c
Merge branch 'main' into spa
rflamary 7780b22
add scheduler param
tgnassou 2956ad5
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou b9dd30a
rm eff for adv loss
antoinecollas 257ab6b
fix test
antoinecollas 9279024
fix CountEpochs
antoinecollas d171add
rm default max_epochs in SPALoss
antoinecollas af38ea1
revert change of mapping
tgnassou 0e342da
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou 539a397
fix scheduler
antoinecollas ee25bf3
call scheduler
antoinecollas 0b47150
fix gauss kernel
antoinecollas bd8a698
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
antoinecollas 880d83b
revert _mapping
antoinecollas ad8ebf8
fix schedulers
antoinecollas 3c06d8b
use gaussian kernel
antoinecollas d8bc359
rm torch deprecated functions
antoinecollas ee59ca0
skip test
antoinecollas 91219e0
fix torch.linalg.norm
antoinecollas aafffdc
rm default sample_idx
antoinecollas 64b08d0
fix distance and K nearest neighbors computation
antoinecollas 29917b4
merge memory bank classe + better init
antoinecollas ccfa323
fix nap_loss call
antoinecollas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
|
||
import torch | ||
|
||
from skada.deep.base import ( | ||
BaseDALoss, | ||
DomainAwareCriterion, | ||
DomainAwareModule, | ||
DomainAwareNet, | ||
DomainBalancedDataLoader, | ||
) | ||
from skada.deep.callbacks import ComputeMemoryBank, MemoryBankInit | ||
from skada.deep.losses import gda_loss, nap_loss | ||
|
||
from .modules import DomainClassifier | ||
|
||
|
||
class SPALoss(BaseDALoss): | ||
"""Loss SPA. | ||
|
||
This loss tries to minimize the divergence between features with | ||
adversarial method. The weights are updated to make harder | ||
to classify domains (i.e., remove domain-specific features). | ||
|
||
See [36]_ for details. | ||
|
||
Parameters | ||
---------- | ||
target_criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
adversarial loss. If None, a BCELoss is used. | ||
reg_adv : float, default=1 | ||
Regularization parameter for adversarial loss. | ||
reg_gsa : float, default=1 | ||
Regularization parameter for graph alignment loss | ||
reg_nap : float, default=1 | ||
Regularization parameter for nap loss | ||
|
||
References | ||
---------- | ||
.. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for | ||
Domain Adaptation. In Neurips, 2023. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
domain_criterion=None, | ||
memory_features=None, | ||
memory_outputs=None, | ||
K=5, | ||
reg_adv=1, | ||
reg_gsa=1, | ||
reg_nap=1, | ||
): | ||
super().__init__() | ||
if domain_criterion is None: | ||
self.domain_criterion_ = torch.nn.BCELoss() | ||
else: | ||
self.domain_criterion_ = domain_criterion | ||
|
||
self.reg_adv = reg_adv | ||
self.reg_gsa = reg_gsa | ||
self.reg_nap = reg_nap | ||
self.K = K | ||
self.memory_features = memory_features | ||
self.memory_outputs = memory_outputs | ||
|
||
def forward( | ||
self, | ||
domain_pred_s, | ||
domain_pred_t, | ||
features_s, | ||
features_t, | ||
sample_idx_t, | ||
**kwargs, | ||
): | ||
"""Compute the domain adaptation loss""" | ||
domain_label = torch.zeros( | ||
(domain_pred_s.size()[0]), | ||
device=domain_pred_s.device, | ||
) | ||
domain_label_target = torch.ones( | ||
(domain_pred_t.size()[0]), | ||
device=domain_pred_t.device, | ||
) | ||
|
||
# update classification function | ||
loss_adv = self.domain_criterion_( | ||
domain_pred_s, domain_label | ||
) + self.domain_criterion_(domain_pred_t, domain_label_target) | ||
|
||
loss_gda = self.reg_gsa * gda_loss(features_s, features_t) | ||
|
||
loss_pl = self.reg_nap * nap_loss( | ||
features_s, | ||
features_t, | ||
self.memory_features, | ||
self.memory_outputs, | ||
K=self.K, | ||
sample_idx=sample_idx_t, | ||
) | ||
loss = loss_adv + loss_gda + loss_pl | ||
return loss | ||
|
||
|
||
def SPA( | ||
module, | ||
layer_name, | ||
reg_adv=1, | ||
reg_gsa=1, | ||
reg_nap=1, | ||
domain_classifier=None, | ||
num_features=None, | ||
base_criterion=None, | ||
domain_criterion=None, | ||
callbacks=None, | ||
**kwargs, | ||
): | ||
"""Domain Adaptation with SPA. | ||
|
||
From [36]_. | ||
|
||
Parameters | ||
---------- | ||
module : torch module (class or instance) | ||
A PyTorch :class:`~torch.nn.Module`. In general, the | ||
uninstantiated class should be passed, although instantiated | ||
modules will also work. | ||
layer_name : str | ||
The name of the module's layer whose outputs are | ||
collected during the training. | ||
reg : float, default=1 | ||
Regularization parameter for DA loss. | ||
domain_classifier : torch module, default=None | ||
A PyTorch :class:`~torch.nn.Module` used to classify the | ||
domain. If None, a domain classifier is created following [1]_. | ||
num_features : int, default=None | ||
Size of the input of domain classifier, | ||
e.g size of the last layer of | ||
the feature extractor. | ||
If domain_classifier is None, num_features has to be | ||
provided. | ||
base_criterion : torch criterion (class) | ||
The base criterion used to compute the loss with source | ||
labels. If None, the default is `torch.nn.CrossEntropyLoss`. | ||
domain_criterion : torch criterion (class) | ||
The criterion (loss) used to compute the | ||
adversarial loss. If None, a BCELoss is used. | ||
|
||
References | ||
---------- | ||
.. [36] Xiao et. al. SPA: A Graph Spectral Alignment Perspective for | ||
Domain Adaptation. In Neurips, 2023. | ||
""" | ||
if domain_classifier is None: | ||
# raise error if num_feature is None | ||
if num_features is None: | ||
raise ValueError( | ||
"If domain_classifier is None, num_features has to be provided" | ||
) | ||
domain_classifier = DomainClassifier(num_features=num_features) | ||
|
||
if callbacks is None: | ||
callbacks = [ | ||
ComputeMemoryBank(), | ||
MemoryBankInit(), | ||
] | ||
else: | ||
if isinstance(callbacks, list): | ||
callbacks.append(ComputeMemoryBank()) | ||
callbacks.append(MemoryBankInit()) | ||
else: | ||
callbacks = [ | ||
callbacks, | ||
ComputeMemoryBank(), | ||
MemoryBankInit(), | ||
] | ||
if base_criterion is None: | ||
base_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
net = DomainAwareNet( | ||
module=DomainAwareModule, | ||
module__base_module=module, | ||
module__layer_name=layer_name, | ||
module__domain_classifier=domain_classifier, | ||
iterator_train=DomainBalancedDataLoader, | ||
criterion=DomainAwareCriterion, | ||
criterion__base_criterion=base_criterion, | ||
criterion__reg=1, | ||
criterion__adapt_criterion=SPALoss( | ||
domain_criterion=domain_criterion, | ||
reg_adv=reg_adv, | ||
reg_gsa=reg_gsa, | ||
reg_nap=reg_nap, | ||
), | ||
callbacks=callbacks, | ||
**kwargs, | ||
) | ||
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in eq (1) of the paper, they use both the cross entropy and the modified cross entropy loss (log(1-prob))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also use label smoothing as the code of the paper: https://github.com/CrownX/SPA/blob/main/code/loss.py#L65-L84