-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add MDD method #263
[MRG] Add MDD method #263
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #263 +/- ##
==========================================
- Coverage 97.20% 96.98% -0.22%
==========================================
Files 61 61
Lines 6334 6372 +38
==========================================
+ Hits 6157 6180 +23
- Misses 177 192 +15 |
skada/deep/_adversarial.py
Outdated
raise ValueError( | ||
"If domain_classifier is None, num_features has to be provided" | ||
) | ||
domain_classifier = DomainClassifier( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if we use here the domain Classifier or if we want a DiscClassifier
maybe the same that is in the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed all the nomenclature "domain" to "disc". In the paper, the disc_classifier must have the same architecture as the task classifier. For the moment, I propose to keep domain_classifier and then do another PR to the adversarial modules from the DomainClassifier (call DomainClassifier or DiscClassifier only in the loss definitions like DANNLoss, MDDLoss instead of the modules like DANN and MDD). It may also be needed to update the DomainAwareNet to change the "domain_classifier" to "disc_classifier".
It is almost ok for me |
Sounds good to me! |
Implement the MDD method from https://arxiv.org/pdf/1904.05801.
Related issue: #254
Core implementation
skada/deep/_adversarial.py
skada/deep/tests/test_deep_adversarial.py
skada/deep/modules.py
Documentation
skada/deep/__init__.py
README.md
docs/source/all.rst