Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add SPA method #276

Merged
merged 39 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f2f107b
add spa
tgnassou Nov 14, 2024
3ae125c
add spa
tgnassou Nov 14, 2024
e097db7
update spa
tgnassou Nov 14, 2024
4145223
merge
tgnassou Nov 14, 2024
c60819c
add sample_idx parameter
tgnassou Nov 18, 2024
fc34fd0
merge main
tgnassou Nov 18, 2024
ca9ecb8
remove SPALoss sampleidx attrbiute
tgnassou Nov 18, 2024
5e232e1
fix test
tgnassou Nov 18, 2024
2cb7154
add callback to init memorybank
tgnassou Nov 18, 2024
12275de
change default hp
tgnassou Nov 18, 2024
90d617e
choose euc
tgnassou Nov 18, 2024
0b9268c
Fix issue with memory_features on cpu
YanisLalou Nov 19, 2024
0fcd3ff
lint
YanisLalou Nov 19, 2024
0392e76
lint
YanisLalou Nov 19, 2024
5536269
Make it fast with matrix multiplications
YanisLalou Nov 19, 2024
959b2b5
Merge branch 'main' into spa
rflamary Nov 19, 2024
678de6c
Merge branch 'main' into spa
rflamary Nov 19, 2024
7780b22
add scheduler param
tgnassou Nov 20, 2024
2956ad5
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou Nov 20, 2024
b9dd30a
rm eff for adv loss
antoinecollas Nov 20, 2024
257ab6b
fix test
antoinecollas Nov 20, 2024
9279024
fix CountEpochs
antoinecollas Nov 20, 2024
d171add
rm default max_epochs in SPALoss
antoinecollas Nov 20, 2024
af38ea1
revert change of mapping
tgnassou Nov 20, 2024
0e342da
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
tgnassou Nov 20, 2024
539a397
fix scheduler
antoinecollas Nov 20, 2024
ee25bf3
call scheduler
antoinecollas Nov 20, 2024
0b47150
fix gauss kernel
antoinecollas Nov 20, 2024
bd8a698
Merge branch 'spa' of https://github.com/tgnassou/skada into spa
antoinecollas Nov 20, 2024
880d83b
revert _mapping
antoinecollas Nov 20, 2024
ad8ebf8
fix schedulers
antoinecollas Nov 20, 2024
3c06d8b
use gaussian kernel
antoinecollas Nov 20, 2024
d8bc359
rm torch deprecated functions
antoinecollas Nov 20, 2024
ee59ca0
skip test
antoinecollas Nov 20, 2024
91219e0
fix torch.linalg.norm
antoinecollas Nov 20, 2024
aafffdc
rm default sample_idx
antoinecollas Nov 20, 2024
64b08d0
fix distance and K nearest neighbors computation
antoinecollas Nov 20, 2024
29917b4
merge memory bank classe + better init
antoinecollas Nov 20, 2024
ccfa323
fix nap_loss call
antoinecollas Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,6 @@ The library is distributed under the 3-Clause BSD license.

[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020.

[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).
[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).

[36] Xiao, Zhiqing, Wang, Haobo, Jin, Ying, Feng, Lei, Chen, Gang, Huang, Fei, Zhao, Junbo.[SPA: A Graph Spectral Alignment Perspective for Domain Adaptation](https://arxiv.org/pdf/2310.17594). In Neurips, 2023.
3 changes: 3 additions & 0 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss, ModifiedCrossEntropyLoss
from ._class_confusion import MCC, MCCLoss
from ._graph_alignment import SPA, SPALoss
from ._baseline import SourceOnly, TargetOnly

from . import losses
Expand All @@ -45,6 +46,8 @@
"ModifiedCrossEntropyLoss",
"CANLoss",
"CAN",
'SPALoss',
'SPA',
"SourceOnly",
"TargetOnly",
]
32 changes: 16 additions & 16 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ def __init__(self, domain_criterion=None):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
domain_label = torch.zeros(
Expand Down Expand Up @@ -183,13 +179,11 @@ def __init__(self, domain_criterion=None):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
dtype = torch.float32
Expand Down Expand Up @@ -247,7 +241,14 @@ def __init__(
self.max_features = max_features
self.random_state = random_state

def forward(self, X, sample_domain=None, is_fit=False, return_features=False):
def forward(
self,
X,
sample_domain=None,
sample_idx=None,
is_fit=False,
return_features=False,
):
if is_fit:
# predict
y_pred = self.base_module_(X)
Expand Down Expand Up @@ -279,6 +280,7 @@ def forward(self, X, sample_domain=None, is_fit=False, return_features=False):
domain_pred,
features,
sample_domain,
sample_idx,
)
else:
if return_features:
Expand Down Expand Up @@ -448,22 +450,20 @@ def __init__(self, gamma=4.0):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
disc_pred_s,
disc_pred_t,
features_s,
features_t,
domain_pred_s,
domain_pred_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
# TODO: handle binary classification
# Multiclass classification
pseudo_label_s = torch.argmax(y_pred_s, axis=-1)
pseudo_label_t = torch.argmax(y_pred_t, axis=-1)

disc_loss_s = self.disc_criterion_s(disc_pred_s, pseudo_label_s)
disc_loss_t = self.disc_criterion_t(disc_pred_t, pseudo_label_t)
disc_loss_s = self.disc_criterion_s(domain_pred_s, pseudo_label_s)
disc_loss_t = self.disc_criterion_t(domain_pred_t, pseudo_label_t)

# Compute the MDD loss value
disc_loss = self.gamma * disc_loss_s - disc_loss_t
Expand Down
8 changes: 1 addition & 7 deletions skada/deep/_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@ def __init__(

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
return 0
Expand Down
7 changes: 1 addition & 6 deletions skada/deep/_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,8 @@ def __init__(self, T=1, eps=1e-7):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = mcc_loss(
Expand Down
17 changes: 3 additions & 14 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,9 @@ def __init__(

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = deepcoral_loss(features_s, features_t, self.assume_centered)
Expand Down Expand Up @@ -132,13 +128,9 @@ def __init__(self, sigmas=None, eps=1e-7):

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
"""Compute the domain adaptation loss"""
loss = dan_loss(features_s, features_t, sigmas=self.sigmas, eps=self.eps)
Expand Down Expand Up @@ -235,12 +227,9 @@ def __init__(
def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
**kwargs,
):
loss = cdd_loss(
y_s,
Expand Down
Loading
Loading