Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new task: pointer network for joint ner and re #10

Merged
merged 290 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
290 commits
Select commit Hold shift + click to select a range
d87d8b6
move source files
ArneBinder Dec 13, 2023
5d21ad8
fix name
ArneBinder Dec 13, 2023
b5c18e5
add comments
ArneBinder Dec 13, 2023
80ee064
add BartAsPointerNetwork.overwrite_decoder_label_embeddings_with_mapp…
ArneBinder Dec 13, 2023
4c013cb
make pre-commit happy
ArneBinder Dec 13, 2023
41b8c81
add SimplePointerNetworkModel
ArneBinder Dec 13, 2023
748b0cd
add metrics to SimplePointerNetworkModel
ArneBinder Dec 13, 2023
3e41638
rearrange tests (and skip test_bart_generate)
ArneBinder Dec 13, 2023
4fb1c1b
raise a exception if decoder_input_ids are missing
ArneBinder Dec 13, 2023
4762a6c
fiy predict
ArneBinder Dec 13, 2023
8c2078f
add tests for SimplePointerNetworkModel
ArneBinder Dec 13, 2023
9030458
test_bart_pointer_network_generate_with_scores()
ArneBinder Dec 13, 2023
c1ab42a
fix configure_optimizers() and add tests
ArneBinder Dec 13, 2023
edf514e
remove lm_head
ArneBinder Dec 13, 2023
323efcf
fix loss calculation for PointerHead
ArneBinder Dec 13, 2023
42b45fb
partly revert 4bfc866fbb5420b8d1e7a5ad488f5947c35a4b1b
ArneBinder Dec 13, 2023
8b51771
align generation defaults with previous hyperparameters (num_beams) a…
ArneBinder Dec 13, 2023
f1b2390
loss for val and test; flatten loss dict
ArneBinder Dec 15, 2023
90b0704
align loss logging with simple model
ArneBinder Dec 15, 2023
d861eaa
return loss for test and val
ArneBinder Dec 17, 2023
a2a8efc
align tests
ArneBinder Dec 17, 2023
6d2e36c
outsource parameter collections and test them
ArneBinder Dec 17, 2023
8a6eec8
fix test_configure_optimizers()
ArneBinder Dec 17, 2023
2b8476d
add tgt_attention_mask (in the taskmodule, it is not yet used in the …
ArneBinder Dec 17, 2023
5a24388
use decoder_attention_mask in simple model
ArneBinder Dec 17, 2023
101652a
fix optimizer setup for simple pointer network
ArneBinder Dec 17, 2023
930e511
fix optimizer setup for original pointer network
ArneBinder Dec 17, 2023
3c4d234
make pre-commit happy
ArneBinder Dec 17, 2023
a5e1e73
try to restore original weight_decays (original model)
ArneBinder Dec 18, 2023
7a3a3d8
align optimizer (especially weight decay) of simple model with original
ArneBinder Dec 18, 2023
51f3db5
simplify configure_optimizers() for normal model
ArneBinder Dec 18, 2023
e1bbc53
add trained model tests
ArneBinder Dec 18, 2023
4a6349a
fix SimplePointerNetworkModel.from_pretrained()
ArneBinder Dec 18, 2023
6e9a343
add tests with trained models
ArneBinder Dec 18, 2023
6d96ef5
further disentangle parameter groups
ArneBinder Dec 18, 2023
64028e0
add generation_kwargs parameter to SimplePointerNetworkModel
ArneBinder Dec 18, 2023
bbf09d7
outsource fixture data
ArneBinder Dec 18, 2023
1fe72a9
remove unused fixture methods
ArneBinder Dec 18, 2023
62ea376
truncate first entry in SimplePointerNetworkModel.predict()
ArneBinder Dec 18, 2023
4ef4fe8
add test_sciarg_predict()
ArneBinder Dec 18, 2023
43fb445
fix (slow) tests
ArneBinder Dec 18, 2023
0d46219
add metric_intervals parameter
ArneBinder Dec 18, 2023
fc8e361
check fn and fp in test_sciarg_predict, for both model types
ArneBinder Dec 19, 2023
c5fa013
add metric_intervals parameter also to PointerNetworkModel
ArneBinder Dec 19, 2023
d8f2c8c
add tests with other model weights
ArneBinder Dec 19, 2023
e636667
minor
ArneBinder Dec 19, 2023
303a90c
shorten tests (also set num_beams=4 for simpel model to align with no…
ArneBinder Dec 19, 2023
8e9a817
translate comments in original generate
ArneBinder Dec 19, 2023
62f2760
set no_repeat_ngram_size to 7 (was 3 implicitly)
ArneBinder Dec 19, 2023
09c3f9e
use new simple model
ArneBinder Dec 19, 2023
bed17a5
outsource tests for predict
ArneBinder Dec 19, 2023
0659ffe
generation: disable ForcedBOSTokenLogitsProcessor and ForcedEOSTokenL…
ArneBinder Dec 19, 2023
b33c52f
improve generation parameter interface
ArneBinder Dec 19, 2023
2b677b0
fix handling of generation kwargs from annotation_encoder_decoder
ArneBinder Dec 19, 2023
59f7b5f
make isort happy
ArneBinder Dec 19, 2023
c2f1559
add prefix_allowed_tokens_fn (not yet used because of strange behaviour)
ArneBinder Dec 19, 2023
dcf5164
make pre-commit happy
ArneBinder Dec 19, 2023
4750ee0
fix pie_modules.taskmodules.pointer_network_taskmodule.TokenDocumentW…
ArneBinder Dec 20, 2023
b25c0c4
pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAnd…
ArneBinder Dec 20, 2023
79c0fe1
set default of create_constraints to False
ArneBinder Dec 20, 2023
d7a7540
do not set strict_span_conversion and verbose values for tokenize_doc…
ArneBinder Dec 20, 2023
7ad2ec1
fix tests
ArneBinder Dec 20, 2023
7d0c3a5
fix tests
ArneBinder Dec 20, 2023
f645f36
outsource maybe_pad_values() and maybe_to_tensor() to utils.py; add u…
ArneBinder Dec 20, 2023
5a8b319
add test_maybe_log_example(_disabled)
ArneBinder Dec 20, 2023
a1fa252
tgt_attention_mask as property of TargetEncodingType
ArneBinder Dec 20, 2023
ec1eb5f
rename and move AnnotationEncoderDecoder to AnnotationLayersEncoderDe…
ArneBinder Dec 20, 2023
118c20e
move BatchableMixin to common.py
ArneBinder Dec 20, 2023
ce0228c
PointerNetworkSpanAndRelationEncoderDecoder.encode()/decode() works w…
ArneBinder Dec 20, 2023
df3bde7
add AnnotationLayersEncoderDecoder.get_metric(); move AnnotationLayer…
ArneBinder Dec 21, 2023
67e1565
fix unbatch_output
ArneBinder Dec 21, 2023
6535ae3
outsource unbatching (but with grain of salt, see todos)
ArneBinder Dec 21, 2023
e2e4d8f
make AnnotationLayerMetric a Metric
ArneBinder Dec 21, 2023
166fc29
remove metrics.py
ArneBinder Dec 21, 2023
177c551
huge refactor: re-integrate main encoding / decoding logi back into t…
ArneBinder Dec 21, 2023
2154a00
move common taskmodule components and add tests
ArneBinder Dec 21, 2023
52ed1f7
add checks with HasBuildMetric
ArneBinder Dec 21, 2023
c294951
remove Optional from result type of AnnotationEncoderDecoder.encode()…
ArneBinder Dec 22, 2023
362ef66
increase test coverage
ArneBinder Dec 22, 2023
20ffcbf
increase test coverage
ArneBinder Dec 22, 2023
e929f49
better / more correct weight decay parametrization
ArneBinder Dec 22, 2023
e7d6712
restore backwards compatibility
ArneBinder Dec 22, 2023
9fec037
fix type of zero recall and precision
ArneBinder Dec 22, 2023
1213a98
do not ignore taskmodule_config for model.save_hyperparameters() to g…
ArneBinder Dec 22, 2023
8e1161d
fix warning
ArneBinder Dec 22, 2023
fb83df8
minor rename
ArneBinder Dec 23, 2023
6347f05
add use_prediction_for_metrics parameter
ArneBinder Dec 23, 2023
3af9980
raise exceptions if label_ids or target_token_ids is not provided to …
ArneBinder Dec 23, 2023
8646c94
revert (does not work as expected): raise exceptions if label_ids or …
ArneBinder Dec 23, 2023
f17364d
improve modularization: move optimizer related code and parameters in…
ArneBinder Dec 23, 2023
49156eb
rename taskmodule to PointerNetworkTaskModuleForEnd2EndRE
ArneBinder Dec 23, 2023
b7dc672
rename em to em_original and add simplified calculation for em
ArneBinder Dec 23, 2023
e707183
minor
ArneBinder Dec 23, 2023
6b0b3c4
show not encoded annotations
ArneBinder Dec 23, 2023
8f05c4d
fix: show not encoded annotations
ArneBinder Dec 23, 2023
f545b81
remove ignore_error_types parameter and sanitize_sequence() returns r…
ArneBinder Dec 23, 2023
b17be10
outsource get_valid_relation_encoding() from sanitize_sequence(); pre…
ArneBinder Dec 23, 2023
ee64a84
rename base_model_kwargs to base_model_config and move base_model_nam…
ArneBinder Dec 23, 2023
fabc197
remove max_target_length
ArneBinder Dec 23, 2023
37b310f
fix: remove max_target_length
ArneBinder Dec 23, 2023
c3f9faf
improve taskmodule tests
ArneBinder Dec 23, 2023
78835b8
do not show remaining encoding ids if they are only eos_id
ArneBinder Dec 23, 2023
17482b3
rename get_valid_relation_encoding() to validate_relation_encoding() …
ArneBinder Dec 24, 2023
0e5e147
add tests for validate_relation_encoding()
ArneBinder Dec 24, 2023
24ced01
add validate_encoding() to AnnotationEncoderDecoder
ArneBinder Dec 24, 2023
f7e00d5
fix: add validate_encoding() to AnnotationEncoderDecoder (test)
ArneBinder Dec 24, 2023
48a4621
remove validate_relation_encoding()
ArneBinder Dec 24, 2023
2276c61
AnnotationEncoderDecoder raise DecodingExceptions
ArneBinder Dec 24, 2023
f8b4bd2
remove AnnotationEncoderDecoder.validate_encoding(); rename get_valid…
ArneBinder Dec 24, 2023
50514fa
add metric tests, also with use_prediction_for_metrics=False
ArneBinder Dec 24, 2023
e84d9b7
fix slow test
ArneBinder Dec 24, 2023
aa3b2f1
add test_predict_step
ArneBinder Dec 24, 2023
37710a4
simplify metric logging
ArneBinder Dec 24, 2023
8ea6153
try to make metrics work with lightning logging
ArneBinder Dec 24, 2023
a7437b1
revert: try to make metrics work with lightning logging
ArneBinder Dec 24, 2023
cb966d0
add comment
ArneBinder Dec 24, 2023
98d13b8
remove deprecated parameter layernorm_decay
ArneBinder Dec 24, 2023
71cee73
use assert_close() instead of assert_allclose() which is deprecated
ArneBinder Dec 24, 2023
6196757
add test_on_train_epoch_end(), test_on_validation_epoch_end(), test_o…
ArneBinder Dec 24, 2023
3607d8e
remove original pointer network model
ArneBinder Dec 24, 2023
bdcc9b4
fix setting decoder_start_token_id
ArneBinder Dec 25, 2023
70421db
improve decoder mask handling
ArneBinder Dec 25, 2023
1af62e1
test with batch of differently sized (encoder) inputs
ArneBinder Dec 25, 2023
58bf21b
prepare decoder_position_id_pattern for PointerHead
ArneBinder Dec 25, 2023
af0d7cf
fix BartAsPointerNetwork tests
ArneBinder Dec 25, 2023
027cca3
check prepare_decoder_position_ids with increase_position_ids_per_rec…
ArneBinder Dec 25, 2023
9dbeceb
implement prepare_decoder_position_ids (yet without bart model modifi…
ArneBinder Dec 26, 2023
9aaa201
add slow test test_sciarg_predict_with_position_id_pattern()
ArneBinder Dec 26, 2023
202a705
add modeling_bart_with_position_ids.py
ArneBinder Dec 26, 2023
917c1b3
make mypy happy and move into BartModelWithDecoderPositionIds to mode…
ArneBinder Dec 26, 2023
c7593a9
use BartModelWithDecoderPositionIds in BartAsPointerNetwork
ArneBinder Dec 26, 2023
0dcd512
add tests with decoder_position_id_pattern
ArneBinder Dec 26, 2023
308640e
allow use_prediction_for_metrics to be a dict (specify per stage)
ArneBinder Dec 26, 2023
2fe0d22
complain if use_prediction_for_metrics defines anything not in metric…
ArneBinder Dec 26, 2023
04864b9
make use_prediction_for_metrics a list (instead of a dict)
ArneBinder Dec 26, 2023
2fa97ef
add test_build_metric() and test_generation_kwargs()
ArneBinder Dec 26, 2023
3cd4811
only call configure_optimizer() on the base model if available and de…
ArneBinder Dec 26, 2023
2e22875
rename adjust_original_model() to adjust_after_loading_original_model()
ArneBinder Dec 26, 2023
0447613
directly call adjust_after_loading_original_model() in BartAsPointerN…
ArneBinder Dec 26, 2023
a8e0d8a
adjust forced_bos_token_id and forced_eos_token_id directly in adjust…
ArneBinder Dec 26, 2023
957b836
improve comment
ArneBinder Dec 26, 2023
c2a9c26
fix
ArneBinder Dec 26, 2023
97fac59
fix test_bart_pointer_network_generate_with_scores()
ArneBinder Dec 26, 2023
52e7ceb
rearrange model components
ArneBinder Dec 26, 2023
487f1d6
add base_model_type parameter
ArneBinder Dec 26, 2023
f93b256
do not wrap model output
ArneBinder Dec 26, 2023
0441133
rename HasBuildMetric.build_metric() to HasConfigureMetric.configure_…
ArneBinder Dec 27, 2023
8c3a7f6
use AutoTaskModule._from_pretrained() in the model
ArneBinder Dec 27, 2023
ea975c3
add optimizer_type parameter
ArneBinder Dec 27, 2023
1e89db8
make mypy happy without cheating
ArneBinder Dec 27, 2023
cfe1312
add todos
ArneBinder Dec 27, 2023
00fc888
improve metrics
ArneBinder Dec 27, 2023
7f45c2b
move metrics
ArneBinder Dec 27, 2023
cddb9d0
check full content in test_sciarg_predict()
ArneBinder Dec 27, 2023
43a6f95
separate test_sciarg_metric() from test_sciarg_predict()
ArneBinder Dec 27, 2023
b49a9fa
skip test_sciarg_metric if model not available
ArneBinder Dec 27, 2023
c653563
move batch data
ArneBinder Dec 27, 2023
48cd3e3
add sciarg batch encoding fixture data as json
ArneBinder Dec 27, 2023
6c2f05d
add typing
ArneBinder Dec 27, 2023
51192b2
add taskmodule config for tests
ArneBinder Dec 27, 2023
1f82247
add comments in preparation of output format change
ArneBinder Dec 27, 2023
504f9cb
unbatch_output() truncates before eos token
ArneBinder Dec 27, 2023
2ed109c
simplify and fix prediction tests
ArneBinder Dec 27, 2023
504081f
no need to ensure that eos token is in the model output
ArneBinder Dec 27, 2023
41921d6
model output and tgt_tokens are without initial bos token (to get par…
ArneBinder Dec 27, 2023
96f0332
remove (src/tgt)_seq_len
ArneBinder Dec 27, 2023
d6cf1f8
rename (src_tokens, src_attention_mask, tgt_tokens, tgt_attention_mas…
ArneBinder Dec 27, 2023
145dfd7
pass all inputs (and targets) into the model (this fixes the missed a…
ArneBinder Dec 27, 2023
baad01c
remove unused fixture data
ArneBinder Dec 27, 2023
bdba2be
rename SimplePointerNetworkModel to SimpleGenerativeModel
ArneBinder Dec 27, 2023
8a8b384
count errors per relation encoding, not per batch entry, and fix metric
ArneBinder Dec 28, 2023
f0163f8
use https://github.com/ArneBinder/pytorch-ie/pull/392
ArneBinder Dec 28, 2023
9b01cf5
use pytorch-ie release 0.29.5
ArneBinder Dec 28, 2023
801d59a
rename PointerNetworkTaskModuleForEnd2EndRE.generation_kwargs to .gen…
ArneBinder Dec 31, 2023
2d2e74c
fix predict tests
ArneBinder Jan 2, 2024
7e653fb
use updated model (but fixed) from other branch
ArneBinder Jan 2, 2024
aeda7c6
add test_build_constraints()
ArneBinder Jan 2, 2024
d3fdd96
add test_build_constraints_single_label()
ArneBinder Jan 2, 2024
9852f4b
simplify build_constraints()
ArneBinder Jan 2, 2024
8674d59
make constraint tests more readable
ArneBinder Jan 2, 2024
da4f605
make constraint tests more readable again
ArneBinder Jan 2, 2024
e198252
rework and fix build_constraints() and add tests
ArneBinder Jan 2, 2024
792fa42
implement constrained generation
ArneBinder Jan 2, 2024
baa3910
minor
ArneBinder Jan 2, 2024
ec8024f
minor fix for _build_constraint()
ArneBinder Jan 2, 2024
53cd8d3
implement training with constraints
ArneBinder Jan 2, 2024
2a9a8be
catch exceptions and skip respective items in encode_target()
ArneBinder Jan 2, 2024
cdbfd71
make use_constraints_encoder_mlp configurable
ArneBinder Jan 2, 2024
66aa3b0
fix constrained_generation by using PrefixConstrainedLogitsProcessorW…
ArneBinder Jan 2, 2024
011e2e4
improve model tests
ArneBinder Jan 2, 2024
5f3bb6c
align with main branch to prepare rebase
ArneBinder Jan 4, 2024
f94a94f
fix rebase errors
ArneBinder Jan 4, 2024
a2b83d6
make tokenizer_name_or_path a mandatory parameter; cleanup
ArneBinder Jan 4, 2024
1a3ddfe
add test_configure_model_generation_with_constrained_generation()
ArneBinder Jan 4, 2024
8aff0df
fix _build_constraint() and add test_prefix_allowed_tokens_fn_with_ma…
ArneBinder Jan 4, 2024
c19a66f
add test_prefix_constrained_logits_processor_with_maximum()
ArneBinder Jan 4, 2024
7b44e19
disentangle metrics
ArneBinder Jan 4, 2024
2452d5e
rename metric keys: invalid to errors and em to encoding_match
ArneBinder Jan 4, 2024
2daf007
fix test_simple_generative_pointer_predict.py
ArneBinder Jan 4, 2024
028aabe
check that encoder_input_ids_index.max() >= encoder_input_length and …
ArneBinder Jan 4, 2024
2a7b180
remove metadata from decode_annotations
ArneBinder Jan 5, 2024
36e7871
simplify and harden metrics; add collect_encoding_matches parameter
ArneBinder Jan 5, 2024
a24396f
rearrange metrics
ArneBinder Jan 5, 2024
23d907d
rename WrappedLayerMetricsWithUnbatchAndDecodingFunction to WrappedLa…
ArneBinder Jan 5, 2024
c876937
add parameters key_micro and in_percent to PrecisionRecallAndF1ForLab…
ArneBinder Jan 5, 2024
9a8d27e
rearrange
ArneBinder Jan 5, 2024
11ce0bb
simplify tests for WrappedMetricWithUnbatchFunction
ArneBinder Jan 5, 2024
e959b72
add tests for PrecisionRecallAndF1ForLabeledAnnotations
ArneBinder Jan 5, 2024
0224ea8
make pre-commit happy
ArneBinder Jan 5, 2024
ae69ce4
move taskmodule metric tests to the actual taskmodule tests and strea…
ArneBinder Jan 5, 2024
c2f45e1
cleanup imports
ArneBinder Jan 5, 2024
872f4c5
rename decode_annotations_with_errors_function to decode_layers_with_…
ArneBinder Jan 5, 2024
1857af8
use add_state() for metrics
ArneBinder Jan 5, 2024
f553542
cleanup
ArneBinder Jan 5, 2024
4e59e78
re-arrange metrics
ArneBinder Jan 5, 2024
7391898
move (test_)pointer_network_taskmodule_for_end2end_re.py to (test_)po…
ArneBinder Jan 5, 2024
2728085
mention taskmodule in readme
ArneBinder Jan 5, 2024
ad94a2e
revert change in utils.py
ArneBinder Jan 8, 2024
4f741a6
simplify PointerHead: remove decoder, just keep its embeddings
ArneBinder Jan 8, 2024
1cb502f
add comment
ArneBinder Jan 8, 2024
adf5479
fix / improve comments
ArneBinder Jan 8, 2024
6a1bacb
simplify embedding handling
ArneBinder Jan 8, 2024
fdd3a4d
streamline overwrite_embeddings_with_mapping()
ArneBinder Jan 8, 2024
d9db9cf
streamline position ids configuration
ArneBinder Jan 8, 2024
e49250a
improve comments
ArneBinder Jan 8, 2024
23fce44
remove max_target_positions from BartAsPointerNetworkConfig
ArneBinder Jan 8, 2024
82aea3f
remove output_size from PointerHead because it was not used; add todo
ArneBinder Jan 8, 2024
1503f0e
add some tests for PointerHead
ArneBinder Jan 8, 2024
bdc071e
add bos_ids to PointerHead, but remove label_ids (can be reconstructe…
ArneBinder Jan 8, 2024
324df70
improve and add tests for PointerHead
ArneBinder Jan 8, 2024
bb35130
rearrange
ArneBinder Jan 8, 2024
8faaefa
add argument names
ArneBinder Jan 9, 2024
af32a62
add and re-arrange tests for BartAsPointerNetwork
ArneBinder Jan 9, 2024
e9c2155
add todo
ArneBinder Jan 9, 2024
43d287d
greatly simplify BartAsPointerNetwork
ArneBinder Jan 9, 2024
6ac4f3a
add test_configure_optimizer()
ArneBinder Jan 9, 2024
458b57e
improve todo
ArneBinder Jan 9, 2024
418e8fa
remove todo
ArneBinder Jan 9, 2024
4941c91
cleanup
ArneBinder Jan 9, 2024
1b83d29
remove todo (moved to PR description)
ArneBinder Jan 9, 2024
286714e
remove test_simple_generative_pointer.py
ArneBinder Jan 9, 2024
6c662db
remove attention_mask parameter from prepare_decoder_inputs(), prepar…
ArneBinder Jan 10, 2024
3683874
handle deprecation warning (assert_allclose)
ArneBinder Jan 10, 2024
39974de
mask bos positions (in addition to eos and pad) on pointer scores; co…
ArneBinder Jan 10, 2024
a145fe0
improve test cases
ArneBinder Jan 10, 2024
865f9b5
use sshleifer/bart-tiny-random to test BartAsPointerNetwork
ArneBinder Jan 10, 2024
9a0ade6
remove test_simple_generative_pointer_predict.py and respective fixtu…
ArneBinder Jan 10, 2024
ec1ebaa
do not show precision / recall / f1 in percent
ArneBinder Jan 10, 2024
3dddc17
revert: mask bos token offsets
ArneBinder Jan 11, 2024
b79aa68
minor change
ArneBinder Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Available taskmodules:
- [TokenClassificationTaskModule](src/pie_modules/taskmodules/token_classification.py)
- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py)
- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py)
- [PointerNetworkTaskModuleForEnd2EndRE](src/pie_modules/taskmodules/pointer_network_for_end2end_re.py)

Available metrics:

Expand Down
1 change: 1 addition & 0 deletions src/pie_modules/models/base_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .bart_as_pointer_network import BartAsPointerNetwork
from .bart_with_decoder_position_ids import BartModelWithDecoderPositionIds
470 changes: 470 additions & 0 deletions src/pie_modules/models/base_models/bart_as_pointer_network.py

Large diffs are not rendered by default.

307 changes: 307 additions & 0 deletions src/pie_modules/models/components/pointer_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.utils import logging

logger = logging.get_logger(__name__)


class PointerHead(torch.nn.Module):
# Copy and generate,
def __init__(
self,
# (decoder) input space
target_token_ids: List[int],
# output space (targets)
bos_id: int,
eos_id: int,
pad_id: int,
# embeddings
embeddings: nn.Embedding,
embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None,
# other parameters
use_encoder_mlp: bool = False,
use_constraints_encoder_mlp: bool = False,
decoder_position_id_pattern: Optional[List[int]] = None,
increase_position_ids_per_record: bool = False,
):
super().__init__()

self.embeddings = embeddings

self.pointer_offset = len(target_token_ids)

# check that bos, eos, and pad are not out of bounds
for target_id, target_id_name in zip(
[bos_id, eos_id, pad_id], ["bos_id", "eos_id", "pad_id"]
):
if target_id >= len(target_token_ids):
raise ValueError(
f"{target_id_name} [{target_id}] must be smaller than the number of target token ids "
f"[{len(target_token_ids)}]!"
)

self.bos_id = bos_id
self.eos_id = eos_id
self.pad_id = pad_id
# all ids that are not bos, eos or pad are label ids
self.label_ids = [
target_id
for target_id in range(len(target_token_ids))
if target_id not in [self.bos_id, self.eos_id, self.pad_id]
]

target2token_id = torch.LongTensor(target_token_ids)
self.register_buffer("target2token_id", target2token_id)
self.label_token_ids = self.target2token_id[self.label_ids]
self.eos_token_id = target_token_ids[self.eos_id]
self.pad_token_id = target_token_ids[self.pad_id]

hidden_size = self.embeddings.embedding_dim
if use_encoder_mlp:
self.encoder_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
)
if use_constraints_encoder_mlp:
self.constraints_encoder_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
)

self.embedding_weight_mapping = None
if embedding_weight_mapping is not None:
# Because of config serialization, the keys may be strings. Convert them back to ints.
self.embedding_weight_mapping = {
int(k): v for k, v in embedding_weight_mapping.items()
}

if decoder_position_id_pattern is not None:
self.register_buffer(
"decoder_position_id_pattern", torch.tensor(decoder_position_id_pattern)
)
self.increase_position_ids_per_record = increase_position_ids_per_record

@property
def use_prepared_position_ids(self):
return hasattr(self, "decoder_position_id_pattern")

def set_embeddings(self, embedding: nn.Embedding) -> None:
self.embeddings = embedding

def overwrite_embeddings_with_mapping(self) -> None:
"""Overwrite individual embeddings with embeddings for other tokens.

This is useful, for instance, if the label vocabulary is a subset of the source vocabulary.
In this case, this method can be used to initialize each label embedding with one or
multiple (averaged) source embeddings.
"""
if self.embedding_weight_mapping is not None:
for special_token_index, source_indices in self.embedding_weight_mapping.items():
self.embeddings.weight.data[special_token_index] = self.embeddings.weight.data[
source_indices
].mean(dim=0)

def prepare_decoder_input_ids(
self,
input_ids: torch.LongTensor,
encoder_input_ids: torch.LongTensor,
) -> torch.LongTensor:
mapping_token_mask = input_ids.lt(self.pointer_offset)
mapped_tokens = input_ids.masked_fill(input_ids.ge(self.pointer_offset), 0)
tag_mapped_tokens = self.target2token_id[mapped_tokens]

encoder_input_ids_index = input_ids - self.pointer_offset
encoder_input_ids_index = encoder_input_ids_index.masked_fill(
encoder_input_ids_index.lt(0), 0
)
encoder_input_length = encoder_input_ids.size(1)
if encoder_input_ids_index.max() >= encoder_input_length:
raise ValueError(
f"encoder_input_ids_index.max() [{encoder_input_ids_index.max()}] must be smaller "
f"than encoder_input_length [{encoder_input_length}]!"
)

word_mapped_tokens = encoder_input_ids.gather(index=encoder_input_ids_index, dim=1)

decoder_input_ids = torch.where(
mapping_token_mask, tag_mapped_tokens, word_mapped_tokens
).to(torch.long)

# Note: we do not need to explicitly handle the padding (via a decoder attention mask) because
# it gets automatically mapped to the pad token id

return decoder_input_ids

def prepare_decoder_position_ids(
self,
input_ids: torch.LongTensor,
# will be used to create the padding mask from the input_ids. Needs to be provided because
# the input_ids may be in token space or target space.
pad_input_id: int,
) -> torch.LongTensor:
bsz, tokens_len = input_ids.size()
pattern_len = len(self.decoder_position_id_pattern)
# the number of full and partly records. note that tokens_len includes the bos token
repeat_num = (tokens_len - 2) // pattern_len + 1
position_ids = self.decoder_position_id_pattern.repeat(bsz, repeat_num)

if self.increase_position_ids_per_record:
position_ids_reshaped = position_ids.view(bsz, -1, pattern_len)
add_shift_pos = (
torch.range(0, repeat_num - 1, device=position_ids_reshaped.device)
.repeat(bsz)
.view(bsz, -1)
.unsqueeze(-1)
)
# multiply by the highest position id in the pattern so that the position ids are unique
# for any decoder_position_id_pattern across all records
add_shift_pos *= max(self.decoder_position_id_pattern) + 1
position_ids_reshaped = add_shift_pos + position_ids_reshaped
position_ids = position_ids_reshaped.view(bsz, -1).long()
# use start_position_id=0
start_pos = torch.zeros(bsz, 1, dtype=position_ids.dtype, device=position_ids.device)
# shift by 2 to account for start_position_id=0 and pad_position_id=1
all_position_ids = torch.cat([start_pos, position_ids + 2], dim=-1)
all_position_ids_truncated = all_position_ids[:bsz, :tokens_len]

# mask the padding tokens
mask_invalid = input_ids.eq(pad_input_id)
all_position_ids_truncated_masked = all_position_ids_truncated.masked_fill(mask_invalid, 1)

return all_position_ids_truncated_masked

def prepare_decoder_inputs(
self,
input_ids: torch.LongTensor,
encoder_input_ids: torch.LongTensor,
position_ids: Optional[torch.LongTensor] = None,
) -> Dict[str, torch.Tensor]:
inputs = {}
if self.use_prepared_position_ids:
if position_ids is None:
position_ids = self.prepare_decoder_position_ids(
# the input_ids are in the target space, so we provide pointer_head.pad_id as the pad_token_id
input_ids=input_ids,
pad_input_id=self.pad_id,
)
inputs["position_ids"] = position_ids

inputs["input_ids"] = self.prepare_decoder_input_ids(
input_ids=input_ids,
encoder_input_ids=encoder_input_ids,
)
return inputs

def forward(
self,
last_hidden_state,
encoder_input_ids,
encoder_last_hidden_state,
encoder_attention_mask,
labels: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
constraints: Optional[torch.LongTensor] = None,
):
# assemble the logits
logits = last_hidden_state.new_full(
(
last_hidden_state.size(0),
last_hidden_state.size(1),
self.pointer_offset + encoder_input_ids.size(-1),
),
fill_value=-1e24,
)

# eos and label scores depend only on the decoder output
# bsz x max_len x 1
eos_scores = F.linear(last_hidden_state, self.embeddings.weight[[self.eos_token_id]])
label_embeddings = self.embeddings.weight[self.label_token_ids]
# bsz x max_len x num_class
label_scores = F.linear(last_hidden_state, label_embeddings)

# the pointer depends on the src token embeddings, the encoder output and the decoder output
# bsz x max_bpe_len x hidden_size
src_outputs = encoder_last_hidden_state
if getattr(self, "encoder_mlp", None) is not None:
src_outputs = self.encoder_mlp(src_outputs)

# bsz x max_word_len x hidden_size
input_embed = self.embeddings(encoder_input_ids)

# bsz x max_len x max_word_len
word_scores = torch.einsum("blh,bnh->bln", last_hidden_state, src_outputs)
gen_scores = torch.einsum("blh,bnh->bln", last_hidden_state, input_embed)
avg_word_scores = (gen_scores + word_scores) / 2

# never point to the padding or the eos token in the encoder input
# TODO: why not excluding the bos token? seems to give worse results, but not tested extensively
mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id)
avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32)

# Note: the remaining row in logits contains the score for the bos token which should be never generated!
logits[:, :, [self.eos_id]] = eos_scores
logits[:, :, self.label_ids] = label_scores
logits[:, :, self.pointer_offset :] = avg_word_scores

loss = None
# compute the loss if labels are provided
if labels is not None:
loss_fct = CrossEntropyLoss()
logits_resized = logits.reshape(-1, logits.size(-1))
labels_resized = labels.reshape(-1)
if decoder_attention_mask is None:
raise ValueError("decoder_attention_mask must be provided to compute the loss!")
mask_resized = decoder_attention_mask.reshape(-1)
labels_masked = labels_resized.masked_fill(
~mask_resized.to(torch.bool), loss_fct.ignore_index
)
loss = loss_fct(logits_resized, labels_masked)

# compute the constraints loss if constraints are provided
if constraints is not None:
if getattr(self, "constraints_encoder_mlp", None) is not None:
# TODO: is it fine to apply constraints_encoder_mlp to both src_outputs and label_embeddings?
# This is what the original code seems to do, but this is different from the usage of encoder_mlp.
constraints_src_outputs = self.constraints_encoder_mlp(src_outputs)
constraints_label_embeddings = self.constraints_encoder_mlp(label_embeddings)
else:
constraints_src_outputs = src_outputs
constraints_label_embeddings = label_embeddings
constraints_label_scores = F.linear(last_hidden_state, constraints_label_embeddings)
# bsz x max_len x max_word_len
constraints_word_scores = torch.einsum(
"blh,bnh->bln", last_hidden_state, constraints_src_outputs
)
constraints_logits = last_hidden_state.new_full(
(
last_hidden_state.size(0),
last_hidden_state.size(1),
self.pointer_offset + encoder_input_ids.size(-1),
),
fill_value=-1e24,
)
constraints_logits[:, :, self.label_ids] = constraints_label_scores
constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores

mask = constraints >= 0
constraints_logits_valid = constraints_logits[mask]
constraints_valid = constraints[mask]
loss_c = F.binary_cross_entropy(
torch.sigmoid(constraints_logits_valid), constraints_valid.float()
)

if loss is None:
loss = loss_c
else:
loss += loss_c

return logits, loss
1 change: 1 addition & 0 deletions src/pie_modules/taskmodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .extractive_question_answering import ExtractiveQuestionAnsweringTaskModule
from .pointer_network_for_end2end_re import PointerNetworkTaskModuleForEnd2EndRE
from .re_text_classification_with_indices import (
RETextClassificationWithIndicesTaskModule,
)
Expand Down
1 change: 1 addition & 0 deletions src/pie_modules/taskmodules/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .interfaces import AnnotationEncoderDecoder, DecodingException
from .mixins import BatchableMixin
from .utils import get_first_occurrence_index
34 changes: 34 additions & 0 deletions src/pie_modules/taskmodules/common/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import abc
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

from pytorch_ie import Annotation

# Annotation Encoding type: encoding for a single annotation
AE = TypeVar("AE")
# Annotation type
A = TypeVar("A", bound=Annotation)
# Annotation Collection Encoding type: encoding for a collection of annotations,
# e.g. all relevant annotations for a document
ACE = TypeVar("ACE")


class DecodingException(Exception, Generic[AE], abc.ABC):
"""Exception raised when decoding fails."""

identifier: str

def __init__(self, message: str, encoding: AE):
self.message = message
self.encoding = encoding


class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]):
"""Base class for annotation encoders and decoders."""

@abc.abstractmethod
def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE:
pass

@abc.abstractmethod
def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A:
pass
Empty file.
Loading
Loading