-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from ArneBinder/pointer_network_joint_ner_and_re
new task: pointer network for joint ner and re
- Loading branch information
Showing
18 changed files
with
4,851 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .bart_as_pointer_network import BartAsPointerNetwork | ||
from .bart_with_decoder_position_ids import BartModelWithDecoderPositionIds |
470 changes: 470 additions & 0 deletions
470
src/pie_modules/models/base_models/bart_as_pointer_network.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
from transformers.utils import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class PointerHead(torch.nn.Module): | ||
# Copy and generate, | ||
def __init__( | ||
self, | ||
# (decoder) input space | ||
target_token_ids: List[int], | ||
# output space (targets) | ||
bos_id: int, | ||
eos_id: int, | ||
pad_id: int, | ||
# embeddings | ||
embeddings: nn.Embedding, | ||
embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None, | ||
# other parameters | ||
use_encoder_mlp: bool = False, | ||
use_constraints_encoder_mlp: bool = False, | ||
decoder_position_id_pattern: Optional[List[int]] = None, | ||
increase_position_ids_per_record: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.embeddings = embeddings | ||
|
||
self.pointer_offset = len(target_token_ids) | ||
|
||
# check that bos, eos, and pad are not out of bounds | ||
for target_id, target_id_name in zip( | ||
[bos_id, eos_id, pad_id], ["bos_id", "eos_id", "pad_id"] | ||
): | ||
if target_id >= len(target_token_ids): | ||
raise ValueError( | ||
f"{target_id_name} [{target_id}] must be smaller than the number of target token ids " | ||
f"[{len(target_token_ids)}]!" | ||
) | ||
|
||
self.bos_id = bos_id | ||
self.eos_id = eos_id | ||
self.pad_id = pad_id | ||
# all ids that are not bos, eos or pad are label ids | ||
self.label_ids = [ | ||
target_id | ||
for target_id in range(len(target_token_ids)) | ||
if target_id not in [self.bos_id, self.eos_id, self.pad_id] | ||
] | ||
|
||
target2token_id = torch.LongTensor(target_token_ids) | ||
self.register_buffer("target2token_id", target2token_id) | ||
self.label_token_ids = self.target2token_id[self.label_ids] | ||
self.eos_token_id = target_token_ids[self.eos_id] | ||
self.pad_token_id = target_token_ids[self.pad_id] | ||
|
||
hidden_size = self.embeddings.embedding_dim | ||
if use_encoder_mlp: | ||
self.encoder_mlp = nn.Sequential( | ||
nn.Linear(hidden_size, hidden_size), | ||
nn.Dropout(0.3), | ||
nn.ReLU(), | ||
nn.Linear(hidden_size, hidden_size), | ||
) | ||
if use_constraints_encoder_mlp: | ||
self.constraints_encoder_mlp = nn.Sequential( | ||
nn.Linear(hidden_size, hidden_size), | ||
nn.Dropout(0.3), | ||
nn.ReLU(), | ||
nn.Linear(hidden_size, hidden_size), | ||
) | ||
|
||
self.embedding_weight_mapping = None | ||
if embedding_weight_mapping is not None: | ||
# Because of config serialization, the keys may be strings. Convert them back to ints. | ||
self.embedding_weight_mapping = { | ||
int(k): v for k, v in embedding_weight_mapping.items() | ||
} | ||
|
||
if decoder_position_id_pattern is not None: | ||
self.register_buffer( | ||
"decoder_position_id_pattern", torch.tensor(decoder_position_id_pattern) | ||
) | ||
self.increase_position_ids_per_record = increase_position_ids_per_record | ||
|
||
@property | ||
def use_prepared_position_ids(self): | ||
return hasattr(self, "decoder_position_id_pattern") | ||
|
||
def set_embeddings(self, embedding: nn.Embedding) -> None: | ||
self.embeddings = embedding | ||
|
||
def overwrite_embeddings_with_mapping(self) -> None: | ||
"""Overwrite individual embeddings with embeddings for other tokens. | ||
This is useful, for instance, if the label vocabulary is a subset of the source vocabulary. | ||
In this case, this method can be used to initialize each label embedding with one or | ||
multiple (averaged) source embeddings. | ||
""" | ||
if self.embedding_weight_mapping is not None: | ||
for special_token_index, source_indices in self.embedding_weight_mapping.items(): | ||
self.embeddings.weight.data[special_token_index] = self.embeddings.weight.data[ | ||
source_indices | ||
].mean(dim=0) | ||
|
||
def prepare_decoder_input_ids( | ||
self, | ||
input_ids: torch.LongTensor, | ||
encoder_input_ids: torch.LongTensor, | ||
) -> torch.LongTensor: | ||
mapping_token_mask = input_ids.lt(self.pointer_offset) | ||
mapped_tokens = input_ids.masked_fill(input_ids.ge(self.pointer_offset), 0) | ||
tag_mapped_tokens = self.target2token_id[mapped_tokens] | ||
|
||
encoder_input_ids_index = input_ids - self.pointer_offset | ||
encoder_input_ids_index = encoder_input_ids_index.masked_fill( | ||
encoder_input_ids_index.lt(0), 0 | ||
) | ||
encoder_input_length = encoder_input_ids.size(1) | ||
if encoder_input_ids_index.max() >= encoder_input_length: | ||
raise ValueError( | ||
f"encoder_input_ids_index.max() [{encoder_input_ids_index.max()}] must be smaller " | ||
f"than encoder_input_length [{encoder_input_length}]!" | ||
) | ||
|
||
word_mapped_tokens = encoder_input_ids.gather(index=encoder_input_ids_index, dim=1) | ||
|
||
decoder_input_ids = torch.where( | ||
mapping_token_mask, tag_mapped_tokens, word_mapped_tokens | ||
).to(torch.long) | ||
|
||
# Note: we do not need to explicitly handle the padding (via a decoder attention mask) because | ||
# it gets automatically mapped to the pad token id | ||
|
||
return decoder_input_ids | ||
|
||
def prepare_decoder_position_ids( | ||
self, | ||
input_ids: torch.LongTensor, | ||
# will be used to create the padding mask from the input_ids. Needs to be provided because | ||
# the input_ids may be in token space or target space. | ||
pad_input_id: int, | ||
) -> torch.LongTensor: | ||
bsz, tokens_len = input_ids.size() | ||
pattern_len = len(self.decoder_position_id_pattern) | ||
# the number of full and partly records. note that tokens_len includes the bos token | ||
repeat_num = (tokens_len - 2) // pattern_len + 1 | ||
position_ids = self.decoder_position_id_pattern.repeat(bsz, repeat_num) | ||
|
||
if self.increase_position_ids_per_record: | ||
position_ids_reshaped = position_ids.view(bsz, -1, pattern_len) | ||
add_shift_pos = ( | ||
torch.range(0, repeat_num - 1, device=position_ids_reshaped.device) | ||
.repeat(bsz) | ||
.view(bsz, -1) | ||
.unsqueeze(-1) | ||
) | ||
# multiply by the highest position id in the pattern so that the position ids are unique | ||
# for any decoder_position_id_pattern across all records | ||
add_shift_pos *= max(self.decoder_position_id_pattern) + 1 | ||
position_ids_reshaped = add_shift_pos + position_ids_reshaped | ||
position_ids = position_ids_reshaped.view(bsz, -1).long() | ||
# use start_position_id=0 | ||
start_pos = torch.zeros(bsz, 1, dtype=position_ids.dtype, device=position_ids.device) | ||
# shift by 2 to account for start_position_id=0 and pad_position_id=1 | ||
all_position_ids = torch.cat([start_pos, position_ids + 2], dim=-1) | ||
all_position_ids_truncated = all_position_ids[:bsz, :tokens_len] | ||
|
||
# mask the padding tokens | ||
mask_invalid = input_ids.eq(pad_input_id) | ||
all_position_ids_truncated_masked = all_position_ids_truncated.masked_fill(mask_invalid, 1) | ||
|
||
return all_position_ids_truncated_masked | ||
|
||
def prepare_decoder_inputs( | ||
self, | ||
input_ids: torch.LongTensor, | ||
encoder_input_ids: torch.LongTensor, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
) -> Dict[str, torch.Tensor]: | ||
inputs = {} | ||
if self.use_prepared_position_ids: | ||
if position_ids is None: | ||
position_ids = self.prepare_decoder_position_ids( | ||
# the input_ids are in the target space, so we provide pointer_head.pad_id as the pad_token_id | ||
input_ids=input_ids, | ||
pad_input_id=self.pad_id, | ||
) | ||
inputs["position_ids"] = position_ids | ||
|
||
inputs["input_ids"] = self.prepare_decoder_input_ids( | ||
input_ids=input_ids, | ||
encoder_input_ids=encoder_input_ids, | ||
) | ||
return inputs | ||
|
||
def forward( | ||
self, | ||
last_hidden_state, | ||
encoder_input_ids, | ||
encoder_last_hidden_state, | ||
encoder_attention_mask, | ||
labels: Optional[torch.LongTensor] = None, | ||
decoder_attention_mask: Optional[torch.LongTensor] = None, | ||
constraints: Optional[torch.LongTensor] = None, | ||
): | ||
# assemble the logits | ||
logits = last_hidden_state.new_full( | ||
( | ||
last_hidden_state.size(0), | ||
last_hidden_state.size(1), | ||
self.pointer_offset + encoder_input_ids.size(-1), | ||
), | ||
fill_value=-1e24, | ||
) | ||
|
||
# eos and label scores depend only on the decoder output | ||
# bsz x max_len x 1 | ||
eos_scores = F.linear(last_hidden_state, self.embeddings.weight[[self.eos_token_id]]) | ||
label_embeddings = self.embeddings.weight[self.label_token_ids] | ||
# bsz x max_len x num_class | ||
label_scores = F.linear(last_hidden_state, label_embeddings) | ||
|
||
# the pointer depends on the src token embeddings, the encoder output and the decoder output | ||
# bsz x max_bpe_len x hidden_size | ||
src_outputs = encoder_last_hidden_state | ||
if getattr(self, "encoder_mlp", None) is not None: | ||
src_outputs = self.encoder_mlp(src_outputs) | ||
|
||
# bsz x max_word_len x hidden_size | ||
input_embed = self.embeddings(encoder_input_ids) | ||
|
||
# bsz x max_len x max_word_len | ||
word_scores = torch.einsum("blh,bnh->bln", last_hidden_state, src_outputs) | ||
gen_scores = torch.einsum("blh,bnh->bln", last_hidden_state, input_embed) | ||
avg_word_scores = (gen_scores + word_scores) / 2 | ||
|
||
# never point to the padding or the eos token in the encoder input | ||
# TODO: why not excluding the bos token? seems to give worse results, but not tested extensively | ||
mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id) | ||
avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32) | ||
|
||
# Note: the remaining row in logits contains the score for the bos token which should be never generated! | ||
logits[:, :, [self.eos_id]] = eos_scores | ||
logits[:, :, self.label_ids] = label_scores | ||
logits[:, :, self.pointer_offset :] = avg_word_scores | ||
|
||
loss = None | ||
# compute the loss if labels are provided | ||
if labels is not None: | ||
loss_fct = CrossEntropyLoss() | ||
logits_resized = logits.reshape(-1, logits.size(-1)) | ||
labels_resized = labels.reshape(-1) | ||
if decoder_attention_mask is None: | ||
raise ValueError("decoder_attention_mask must be provided to compute the loss!") | ||
mask_resized = decoder_attention_mask.reshape(-1) | ||
labels_masked = labels_resized.masked_fill( | ||
~mask_resized.to(torch.bool), loss_fct.ignore_index | ||
) | ||
loss = loss_fct(logits_resized, labels_masked) | ||
|
||
# compute the constraints loss if constraints are provided | ||
if constraints is not None: | ||
if getattr(self, "constraints_encoder_mlp", None) is not None: | ||
# TODO: is it fine to apply constraints_encoder_mlp to both src_outputs and label_embeddings? | ||
# This is what the original code seems to do, but this is different from the usage of encoder_mlp. | ||
constraints_src_outputs = self.constraints_encoder_mlp(src_outputs) | ||
constraints_label_embeddings = self.constraints_encoder_mlp(label_embeddings) | ||
else: | ||
constraints_src_outputs = src_outputs | ||
constraints_label_embeddings = label_embeddings | ||
constraints_label_scores = F.linear(last_hidden_state, constraints_label_embeddings) | ||
# bsz x max_len x max_word_len | ||
constraints_word_scores = torch.einsum( | ||
"blh,bnh->bln", last_hidden_state, constraints_src_outputs | ||
) | ||
constraints_logits = last_hidden_state.new_full( | ||
( | ||
last_hidden_state.size(0), | ||
last_hidden_state.size(1), | ||
self.pointer_offset + encoder_input_ids.size(-1), | ||
), | ||
fill_value=-1e24, | ||
) | ||
constraints_logits[:, :, self.label_ids] = constraints_label_scores | ||
constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores | ||
|
||
mask = constraints >= 0 | ||
constraints_logits_valid = constraints_logits[mask] | ||
constraints_valid = constraints[mask] | ||
loss_c = F.binary_cross_entropy( | ||
torch.sigmoid(constraints_logits_valid), constraints_valid.float() | ||
) | ||
|
||
if loss is None: | ||
loss = loss_c | ||
else: | ||
loss += loss_c | ||
|
||
return logits, loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .interfaces import AnnotationEncoderDecoder, DecodingException | ||
from .mixins import BatchableMixin | ||
from .utils import get_first_occurrence_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import abc | ||
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar | ||
|
||
from pytorch_ie import Annotation | ||
|
||
# Annotation Encoding type: encoding for a single annotation | ||
AE = TypeVar("AE") | ||
# Annotation type | ||
A = TypeVar("A", bound=Annotation) | ||
# Annotation Collection Encoding type: encoding for a collection of annotations, | ||
# e.g. all relevant annotations for a document | ||
ACE = TypeVar("ACE") | ||
|
||
|
||
class DecodingException(Exception, Generic[AE], abc.ABC): | ||
"""Exception raised when decoding fails.""" | ||
|
||
identifier: str | ||
|
||
def __init__(self, message: str, encoding: AE): | ||
self.message = message | ||
self.encoding = encoding | ||
|
||
|
||
class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]): | ||
"""Base class for annotation encoders and decoders.""" | ||
|
||
@abc.abstractmethod | ||
def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A: | ||
pass |
Empty file.
Oops, something went wrong.