Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Document.add_all_annotations_from_other() #429

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def add_all_annotations_from_other(
process_predictions: bool = True,
strict: bool = True,
verbose: bool = True,
) -> Dict[str, Dict[Annotation, Annotation]]:
) -> Dict[str, Dict[int, Annotation]]:
"""Adds all annotations from another document to this document. It allows to blacklist annotations
and also to override annotations. It returns the original annotations for which a new annotation was
added to the current document.
Expand Down Expand Up @@ -862,7 +862,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
```
"""
removed_annotations = defaultdict(set, removed_annotations or dict())
added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict)
added_annotations: Dict[str, Dict[int, Annotation]] = defaultdict(dict)

annotation_store: Dict[str, Dict[int, Annotation]] = defaultdict(dict)
named_annotation_fields = {field.name: field for field in self.annotation_fields()}
Expand Down Expand Up @@ -905,7 +905,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].append(new_ann)
added_annotations[field_name][ann] = new_ann
added_annotations[field_name][ann._id] = new_ann
else:
if strict:
raise ValueError(
Expand All @@ -930,7 +930,7 @@ class TokenBasedDocumentWithEntitiesRelationsAndRelationAttributes(TokenBasedDoc
if ann._id != new_ann._id:
annotation_store[field_name][ann._id] = new_ann
self[field_name].predictions.append(new_ann)
added_annotations[field_name][ann] = new_ann
added_annotations[field_name][ann._id] = new_ann
else:
if strict:
raise ValueError(
Expand Down
21 changes: 11 additions & 10 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,11 @@ def test_document_extend_from_other_full_copy(text_document):
for layer_name, annotation_mapping in added_annotations.items():
assert len(annotation_mapping) > 0
available_annotations = text_document[layer_name]
assert set(annotation_mapping) == set(available_annotations)
available_annotation_ids = [a._id for a in available_annotations]
assert set(annotation_mapping) == set(available_annotation_ids)
assert len(annotation_mapping) == 1
# since we have only one annotation, we can construct the expected mapping
assert annotation_mapping == {available_annotations[0]: doc_new[layer_name][0]}
assert annotation_mapping == {available_annotation_ids[0]: doc_new[layer_name][0]}


def test_document_extend_from_other_wrong_override_annotation_mapping(text_document):
Expand Down Expand Up @@ -711,16 +712,16 @@ class TestDocument2(TokenBasedDocument):
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
# check that the added annotations are as expected (the entity annotations are already there)
assert added_annotation_sets == {
"relations": set(text_document.relations),
"relation_attributes": set(text_document.relation_attributes),
"labels": set(text_document.labels),
"relations": {ann._id for ann in text_document.relations},
"relation_attributes": {ann._id for ann in text_document.relation_attributes},
"labels": {ann._id for ann in text_document.labels},
}
for layer_name, annotation_mapping in added_annotations.items():
text_annotations = text_document[layer_name]
token_annotations = token_document[layer_name]
assert len(annotation_mapping) == len(text_annotations) == len(token_annotations) == 1
# since we have only one annotation, we can construct the expected mapping
assert annotation_mapping == {text_annotations[0]: token_annotations[0]}
assert annotation_mapping == {text_annotations[0]._id: token_annotations[0]}

assert (
len(token_document.entities1)
Expand Down Expand Up @@ -753,12 +754,12 @@ def test_document_extend_from_other_remove(text_document):
added_annotation_sets = {k: set(v) for k, v in added_annotations.items()}
# the only entity in entities1 is removed and since the relation has it as head, the relation is removed as well
assert added_annotation_sets == {
"entities2": set(text_document.entities2),
"labels": set(text_document.labels),
"entities2": {ann._id for ann in text_document.entities2},
"labels": {ann._id for ann in text_document.labels},
}
assert added_annotations == {
"entities2": {text_document.entities2[0]: doc_new.entities2[0]},
"labels": {text_document.labels[0]: doc_new.labels[0]},
"entities2": {text_document.entities2[0]._id: doc_new.entities2[0]},
"labels": {text_document.labels[0]._id: doc_new.labels[0]},
}

assert len(doc_new.entities1) == 0
Expand Down
Loading