Skip to content

Commit

Permalink
implement typed annotation collection
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Mar 4, 2022
1 parent 2c6602c commit 3e0c861
Show file tree
Hide file tree
Showing 19 changed files with 141 additions and 246 deletions.
2 changes: 1 addition & 1 deletion examples/predict/ner_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main():

ner_pipeline(document, predict_field="entities")

for entity in document.predictions["entities"].as_spans:
for entity in document.predictions.spans["entities"]:
entity_text = document.text[entity.start : entity.end]
label = entity.label
print(f"{entity_text} -> {label}")
Expand Down
2 changes: 1 addition & 1 deletion examples/predict/re_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
pipeline(document, predict_field="relations")

relation: BinaryRelation
for relation in document.predictions["relations"].as_binary_relations:
for relation in document.predictions.binary_relations["relations"]:
head, tail = relation.head, relation.tail
head_text = document.text[head.start : head.end]
tail_text = document.text[tail.start : tail.end]
Expand Down
2 changes: 1 addition & 1 deletion examples/predict/re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():

re_pipeline(document, predict_field="relations", batch_size=2)

for relation in document.predictions["relations"].as_binary_relations:
for relation in document.predictions.binary_relations["relations"]:
head, tail = relation.head, relation.tail
head_text = document.text[head.start : head.end]
tail_text = document.text[tail.start : tail.end]
Expand Down
18 changes: 9 additions & 9 deletions src/pytorch_ie/data/datasets/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def convert_brat_to_document(
doc = Document(text=brat_doc["context"], doc_id=brat_doc["file_name"])

# add spans
doc.annotations.create_layer(name=span_annotation_name)
doc.annotations.spans.create_layer(name=span_annotation_name)
span_id_mapping = {}
for brat_span in dl_to_ld(brat_doc["spans"]):
locations = dl_to_ld(brat_span["locations"])
Expand Down Expand Up @@ -85,10 +85,10 @@ def convert_brat_to_document(
brat_span["id"] not in span_id_mapping
), f'brat span id "{brat_span["id"]}" already exists'
span_id_mapping[brat_span["id"]] = span
doc.annotations.add(name=span_annotation_name, annotation=span)
doc.add_annotation(name=span_annotation_name, annotation=span)

# add relations
doc.annotations.create_layer(name=relation_annotation_name)
doc.annotations.binary_relations.create_layer(name=relation_annotation_name)
for brat_relation in dl_to_ld(brat_doc["relations"]):
# strip annotation type identifier from id
metadata = {"id": brat_relation["id"][1:]}
Expand All @@ -102,7 +102,7 @@ def convert_brat_to_document(
relation = BinaryRelation(
label=brat_relation["type"], head=head, tail=tail, metadata=metadata
)
doc.annotations.add(name=relation_annotation_name, annotation=relation)
doc.add_annotation(name=relation_annotation_name, annotation=relation)

# add events -> not yet implement
# add equivalence_relations -> not yet implement
Expand Down Expand Up @@ -221,20 +221,20 @@ def convert_document_to_brat(
if relation_annotation_names is None:
relation_annotation_names = [DEFAULT_RELATION_ANNOTATION_NAME]
for entity_annotation_name in span_annotation_names:
if doc.annotations.has_layer(entity_annotation_name):
for span_ann in doc.annotations[entity_annotation_name].as_spans:
if doc.annotations.spans.has_layer(entity_annotation_name):
for span_ann in doc.annotations.spans[entity_annotation_name]:
serialized_annotations[span_ann] = serialize_labeled_span(span_ann, doc, **kwargs)
for relation_annotation_name in relation_annotation_names:
if doc.annotations.has_layer(relation_annotation_name):
for rel_ann in doc.annotations[relation_annotation_name].as_binary_relations:
if doc.annotations.binary_relations.has_layer(relation_annotation_name):
for rel_ann in doc.annotations.binary_relations[relation_annotation_name]:
serialized_annotations[rel_ann] = serialize_binary_relation(
rel_ann,
doc,
head_argument_name=head_argument_name,
tail_argument_name=tail_argument_name,
**kwargs,
)
for name, annots in doc.annotations.named_layers:
for layer_type, name, annots in doc.annotations.typed_named_layers:
not_serialized = [ann for ann in annots if ann not in serialized_annotations]
if len(not_serialized) > 0:
logger.warning(
Expand Down
128 changes: 68 additions & 60 deletions src/pytorch_ie/data/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)


class Annotation:
Expand Down Expand Up @@ -144,58 +156,12 @@ def __repr__(self) -> str:
)


# simple list for now
AnnotationLayer = list
T_annotation = TypeVar("T_annotation", bound=Annotation)


class AnnotationLayer(List[T_annotation]):
"""
An AnnotationLayer is a special List with some sanity checks and typed getters. It is ensured that all
entries have the same type.
It is totally optional to use the typed getters, they are just available to
ease typing, e.g. the following would not cause any trouble for mypy:
layer = AnnotationLayer([SpanAnnotation(start=0, end=2, label="e1")])
entity = layer.as_spans[0]
start, end = entity.start, entity.end # access to .end and .start would cause issues otherwise
"""

def _check_type(self, type_to_check: Type):
if len(self) > 0 and not isinstance(self[0], type_to_check):
raise TypeError(
f"Entry caused a type mismatch. Expected type: {type(self[0])}, actual type: {type_to_check}."
)

def append(self, __object: T_annotation) -> None:
self._check_type(type(__object))
super().append(__object)

def extend(self, __iterable: Iterable[T_annotation]) -> None:
for e in __iterable:
self.append(e)

def __setitem__(self, key, value):
self._check_type(type(value))
super().__setitem__(i=key, o=value)

def as_type(self, annotation_type: Type) -> List[T_annotation]:
self._check_type(annotation_type)
return self

@property
def as_spans(self) -> List[LabeledSpan]:
return cast(List[LabeledSpan], self.as_type(LabeledSpan))

@property
def as_binary_relations(self) -> List[BinaryRelation]:
return cast(List[BinaryRelation], self.as_type(BinaryRelation))

@property
def as_labels(self) -> List[Label]:
return cast(List[Label], self.as_type(Label))


class AnnotationCollection(Dict[str, AnnotationLayer]):
class TypedAnnotationCollection(Generic[T_annotation], Dict[str, AnnotationLayer[T_annotation]]):
"""
An `AnnotationCollection` holds a mapping from layer names to `AnnotationLayers`. However, it
also provides an `add` method to directly add an Annotation to a certain layer and create that if necessary.
Expand All @@ -204,20 +170,58 @@ class AnnotationCollection(Dict[str, AnnotationLayer]):
def has_layer(self, name: str) -> bool:
return name in self

def add(self, name: str, annotation: Annotation = None):
def add(self, name: str, annotation: T_annotation):
if not self.has_layer(name=name):
self.create_layer(name=name)
self[name].append(annotation)

def create_layer(self, name: str, allow_exists: bool = False) -> AnnotationLayer:
def create_layer(self, name: str, allow_exists: bool = False) -> AnnotationLayer[T_annotation]:
if self.has_layer(name) and not allow_exists:
raise ValueError(f"layer with name {name} already exists")
self[name] = AnnotationLayer()
self[name] = AnnotationLayer[T_annotation]()
return self[name]

@property
def named_layers(self) -> List[Tuple[str, AnnotationLayer]]:
return list(self.items())
def named_layers(self) -> Sequence[Tuple[str, AnnotationLayer[T_annotation]]]:
return [item for item in self.items()]


class AnnotationCollection:
def __init__(self):
self.labels = TypedAnnotationCollection[Label]()
self.spans = TypedAnnotationCollection[LabeledSpan]()
self.binary_relations = TypedAnnotationCollection[BinaryRelation]()

self._types_to_collections = {
Label: self.labels,
LabeledSpan: self.spans,
BinaryRelation: self.binary_relations,
}

def add(self, name: str, annotation: Annotation):
collection = self._types_to_collections.get(type(annotation))
if collection is None:
raise TypeError(f"annotation has unknown type: {type(annotation)}")
collection.add(name=name, annotation=annotation)

@property
def typed_collections(self) -> Sequence[Tuple[Type, TypedAnnotationCollection]]:
return [item for item in self._types_to_collections.items()]

@property
def typed_named_layers(self) -> Sequence[Tuple[Type, str, AnnotationLayer]]:
res = []
for base_type, typed_collection in self.typed_collections:
res.extend(
[(base_type, name, layer) for (name, layer) in typed_collection.named_layers]
)
return res

def __repr__(self) -> str:
return (
f"Document(labels={self.labels}, spans={self.spans}, "
f"binary_relations={self.binary_relations})"
)


class Document:
Expand Down Expand Up @@ -247,8 +251,10 @@ def add_prediction(self, name: str, prediction: Annotation):
self.predictions.add(name=name, annotation=prediction)

def clear_predictions(self, name: str) -> None:
if name in self.predictions:
del self.predictions[name]
# TODO: should we respect the base_type?
for base_type, collection in self.predictions.typed_collections:
if name in collection:
del collection[name]

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -292,13 +298,15 @@ def construct_document(

if spans is not None:
for layer_name, layer_spans in spans.items():
doc.annotations[layer_name] = AnnotationLayer[LabeledSpan](layer_spans)
doc.annotations.spans[layer_name] = AnnotationLayer[LabeledSpan](layer_spans)
if assert_span_text:
for ann in doc.annotations[layer_name]:
for ann in doc.annotations.spans[layer_name]:
_assert_span_text(doc, ann)
if binary_relations is not None:
for layer_name, layer_binary_relations in binary_relations.items():
doc.annotations[layer_name] = AnnotationLayer[BinaryRelation](layer_binary_relations)
doc.annotations.binary_relations[layer_name] = AnnotationLayer[BinaryRelation](
layer_binary_relations
)

return doc

Expand Down
14 changes: 7 additions & 7 deletions src/pytorch_ie/taskmodules/transformer_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def prepare(self, documents: List[Document]) -> None:
entity_labels: Set[str] = set()
relation_labels: Set[str] = set()
for document in documents:
entities = document.annotations[self.entity_annotation].as_spans
relations = document.annotations[self.relation_annotation].as_binary_relations
entities = document.annotations.spans[self.entity_annotation]
relations = document.annotations.binary_relations[self.relation_annotation]

if self.add_type_to_marker:
for entity in entities:
Expand Down Expand Up @@ -286,16 +286,16 @@ def encode_input(
)

for document in documents:
entities = document.annotations[self.entity_annotation].as_spans
if document.annotations.has_layer(self.relation_annotation):
relations = document.annotations[self.relation_annotation].as_binary_relations
entities = document.annotations.spans[self.entity_annotation]
if document.annotations.binary_relations.has_layer(self.relation_annotation):
relations = document.annotations.binary_relations[self.relation_annotation]
else:
relations = None
relation_mapping = {(rel.head, rel.tail): rel.label for rel in relations or []}

partitions: Sequence[Optional[LabeledSpan]]
if self.partition_annotation is not None:
partitions = document.annotations[self.partition_annotation].as_spans
partitions = document.annotations.spans[self.partition_annotation]
else:
# use single dummy partition
partitions = [None]
Expand Down Expand Up @@ -437,7 +437,7 @@ def encode_target(
target: List[TransformerReTextClassificationTargetEncoding] = []
for i, document in enumerate(documents):
meta = metadata[i]
relations = document.annotations[self.relation_annotation].as_binary_relations
relations = document.annotations.binary_relations[self.relation_annotation]

head_tail_to_labels = {
(relation.head, relation.tail): relation.labels for relation in relations
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_ie/taskmodules/transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def encode_input(
)

def document_to_target_string(self, document: Document) -> str:
relations = document.annotations[self.relation_annotation].as_binary_relations
relations = document.annotations.binary_relations[self.relation_annotation]

head_to_relation: Dict[LabeledSpan, List[BinaryRelation]] = {}
for relation in relations:
Expand Down
14 changes: 7 additions & 7 deletions src/pytorch_ie/taskmodules/transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _config(self) -> Dict[str, Any]:
def prepare(self, documents: List[Document]) -> None:
labels = set()
for document in documents:
entities = document.annotations[self.entity_annotation].as_spans
entities = document.annotations.spans[self.entity_annotation]

for entity in entities:
# TODO: labels is a set, use update
Expand All @@ -109,7 +109,7 @@ def encode_input(
expanded_documents = []
for doc in documents:
if self.single_sentence:
partitions = doc.annotations[self.sentence_annotation].as_spans
partitions = doc.annotations.spans[self.sentence_annotation]
else:
partitions = [LabeledSpan(start=0, end=len(doc.text), label="FULL_DOCUMENT")]
for partition in partitions:
Expand Down Expand Up @@ -137,7 +137,7 @@ def encode_input(
i = 0
for document in documents:
for sentence_index in range(
len(document.annotations[self.sentence_annotation].as_spans)
len(document.annotations.spans[self.sentence_annotation])
):
metadata[i]["sentence_index"] = sentence_index
i += 1
Expand All @@ -153,9 +153,9 @@ def encode_target(
target = []
if self.single_sentence:
for i, document in enumerate(documents):
entities = document.annotations[self.entity_annotation].as_spans
entities = document.annotations.spans[self.entity_annotation]
sentence_idx = metadata[i]["sentence_index"]
partitions = document.annotations[self.sentence_annotation].as_spans
partitions = document.annotations.spans[self.sentence_annotation]
assert (
partitions is not None
), f"document has no span annotations with name '{self.sentence_annotation}'"
Expand Down Expand Up @@ -184,7 +184,7 @@ def encode_target(
target.append(label_ids)
else:
for i, document in enumerate(documents):
entities = document.annotations[self.entity_annotation].as_spans
entities = document.annotations.spans[self.entity_annotation]
label_ids = []
for entity in entities:
start_idx = input_encodings[i].char_to_token(entity.start)
Expand Down Expand Up @@ -226,7 +226,7 @@ def create_annotations_from_output(
if self.single_sentence:
document = encoding.document
metadata = encoding.metadata
partitions = document.annotations[self.sentence_annotation].as_spans
partitions = document.annotations.spans[self.sentence_annotation]
sentence = partitions[metadata["sentence_index"]]

# tag_sequence = [
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_ie/taskmodules/transformer_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _config(self) -> Dict[str, Any]:
def prepare(self, documents: List[Document]) -> None:
labels = set()
for document in documents:
annotations = document.annotations[self.annotation].as_labels
annotations = document.annotations.labels[self.annotation]

for annotation in annotations:
# TODO: labels is a set...
Expand Down Expand Up @@ -157,7 +157,7 @@ def encode_target(

target: List[TransformerTextClassificationTargetEncoding] = []
for i, document in enumerate(documents):
annotations = document.annotations[self.annotation].as_labels
annotations = document.annotations.labels[self.annotation]
if self.multi_label:
label_ids = [0] * len(self.label_to_id)
for annotation in annotations:
Expand Down
Loading

0 comments on commit 3e0c861

Please sign in to comment.