Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Handle edge cases for CAN #269

Merged
merged 4 commits into from
Oct 31, 2024

Conversation

YanisLalou
Copy link
Collaborator

@YanisLalou YanisLalou commented Oct 29, 2024

This pull request introduces several improvements focusing on enhancing numerical stability and performance. The changes include adding gradient computation disabling for specific operations, and ensuring edge cases are handled correctly in loss functions.

Enhancements to Numerical Stability and Performance:

  • skada/deep/callbacks.py: Disabled gradient computation for feature extraction in on_epoch_begin to improve performance.
  • skada/deep/utils.py: Updated _compute_dissimilarities to compute dissimilarities in batches, improving performance for large datasets.
  • skada/deep/losses.py: Added a small epsilon (eps) to the median pairwise distance calculation to avoid division by zero errors.
  • skada/deep/utils.py: Wrapped centroid initialization, fitting, and prediction methods in torch.no_grad() to prevent unnecessary gradient computations. [1] [2] [3]

Refinements in Centroid Calculations:

  • skada/deep/callbacks.py: Changed centroid calculation from mean to sum in on_epoch_begin to ensure consistency with normalization.
  • skada/deep/losses.py: Ensured source_centroids is stacked properly after calculation in cdd_loss.

Handling Edge Cases in Loss Functions:

  • skada/deep/losses.py: Added checks to ensure intra-class and inter-class domain discrepancy measures (e1, e2, e3) are set to zero if their corresponding masks sum to zero. This prevents potential errors from invalid operations.

Enhancements to tests:

# Define sigmas
if sigmas is None:
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
eps = 1e-7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps as parameter

@YanisLalou YanisLalou changed the title [TO_REVIEW] Handle edge cases for CAN [MRG] Handle edge cases for CAN Oct 31, 2024
@YanisLalou YanisLalou merged commit f4f68b1 into scikit-adaptation:main Oct 31, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants