-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matching keys lengths and new metrics #358
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ido!
Few comments and questions inline.
fuse/eval/metrics/metrics_common.py
Outdated
@@ -160,7 +160,10 @@ def collect(self, batch: Dict) -> None: | |||
batch_to_collect = {} | |||
|
|||
for name, key in self._keys_to_collect.items(): | |||
value = batch[key] | |||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it just for debugging?
Or would you like to print a more informative message and then reraise the exception?
|
||
def mcc_wrapper( | ||
self, | ||
pred: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it a string or a list of numpy array?
Are you expecting class predictions or scores after sofrmax?
fuse/data/utils/collates.py
Outdated
@@ -247,3 +255,22 @@ def crop_padding( | |||
cropped_sequences = [ids[:min_length] for ids in input_ids_list] | |||
batched_sequences = torch.stack(cropped_sequences, dim=0) | |||
return batched_sequences | |||
|
|||
@staticmethod | |||
def match_length_to_target_key( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cropping?
What happens when the target_key is shorter then keys_to_match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will raise an error, maybe I'll change the name length to crop_length_to_target_key
? The motivation is when the encoder and labels should have the same length and pad cropping was applied to the encoder input - so we want to match the labels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean decoder input?
I asked Michal to solve it in a different way.
(labels and decoder should always have the same length).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I do mean encoder, this is for the case of an encoder-only model
@@ -43,6 +43,7 @@ def __init__( | |||
keep_keys: Sequence[str] = tuple(), | |||
raise_error_key_missing: bool = True, | |||
special_handlers_keys: Optional[Dict[str, Callable]] = None, | |||
post_collate_special_handlers_keys: Optional[List[Callable]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not to use special_handlers_keys instead?
Can you give an example of a use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a use case is match_target_key_to_length
from below, we can't use special handlers because it handles one element in the batch dict each time, we need to match one key to another after the special handler has been applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it a list and not a single callable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to allow multiple use cases, e.g. aligning a set of keys to key A and a different set of keys to key B.
|
||
def balanced_acc_wrapper( | ||
self, | ||
pred: Union[List, np.ndarray], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still not clear if you expect here sofrmax scores (we call pred) or class predictions (we call cls_pred).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I understand, I expect class predictions I used the same names as in MetricAccuracy. Do you want me to change to cls_pred?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ido!
I still have few comments.
In order not to block you, I'm ok with merging and potentially make some modifications in a new PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks!
Added 2 new metrics: mathews correlation coef. and balanced accuracy.
Added a new functionality to default collate allowing to match the length of some keys in the batch dict to other keys - example: for encoder model, matching the length of the labels to the length of the input after crop padding is applied.