Skip to content

Commit

Permalink
simplify transformer_token_classification_taskmodule (#71)
Browse files Browse the repository at this point in the history
* create separate (prepared_)taskmodule_with_partition instead of modifying it inplace

* move from parameters single_sentence and sentence_annotation to partition_annotation (backwards compatible)

* simplify encode_input

* simplify encode_target; outsource _convert_span_annotations_to_tag_sequence() and _encode_text() to ease testing

* fix test_convert_span_annotations_to_tag_sequence_with_partition()

* raise exception when partitioning is enabled in the taskmodule, but not partition is provided to _encode_text

* simplify create_annotations_from_output

* create tests for create_annotations_from_output

* add documentation for _convert_span_annotations_to_tag_sequence

* remove backwards compatibility for partition_annotation (single_sentence + sentence_annotation)

* don't use leading underscores if not necessary

Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
ArneBinder and ArneBinder authored Feb 23, 2022
1 parent 81d4bf4 commit a83f179
Show file tree
Hide file tree
Showing 2 changed files with 503 additions and 144 deletions.
249 changes: 110 additions & 139 deletions pytorch_ie/taskmodules/transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,47 @@
logger = logging.getLogger(__name__)


def convert_span_annotations_to_tag_sequence(
spans: List[LabeledSpan], encoding: BatchEncoding, partition: Optional[LabeledSpan] = None
) -> Sequence[Optional[str]]:
"""
Given a list of span annotations and an encoding (tokenizer output), create a sequence of tags with the
length of the number of tokens in the encoding. At positions where the tokens are None, None is returned as tag.
If a partition is provided, only the tokens within that span are considered.
For now, the BIO-encoding is used.
Note: The spans are not allowed to overlap (will raise an exception).
"""
word_ids = encoding.word_ids()
tag_sequence = [None if word_ids[j] is None else "O" for j in range(len(word_ids))]
offset = partition.start if partition is not None else 0
for span in spans:
if partition is not None and (span.start < partition.start or span.end > partition.end):
continue

start_idx = encoding.char_to_token(span.start - offset)
end_idx = encoding.char_to_token(span.end - 1 - offset)
if start_idx is None or end_idx is None:
logger.warning(
f"Entity annotation does not start or end with a token, it will be skipped: {span}"
)
continue

for j in range(start_idx, end_idx + 1):
if tag_sequence[j] != "O":
# TODO: is ValueError a good exception type for this?
raise ValueError(f"tag already assigned (current span has an overlap: {span})")
prefix = "B" if j == start_idx else "I"
tag_sequence[j] = f"{prefix}-{span.label_single}"

return tag_sequence


class TransformerTokenClassificationTaskModule(_TransformerTokenClassificationTaskModule):
def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
single_sentence: bool = False,
sentence_annotation: str = "sentences",
partition_annotation: Optional[str] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
Expand All @@ -62,8 +96,7 @@ def __init__(

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.entity_annotation = entity_annotation
self.single_sentence = single_sentence
self.sentence_annotation = sentence_annotation
self.partition_annotation = partition_annotation
self.label_to_id = label_to_id or {}
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
self.padding = padding
Expand All @@ -86,10 +119,7 @@ def prepare(self, documents: List[Document]) -> None:
), f"document has no span annotations with name '{self.entity_annotation}'"

for entity in entities:
# TODO: labels is a set...
for label in entity.labels:
if label not in labels:
labels.add(label)
labels.update(entity.labels)

self.label_to_id["O"] = 0
current_id = 1
Expand All @@ -100,55 +130,53 @@ def prepare(self, documents: List[Document]) -> None:

self.id_to_label = {v: k for k, v in self.label_to_id.items()}

def encode_text(self, text, partition: Optional[LabeledSpan] = None):
if self.partition_annotation is not None and partition is None:
raise ValueError(f"partitioning is enabled, but no partition is provided")
text_partition = text[partition.start : partition.end] if partition is not None else text
return self.tokenizer(
text_partition,
padding=False,
truncation=False,
max_length=None,
is_split_into_words=False,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)

def encode_input(
self, documents: List[Document]
) -> Tuple[
List[TransformerTokenClassificationInputEncoding],
List[Metadata],
Optional[List[Document]],
]:
metadata = []
expanded_documents = []
input_ = []
for doc in documents:
if self.single_sentence:
partitions_or_none = doc.span_annotations(self.sentence_annotation)
partitions: Sequence[Optional[LabeledSpan]]
if self.partition_annotation is not None:
partitions_or_none = doc.span_annotations(self.partition_annotation)
assert (
partitions_or_none
), f"document has no span annotations with name '{self.sentence_annotation}'"
), f"document has no span annotations with name '{self.partition_annotation}'"
partitions = partitions_or_none
else:
partitions = [LabeledSpan(start=0, end=len(doc.text), label="FULL_DOCUMENT")]

for partition in partitions:
encoding = self.tokenizer(
doc.text[partition.start : partition.end],
padding=False,
truncation=False,
max_length=None,
is_split_into_words=False,
return_offsets_mapping=True,
return_special_tokens_mask=True,
)
partitions = [None]

for partition_index, partition in enumerate(partitions):
encoding = self.encode_text(text=doc.text, partition=partition)
current_metadata = {
"offset_mapping": encoding.pop("offset_mapping"),
"special_tokens_mask": encoding.pop("special_tokens_mask"),
}
if partition is not None:
current_metadata["sentence_index"] = partition_index
metadata.append(current_metadata)
input_.append(encoding)
expanded_documents.append(doc)

metadata = [
{
"offset_mapping": inp.pop("offset_mapping"),
"special_tokens_mask": inp.pop("special_tokens_mask"),
}
for inp in input_
]

if self.single_sentence:
i = 0
for document in documents:
for sentence_index in range(
len(document.span_annotations(self.sentence_annotation) or [])
):
metadata[i]["sentence_index"] = sentence_index
i += 1

return input_, metadata, expanded_documents

def encode_target(
Expand All @@ -158,68 +186,27 @@ def encode_target(
metadata: List[Metadata],
) -> List[TransformerTokenClassificationTargetEncoding]:
target = []
if self.single_sentence:
for i, document in enumerate(documents):
entities = document.span_annotations(self.entity_annotation)
assert (
entities
), f"document has no span annotations with name '{self.entity_annotation}'"
sentence_idx = metadata[i]["sentence_index"]
partitions = document.span_annotations(self.sentence_annotation)
for i, document in enumerate(documents):
entities = document.span_annotations(self.entity_annotation)
assert (
entities
), f"document has no span annotations with name '{self.entity_annotation}'"
partition = None
if self.partition_annotation is not None:
partition_index = metadata[i]["sentence_index"]
partitions = document.span_annotations(self.partition_annotation)
assert (
partitions
), f"document has no span annotations with name '{self.sentence_annotation}'"
sentence = partitions[sentence_idx]

word_ids = input_encodings[i].word_ids()
label_ids = [
self.label_pad_token_id if word_ids[j] is None else self.label_to_id["O"]
for j in range(len(word_ids))
]

entity: LabeledSpan
for entity in entities:
if entity.start < sentence.start or entity.end > sentence.end:
continue

entity_start = entity.start - sentence.start
entity_end = entity.end - sentence.start

start_idx = input_encodings[i].char_to_token(entity_start)
end_idx = input_encodings[i].char_to_token(entity_end - 1)
# TODO: remove this is if case
if start_idx is None or end_idx is None:
logger.warning(
f"Entity annotation does not start or end with a token, it will be skipped: {entity}"
)
continue

for j in range(start_idx, end_idx + 1):
prefix = "B" if j == start_idx else "I"
label_ids[j] = self.label_to_id[f"{prefix}-{entity.label_single}"]

target.append(label_ids)
else:
for i, document in enumerate(documents):
word_ids = input_encodings[i].word_ids()
label_ids = [
self.label_pad_token_id if word_ids[j] is None else self.label_to_id["O"]
for j in range(len(word_ids))
]

entities = document.span_annotations(self.entity_annotation)
assert (
entities
), f"document has no span annotations with name '{self.entity_annotation}'"

for entity in entities:
start_idx = input_encodings[i].char_to_token(entity.start)
end_idx = input_encodings[i].char_to_token(entity.end - 1)
for j in range(start_idx, end_idx + 1):
prefix = "B" if j == start_idx else "I"
label_ids[j] = self.label_to_id[f"{prefix}-{entity.label_single}"]

target.append(label_ids)
), f"document has no span annotations with name '{self.partition_annotation}'"
partition = partitions[partition_index]
tag_sequence = convert_span_annotations_to_tag_sequence(
spans=entities, encoding=input_encodings[i], partition=partition
)
label_ids = [
self.label_to_id[tag] if tag is not None else self.label_pad_token_id
for tag in tag_sequence
]
target.append(label_ids)

return target

Expand All @@ -237,48 +224,32 @@ def create_annotations_from_output(
encoding: TransformerTokenClassificationTaskEncoding,
output: TransformerTokenClassificationTaskOutput,
) -> Iterator[Tuple[str, Annotation]]:
if self.single_sentence:
document = encoding.document
metadata = encoding.metadata
partitions = document.span_annotations(self.sentence_annotation)

offset = 0
if self.partition_annotation is not None:
partitions = encoding.document.span_annotations(self.partition_annotation)
assert (
partitions
), f"document has no span annotations with name '{self.sentence_annotation}'"
sentence = partitions[metadata["sentence_index"]]

tag_sequence = [
"O" if stm else tag
for tag, stm in zip(output["tags"], metadata["special_tokens_mask"])
]

spans = bio_tags_to_spans(tag_sequence)
for label, (start, end) in spans:
yield (
self.entity_annotation,
LabeledSpan(
sentence.start + metadata["offset_mapping"][start][0],
sentence.start + metadata["offset_mapping"][end][1],
label,
),
)
else:
metadata = encoding.metadata

tag_sequence = [
"O" if stm else tag
for tag, stm in zip(output["tags"], metadata["special_tokens_mask"])
]
), f"document has no span annotations with name '{self.partition_annotation}'"
offset = partitions[encoding.metadata["sentence_index"]].start

tag_sequence = [
"O" if is_special_token else tag
for tag, is_special_token in zip(
output["tags"], encoding.metadata["special_tokens_mask"]
)
]

spans = bio_tags_to_spans(tag_sequence)
for label, (start, end) in spans:
yield (
self.entity_annotation,
LabeledSpan(
metadata["offset_mapping"][start][0],
metadata["offset_mapping"][end][1],
label,
),
)
spans = bio_tags_to_spans(tag_sequence)
for label, (start, end) in spans:
yield (
self.entity_annotation,
LabeledSpan(
encoding.metadata["offset_mapping"][start][0] + offset,
encoding.metadata["offset_mapping"][end][1] + offset,
label,
),
)

def collate(
self, encodings: List[TransformerTokenClassificationTaskEncoding]
Expand Down
Loading

0 comments on commit a83f179

Please sign in to comment.