Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matching keys lengths and new metrics #358

Merged
merged 7 commits into from
Jun 30, 2024
Merged

Conversation

IdoAmosIBM
Copy link
Collaborator

Added 2 new metrics: mathews correlation coef. and balanced accuracy.
Added a new functionality to default collate allowing to match the length of some keys in the batch dict to other keys - example: for encoder model, matching the length of the labels to the length of the input after crop padding is applied.

Copy link
Collaborator

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ido!
Few comments and questions inline.

@@ -160,7 +160,10 @@ def collect(self, batch: Dict) -> None:
batch_to_collect = {}

for name, key in self._keys_to_collect.items():
value = batch[key]
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it just for debugging?
Or would you like to print a more informative message and then reraise the exception?


def mcc_wrapper(
self,
pred: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a string or a list of numpy array?
Are you expecting class predictions or scores after sofrmax?

@@ -247,3 +255,22 @@ def crop_padding(
cropped_sequences = [ids[:min_length] for ids in input_ids_list]
batched_sequences = torch.stack(cropped_sequences, dim=0)
return batched_sequences

@staticmethod
def match_length_to_target_key(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cropping?
What happens when the target_key is shorter then keys_to_match?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will raise an error, maybe I'll change the name length to crop_length_to_target_key? The motivation is when the encoder and labels should have the same length and pad cropping was applied to the encoder input - so we want to match the labels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean decoder input?
I asked Michal to solve it in a different way.
(labels and decoder should always have the same length).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I do mean encoder, this is for the case of an encoder-only model

@@ -43,6 +43,7 @@ def __init__(
keep_keys: Sequence[str] = tuple(),
raise_error_key_missing: bool = True,
special_handlers_keys: Optional[Dict[str, Callable]] = None,
post_collate_special_handlers_keys: Optional[List[Callable]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not to use special_handlers_keys instead?
Can you give an example of a use case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a use case is match_target_key_to_length from below, we can't use special handlers because it handles one element in the batch dict each time, we need to match one key to another after the special handler has been applied.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a list and not a single callable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to allow multiple use cases, e.g. aligning a set of keys to key A and a different set of keys to key B.


def balanced_acc_wrapper(
self,
pred: Union[List, np.ndarray],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not clear if you expect here sofrmax scores (we call pred) or class predictions (we call cls_pred).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I understand, I expect class predictions I used the same names as in MetricAccuracy. Do you want me to change to cls_pred?

mosheraboh
mosheraboh previously approved these changes Jun 26, 2024
Copy link
Collaborator

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ido!
I still have few comments.
In order not to block you, I'm ok with merging and potentially make some modifications in a new PR.

@SagiPolaczek SagiPolaczek marked this pull request as ready for review June 30, 2024 08:37
Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thanks!

@SagiPolaczek SagiPolaczek merged commit d690708 into master Jun 30, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants