From 2aebdb162e62ed227d7a81b941f8a94dbf005dd3 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 20:53:12 +0100 Subject: [PATCH 01/10] implement BartModelWithDecoderPositionIds --- .../models/base_models/__init__.py | 0 .../bart_with_decoder_position_ids.py | 526 ++++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 src/pie_modules/models/base_models/__init__.py create mode 100644 src/pie_modules/models/base_models/bart_with_decoder_position_ids.py diff --git a/src/pie_modules/models/base_models/__init__.py b/src/pie_modules/models/base_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py new file mode 100644 index 000000000..e0caa9e93 --- /dev/null +++ b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py @@ -0,0 +1,526 @@ +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model, but the decoder accepts predefined position ids. If not provided, the +original logic is used to create the position ids. + +The model is based on the BartModel from Transformers 4.35.0, +i.e. https://github.com/huggingface/transformers/blob/v4.35.0/src/transformers/models/bart/modeling_bart.py. + +Note: This also contains some minor modifications to make the code mypy (v1.4.1) compliant. +. +""" +import math +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqModelOutput, +) +from transformers.models.bart import BartConfig +from transformers.models.bart.modeling_bart import ( + _CHECKPOINT_FOR_DOC, + _CONFIG_FOR_DOC, + _EXPECTED_OUTPUT_SHAPE, + BART_INPUTS_DOCSTRING, + BART_START_DOCSTRING, + BartDecoderLayer, + BartEncoder, + BartPreTrainedModel, + shift_tokens_right, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +logger = logging.get_logger(__name__) + + +class BartLearnedPositionalEmbeddingWithPositionIds(nn.Embedding): + """This module learns positional embeddings up to a fixed maximum size.""" + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + input_ids: torch.Tensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.Tensor] = None, + ): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange( + past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device, + ).expand(bsz, -1) + else: + positions = position_ids + + return super().forward(positions + self.offset) + + +class BartDecoderWithPositionIds(BartPreTrainedModel): + """Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + [`BartDecoderLayer`] + + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbeddingWithPositionIds( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList( + [BartDecoderLayer(config) for _ in range(config.decoder_layers)] + ) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) * self.embed_scale + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = ( + attention_mask if (attention_mask is not None and 0 in attention_mask) else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if getattr(self.config, "_flash_attn_2_enabled", False): + encoder_attention_mask = ( + encoder_attention_mask if 0 in encoder_attention_mask else None + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length, position_ids) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None + all_self_attns: Optional[Tuple[Any, ...]] = () if output_attentions else None + all_cross_attentions: Optional[Tuple[Any, ...]] = ( + () if (output_attentions and encoder_hidden_states is not None) else None + ) + next_decoder_cache: Optional[Tuple[Any, ...]] = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if next_decoder_cache is not None: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if all_self_attns is not None: + all_self_attns += (layer_outputs[1],) + + if all_cross_attentions is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare BART Model outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class BartModelWithDecoderPositionIds(BartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: BartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoderWithPositionIds(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + # different to other models, Bart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + if not isinstance(encoder_outputs, BaseModelOutput): + raise ValueError( + "Inconsistent output: The output of the model encoder should be of type " + f"`BaseModelOutputWithPastAndCrossAttentions`, but is of type `{type(encoder_outputs)}`." + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) From 0d9a359f2ce49d904a9a6244ea1513154686f785 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 21:31:25 +0100 Subject: [PATCH 02/10] add test_bart_learned_positional_embedding_with_position_ids() --- tests/models/base_models/__init__.py | 0 .../test_bart_with_decoder_position_ids.py | 26 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 tests/models/base_models/__init__.py create mode 100644 tests/models/base_models/test_bart_with_decoder_position_ids.py diff --git a/tests/models/base_models/__init__.py b/tests/models/base_models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py new file mode 100644 index 000000000..8a0175132 --- /dev/null +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -0,0 +1,26 @@ +import torch + +from pie_modules.models.base_models.bart_with_decoder_position_ids import ( + BartLearnedPositionalEmbeddingWithPositionIds, +) + + +def test_bart_learned_positional_embedding_with_position_ids(): + # Arrange + torch.manual_seed(42) + model = BartLearnedPositionalEmbeddingWithPositionIds(10, 6) + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) + position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) + position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 2]]) + + # Act + original = model(input_ids=input_ids) + replaced_original = model(input_ids=input_ids, position_ids=position_ids_original) + replaced_different = model(input_ids=input_ids, position_ids=position_ids_different) + + # Assert + assert original.shape == (1, 10, 6) + assert replaced_original.shape == (1, 10, 6) + torch.testing.assert_close(original, replaced_original) + assert replaced_different.shape == (1, 10, 6) + assert not torch.allclose(original, replaced_different) From 81d2c9f24f98e6632c843176621b0f31a54d2da7 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 21:48:26 +0100 Subject: [PATCH 03/10] add test_bart_decoder_with_position_ids(), test_bart_decoder_with_position_ids_get_input_embeddings(), and test_set_input_embeddings() --- .../test_bart_with_decoder_position_ids.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index 8a0175132..65385ad0f 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -1,6 +1,11 @@ +import pytest import torch +from torch.nn import Embedding +from transformers import BartConfig +from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding from pie_modules.models.base_models.bart_with_decoder_position_ids import ( + BartDecoderWithPositionIds, BartLearnedPositionalEmbeddingWithPositionIds, ) @@ -24,3 +29,49 @@ def test_bart_learned_positional_embedding_with_position_ids(): torch.testing.assert_close(original, replaced_original) assert replaced_different.shape == (1, 10, 6) assert not torch.allclose(original, replaced_different) + + +@pytest.fixture(scope="module") +def bart_config(): + return BartConfig( + vocab_size=30, + d_model=10, + encoder_layers=1, + decoder_layers=1, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=20, + decoder_ffn_dim=20, + max_position_embeddings=10, + ) + + +@pytest.fixture(scope="module") +def bart_decoder_with_position_ids(bart_config): + return BartDecoderWithPositionIds(config=bart_config) + + +def test_bart_decoder_with_position_ids(bart_decoder_with_position_ids): + assert bart_decoder_with_position_ids is not None + + +def test_bart_decoder_with_position_ids_get_input_embeddings(bart_decoder_with_position_ids): + input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() + assert input_embeddings is not None + assert isinstance(input_embeddings, Embedding) + assert input_embeddings.embedding_dim == 10 + assert input_embeddings.num_embeddings == 30 + + +def test_set_input_embeddings(bart_decoder_with_position_ids): + original_input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() + torch.manual_seed(42) + new_input_embeddings = Embedding( + original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim + ) + bart_decoder_with_position_ids.set_input_embeddings(new_input_embeddings) + input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() + assert input_embeddings == new_input_embeddings + assert input_embeddings is not original_input_embeddings + # recover original input embeddings + bart_decoder_with_position_ids.set_input_embeddings(original_input_embeddings) From 13e240ad1dcbd62745c040389849392b70293ea6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 21:49:17 +0100 Subject: [PATCH 04/10] fix naming --- tests/models/base_models/test_bart_with_decoder_position_ids.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index 65385ad0f..7e6915429 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -63,7 +63,7 @@ def test_bart_decoder_with_position_ids_get_input_embeddings(bart_decoder_with_p assert input_embeddings.num_embeddings == 30 -def test_set_input_embeddings(bart_decoder_with_position_ids): +def test_bart_decoder_with_position_ids_set_input_embeddings(bart_decoder_with_position_ids): original_input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() torch.manual_seed(42) new_input_embeddings = Embedding( From 1de31965dc0cf9c6b7c513412280aa0631d1dbda Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 22:00:44 +0100 Subject: [PATCH 05/10] add test_bart_decoder_with_position_ids_forward() --- .../test_bart_with_decoder_position_ids.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index 7e6915429..dd44ff00e 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -2,6 +2,7 @@ import torch from torch.nn import Embedding from transformers import BartConfig +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding from pie_modules.models.base_models.bart_with_decoder_position_ids import ( @@ -75,3 +76,29 @@ def test_bart_decoder_with_position_ids_set_input_embeddings(bart_decoder_with_p assert input_embeddings is not original_input_embeddings # recover original input embeddings bart_decoder_with_position_ids.set_input_embeddings(original_input_embeddings) + + +def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): + # Arrange + model = bart_decoder_with_position_ids + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) + + # Act + torch.manual_seed(42) + original = model(input_ids=input_ids) + torch.manual_seed(42) + replaced_original = model(input_ids=input_ids, position_ids=position_ids_original) + torch.manual_seed(42) + replaced_different = model(input_ids=input_ids, position_ids=position_ids_different) + + # Assert + assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions) + assert original.last_hidden_state.shape == (1, 8, 10) + assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions) + torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) + + assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions) + assert replaced_different.last_hidden_state.shape == (1, 8, 10) + assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) From f4135fda6440f1cca22cd162e8a595235a48e0f5 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 22:08:17 +0100 Subject: [PATCH 06/10] add tests for BartModelWithDecoderPositionIds --- .../test_bart_with_decoder_position_ids.py | 84 ++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index dd44ff00e..ebd0e8828 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -2,12 +2,19 @@ import torch from torch.nn import Embedding from transformers import BartConfig -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqModelOutput, +) +from transformers.models.bart.modeling_bart import ( + BartEncoder, + BartLearnedPositionalEmbedding, +) from pie_modules.models.base_models.bart_with_decoder_position_ids import ( BartDecoderWithPositionIds, BartLearnedPositionalEmbeddingWithPositionIds, + BartModelWithDecoderPositionIds, ) @@ -102,3 +109,76 @@ def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions) assert replaced_different.last_hidden_state.shape == (1, 8, 10) assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) + + +@pytest.fixture(scope="module") +def bart_model_with_decoder_position_ids(bart_config): + return BartModelWithDecoderPositionIds(config=bart_config) + + +def test_bart_model_with_decoder_position_ids(bart_model_with_decoder_position_ids): + assert bart_model_with_decoder_position_ids is not None + + +def test_bart_model_with_decoder_position_ids_get_input_embeddings( + bart_model_with_decoder_position_ids, +): + input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() + assert input_embeddings is not None + assert isinstance(input_embeddings, Embedding) + assert input_embeddings.embedding_dim == 10 + assert input_embeddings.num_embeddings == 30 + + +def test_bart_model_with_decoder_position_ids_set_input_embeddings( + bart_model_with_decoder_position_ids, +): + original_input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() + torch.manual_seed(42) + new_input_embeddings = Embedding( + original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim + ) + bart_model_with_decoder_position_ids.set_input_embeddings(new_input_embeddings) + input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() + assert input_embeddings == new_input_embeddings + assert input_embeddings is not original_input_embeddings + # recover original input embeddings + bart_model_with_decoder_position_ids.set_input_embeddings(original_input_embeddings) + + +def test_bart_model_with_decoder_position_ids_get_encoder(bart_model_with_decoder_position_ids): + encoder = bart_model_with_decoder_position_ids.get_encoder() + assert encoder is not None + assert isinstance(encoder, BartEncoder) + + +def test_bart_model_with_decoder_position_ids_get_decoder(bart_model_with_decoder_position_ids): + decoder = bart_model_with_decoder_position_ids.get_decoder() + assert decoder is not None + assert isinstance(decoder, BartDecoderWithPositionIds) + + +def test_bart_model_with_decoder_position_forward(bart_model_with_decoder_position_ids): + # Arrange + model = bart_model_with_decoder_position_ids + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) + + # Act + torch.manual_seed(42) + original = model(input_ids=input_ids) + torch.manual_seed(42) + replaced_original = model(input_ids=input_ids, decoder_position_ids=position_ids_original) + torch.manual_seed(42) + replaced_different = model(input_ids=input_ids, decoder_position_ids=position_ids_different) + + # Assert + assert isinstance(original, Seq2SeqModelOutput) + assert original.last_hidden_state.shape == (1, 8, 10) + assert isinstance(replaced_original, Seq2SeqModelOutput) + torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) + + assert isinstance(replaced_different, Seq2SeqModelOutput) + assert replaced_different.last_hidden_state.shape == (1, 8, 10) + assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) From 62c137e357bee9fa6e0f3cf8ef1fd1e535946f16 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 22:38:56 +0100 Subject: [PATCH 07/10] check actual logits --- .../bart_with_decoder_position_ids.py | 6 +- .../test_bart_with_decoder_position_ids.py | 91 ++++++++++++++++--- 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py index e0caa9e93..f86169f5e 100644 --- a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py +++ b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py @@ -488,10 +488,12 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - if not isinstance(encoder_outputs, BaseModelOutput): + if not ( + isinstance(encoder_outputs, BaseModelOutput) or isinstance(encoder_outputs, tuple) + ): raise ValueError( "Inconsistent output: The output of the model encoder should be of type " - f"`BaseModelOutputWithPastAndCrossAttentions`, but is of type `{type(encoder_outputs)}`." + f"`BaseModelOutput` or tuple, but is of type `{type(encoder_outputs)}`." ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index ebd0e8828..e1e7a12ed 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -113,6 +113,7 @@ def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): @pytest.fixture(scope="module") def bart_model_with_decoder_position_ids(bart_config): + torch.manual_seed(42) return BartModelWithDecoderPositionIds(config=bart_config) @@ -158,7 +159,13 @@ def test_bart_model_with_decoder_position_ids_get_decoder(bart_model_with_decode assert isinstance(decoder, BartDecoderWithPositionIds) -def test_bart_model_with_decoder_position_forward(bart_model_with_decoder_position_ids): +@pytest.mark.parametrize( + "return_dict", + [True, False], +) +def test_bart_model_with_decoder_position_forward( + bart_model_with_decoder_position_ids, return_dict +): # Arrange model = bart_model_with_decoder_position_ids input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) @@ -167,18 +174,80 @@ def test_bart_model_with_decoder_position_forward(bart_model_with_decoder_positi # Act torch.manual_seed(42) - original = model(input_ids=input_ids) + original = model(input_ids=input_ids, return_dict=return_dict)[0] torch.manual_seed(42) - replaced_original = model(input_ids=input_ids, decoder_position_ids=position_ids_original) + replaced_original = model( + input_ids=input_ids, decoder_position_ids=position_ids_original, return_dict=return_dict + )[0] torch.manual_seed(42) - replaced_different = model(input_ids=input_ids, decoder_position_ids=position_ids_different) + replaced_different = model( + input_ids=input_ids, decoder_position_ids=position_ids_different, return_dict=return_dict + )[0] # Assert - assert isinstance(original, Seq2SeqModelOutput) - assert original.last_hidden_state.shape == (1, 8, 10) - assert isinstance(replaced_original, Seq2SeqModelOutput) - torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) + assert isinstance(original, torch.FloatTensor) + assert original.shape == (1, 8, 10) + torch.testing.assert_close( + original[0, :5, :3], + torch.tensor( + [ + [0.7961970567703247, 1.2232387065887451, 0.7286717295646667], + [0.034051503986120224, -0.9746682047843933, -0.700711190700531], + [0.1363907903432846, -0.4540761113166809, -1.2949464321136475], + [1.1136258840560913, -0.1388537585735321, 1.538393259048462], + [-1.1127841472625732, 0.22768200933933258, 1.6438117027282715], + ] + ), + ) + torch.testing.assert_close( + original.sum(dim=-1), + torch.tensor( + [ + [ + -2.384185791015625e-07, + -4.76837158203125e-07, + -2.682209014892578e-07, + 2.086162567138672e-07, + 5.960464477539063e-08, + 5.960464477539063e-08, + 0.0, + 0.0, + ] + ] + ), + ) + assert isinstance(replaced_original, torch.FloatTensor) + torch.testing.assert_close(original, replaced_original) - assert isinstance(replaced_different, Seq2SeqModelOutput) - assert replaced_different.last_hidden_state.shape == (1, 8, 10) - assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) + assert isinstance(replaced_different, torch.FloatTensor) + assert replaced_different.shape == (1, 8, 10) + torch.testing.assert_close( + replaced_different[0, :5, :3], + torch.tensor( + [ + [0.7961970567703247, 1.2232387065887451, 0.7286717295646667], + [0.1183161735534668, -0.7555443048477173, -1.230163812637329], + [1.2578136920928955, 0.18759475648403168, -0.1578090786933899], + [0.5176712870597839, 0.9378399848937988, 1.3435578346252441], + [0.6121589541435242, -1.0105386972427368, 2.361997365951538], + ] + ), + ) + torch.testing.assert_close( + replaced_different.sum(dim=-1), + torch.tensor( + [ + [ + -2.384185791015625e-07, + -4.76837158203125e-07, + -2.682209014892578e-07, + 2.086162567138672e-07, + 5.960464477539063e-08, + 5.960464477539063e-08, + 0.0, + 0.0, + ] + ] + ), + ) + assert not torch.allclose(replaced_different, original) From 686dff359c159c0d019058cd5f53230d64eb8283 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 5 Jan 2024 23:00:28 +0100 Subject: [PATCH 08/10] increase test coverage --- .../test_bart_with_decoder_position_ids.py | 80 +++++++++++-------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index e1e7a12ed..4e2d181a1 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -114,7 +114,9 @@ def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): @pytest.fixture(scope="module") def bart_model_with_decoder_position_ids(bart_config): torch.manual_seed(42) - return BartModelWithDecoderPositionIds(config=bart_config) + model = BartModelWithDecoderPositionIds(config=bart_config) + model.train() + return model def test_bart_model_with_decoder_position_ids(bart_model_with_decoder_position_ids): @@ -160,29 +162,37 @@ def test_bart_model_with_decoder_position_ids_get_decoder(bart_model_with_decode @pytest.mark.parametrize( - "return_dict", - [True, False], + "return_dict, prepare_encoder_outputs, output_everything", + [(True, True, True), (False, False, False)], ) def test_bart_model_with_decoder_position_forward( - bart_model_with_decoder_position_ids, return_dict + bart_model_with_decoder_position_ids, return_dict, prepare_encoder_outputs, output_everything ): - # Arrange model = bart_model_with_decoder_position_ids + + # Arrange + model.eval() input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) + common_kwargs = {"input_ids": input_ids, "return_dict": return_dict} + if prepare_encoder_outputs: + common_kwargs["encoder_outputs"] = bart_model_with_decoder_position_ids.get_encoder()( + input_ids=input_ids, return_dict=False + ) + else: + common_kwargs["encoder_outputs"] = None + if output_everything: + common_kwargs["output_attentions"] = True + common_kwargs["output_hidden_states"] = True # Act - torch.manual_seed(42) - original = model(input_ids=input_ids, return_dict=return_dict)[0] - torch.manual_seed(42) + original = model(**common_kwargs)[0] replaced_original = model( - input_ids=input_ids, decoder_position_ids=position_ids_original, return_dict=return_dict - )[0] - torch.manual_seed(42) - replaced_different = model( - input_ids=input_ids, decoder_position_ids=position_ids_different, return_dict=return_dict + decoder_position_ids=position_ids_original, + **common_kwargs, )[0] + replaced_different = model(decoder_position_ids=position_ids_different, **common_kwargs)[0] # Assert assert isinstance(original, torch.FloatTensor) @@ -191,11 +201,11 @@ def test_bart_model_with_decoder_position_forward( original[0, :5, :3], torch.tensor( [ - [0.7961970567703247, 1.2232387065887451, 0.7286717295646667], - [0.034051503986120224, -0.9746682047843933, -0.700711190700531], - [0.1363907903432846, -0.4540761113166809, -1.2949464321136475], - [1.1136258840560913, -0.1388537585735321, 1.538393259048462], - [-1.1127841472625732, 0.22768200933933258, 1.6438117027282715], + [0.7589594721794128, 1.0452316999435425, 0.7063764333724976], + [-0.12192550301551819, -0.9932114481925964, -0.722382664680481], + [0.24711951613426208, -0.291597843170166, -1.0466505289077759], + [1.1228691339492798, -0.0873560905456543, 1.534016728401184], + [-1.1132177114486694, 0.2277398556470871, 1.6456809043884277], ] ), ) @@ -204,14 +214,14 @@ def test_bart_model_with_decoder_position_forward( torch.tensor( [ [ - -2.384185791015625e-07, - -4.76837158203125e-07, + 0.0, + -1.1920928955078125e-07, + -1.1920928955078125e-07, -2.682209014892578e-07, - 2.086162567138672e-07, 5.960464477539063e-08, 5.960464477539063e-08, - 0.0, - 0.0, + 2.384185791015625e-07, + -5.960464477539063e-08, ] ] ), @@ -225,11 +235,11 @@ def test_bart_model_with_decoder_position_forward( replaced_different[0, :5, :3], torch.tensor( [ - [0.7961970567703247, 1.2232387065887451, 0.7286717295646667], - [0.1183161735534668, -0.7555443048477173, -1.230163812637329], - [1.2578136920928955, 0.18759475648403168, -0.1578090786933899], - [0.5176712870597839, 0.9378399848937988, 1.3435578346252441], - [0.6121589541435242, -1.0105386972427368, 2.361997365951538], + [0.7589594721794128, 1.0452316999435425, 0.7063764333724976], + [-0.0127173513174057, -0.8127143383026123, -1.256797194480896], + [1.0517312288284302, 0.037927787750959396, -0.28661563992500305], + [0.5884698629379272, 0.9930593371391296, 1.3842554092407227], + [0.6132885813713074, -1.0105736255645752, 2.361264228820801], ] ), ) @@ -238,14 +248,14 @@ def test_bart_model_with_decoder_position_forward( torch.tensor( [ [ - -2.384185791015625e-07, - -4.76837158203125e-07, - -2.682209014892578e-07, - 2.086162567138672e-07, - 5.960464477539063e-08, - 5.960464477539063e-08, - 0.0, 0.0, + -2.384185791015625e-07, + -1.7881393432617188e-07, + 2.5331974029541016e-07, + 1.4901161193847656e-07, + 1.1920928955078125e-07, + -1.1920928955078125e-07, + -1.7881393432617188e-07, ] ] ), From f281c2ee81770486bcfd23c6dbe955e22b586ba6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 8 Jan 2024 11:48:51 +0100 Subject: [PATCH 09/10] add test_bart_decoder_with_position_ids_forward_with_inputs_embeds() and def test_bart_decoder_with_position_ids_forward_wrong_position_ids_shape() --- .../bart_with_decoder_position_ids.py | 14 +++--- .../test_bart_with_decoder_position_ids.py | 44 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py index f86169f5e..937e54171 100644 --- a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py +++ b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py @@ -196,11 +196,11 @@ def forward( If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of - shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing - `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more - control over how to convert `input_ids` indices into associated vectors than the model's internal - embedding lookup matrix. + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -270,6 +270,10 @@ def forward( ) # embed positions + if position_ids is not None and position_ids.shape != input_shape: + raise ValueError( + f"Position IDs shape {position_ids.shape} does not match input ids shape {input_shape}." + ) positions = self.embed_positions(input, past_key_values_length, position_ids) positions = positions.to(inputs_embeds.device) diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py index 4e2d181a1..56c13d158 100644 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ b/tests/models/base_models/test_bart_with_decoder_position_ids.py @@ -111,6 +111,50 @@ def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) +def test_bart_decoder_with_position_ids_forward_with_inputs_embeds(bart_decoder_with_position_ids): + # Arrange + model = bart_decoder_with_position_ids + inputs_embeds = torch.randn(1, 8, 10) + position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) + + # Act + torch.manual_seed(42) + original = model(inputs_embeds=inputs_embeds) + torch.manual_seed(42) + replaced_original = model(inputs_embeds=inputs_embeds, position_ids=position_ids_original) + torch.manual_seed(42) + replaced_different = model(inputs_embeds=inputs_embeds, position_ids=position_ids_different) + + # Assert + assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions) + assert original.last_hidden_state.shape == (1, 8, 10) + assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions) + torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) + + assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions) + assert replaced_different.last_hidden_state.shape == (1, 8, 10) + assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) + + +def test_bart_decoder_with_position_ids_forward_wrong_position_ids_shape( + bart_decoder_with_position_ids, +): + # Arrange + model = bart_decoder_with_position_ids + input_ids = torch.tensor([[0, 1, 2, 3]]) + position_ids_wrong_shape = torch.tensor([[0, 1, 2]]) + + # Act + torch.manual_seed(42) + with pytest.raises(ValueError) as excinfo: + model(input_ids=input_ids, position_ids=position_ids_wrong_shape) + assert ( + str(excinfo.value) + == "Position IDs shape torch.Size([1, 3]) does not match input ids shape torch.Size([1, 4])." + ) + + @pytest.fixture(scope="module") def bart_model_with_decoder_position_ids(bart_config): torch.manual_seed(42) From 2544024e3b453ca4fc41549ccb6832a285a79ca9 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 8 Jan 2024 11:59:44 +0100 Subject: [PATCH 10/10] add documentation --- .../models/base_models/bart_with_decoder_position_ids.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py index 937e54171..03ca5dea6 100644 --- a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py +++ b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py @@ -132,8 +132,8 @@ def set_input_embeddings(self, value): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -155,6 +155,10 @@ def forward( [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Position indices for each input sequence token that are used to create the position embedding + of the sequence. If `None` (default), position ids are automatically created as sequential + integers (takes previous `past_key_values` into account, if provided). attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: