forked from boostcampaitech5/level2_nlp_mrc-nlp-08
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_taemin.py
109 lines (98 loc) · 4 KB
/
utils_taemin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from typing import Callable, Dict, List, Tuple
import evaluate
import pandas as pd
import transformers
from datasets import (Dataset, DatasetDict, Features, Sequence, Value,
load_from_disk, metric)
from numpy import array
from transformers import (AutoTokenizer, DataCollatorWithPadding,
EvalPrediction, TrainingArguments)
from arguments import DataTrainingArguments
from data_preprocessing import Preprocess
from retrieval import SparseRetrieval, SparseRetrievalBM25
from utils_qa import postprocess_qa_predictions
def data_collators(tokenizer):
data_collator = DataCollatorWithPadding(
tokenizer, pad_to_multiple_of=8
)
return data_collator
def compute_metrics(p: EvalPrediction):
metric = evaluate.load("squad")
return metric.compute(predictions=p.predictions, references=p.label_ids)
def post_processing_function(examples, features, predictions, training_args):
# Post-processing: start logits과 end logits을 original context의 정답과 match시킵니다.
# datasets = load_from_disk(os.path.join(os.path.abspath(os.path.dirname(__file__)), "data/train_dataset/"))
predictions = postprocess_qa_predictions(
examples=examples,
features=features,
predictions=predictions,
max_answer_length=30,
output_dir=training_args.output_dir,
)
# Metric을 구할 수 있도록 Format을 맞춰줍니다.
formatted_predictions = [
{"id": k, "prediction_text": v} for k, v in predictions.items()
]
if training_args.do_predict:
return formatted_predictions
elif training_args.do_eval:
references = [
{"id": ex["id"], "answers": eval(ex['answers'])}
# for ex in datasets["validation"]
for _, ex in examples.iterrows()
]
return EvalPrediction(
predictions=formatted_predictions, label_ids=references
)
def run_sparse_retrieval(
tokenize_fn: Callable[[str], List[str]],
datasets: pd.DataFrame,
data_path: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "csv_data"),
context_path: str = "wikipedia_documents.json",
bm25: str = None,
) -> DatasetDict:
# Query에 맞는 Passage들을 Retrieval 합니다.
if bm25:
assert bm25 in ["Okapi", "L", "plus"], "Invalid type for BM25 has been passed."
print(f"BM25 {bm25} is being used for passage retrieval")
retriever = SparseRetrievalBM25(
tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path, bm25_type=bm25
)
else:
print("TF-IDF is being used for passage retrieval")
retriever = SparseRetrieval(
tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path
)
retriever.get_sparse_embedding()
# retriever.get_sparse_embedding()
df = retriever.retrieve(datasets, topk=30)
# test data 에 대해선 정답이 없으므로 id question context 로만 데이터셋이 구성됩니다.
k = 1
if k==1:
f = Features(
{
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
# train data 에 대해선 정답이 존재하므로 id question context answer 로 데이터셋이 구성됩니다.
elif k==0:
f = Features(
{
"answers": Sequence(
feature={
"text": Value(dtype="string", id=None),
"answer_start": Value(dtype="int32", id=None),
},
length=-1,
id=None,
),
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
return datasets