-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Fix mdd loss #277
[MRG] Fix mdd loss #277
Conversation
It seems nice, but why adding the new arguments domain_criterion_t and domain_criterion_s? I feel it adds too much complexity. Maybe we can say that the original MDD paper only focus on multi classification, and we do the same here. We can almost remove the domain_criterion argument and hardset it to be crossentropy for source and custom crossentropy for target. We can let to future PRs the generalization of MDD to other losses? |
I agree, I change this |
Ok, I just have one comment, but the rest is ok for me. |
"""Compute the modified CrossEntropyLoss""" | ||
prob = F.softmax(input, dim=-1) | ||
prob = prob[..., target] | ||
log_one_minus_prob = torch.log(1 - prob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use this one: https://pytorch.org/docs/stable/generated/torch.log1p.html
Use equations (26)-(29) from https://arxiv.org/pdf/1904.05801