-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoverlap.py
255 lines (218 loc) · 11.5 KB
/
overlap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from .base import BaseClass
import torch
from .utils import get_max_length
class SingleMetric(BaseClass):
"""
A class representing a single metric.
This class inherits from the BaseClass and provides additional functionality for handling single metrics.
Args:
**kwargs: Additional keyword arguments to be passed to the BaseClass constructor.
Attributes:
None
Methods:
None
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
class Perplexity(SingleMetric):
def __init__(self, model, tokenizer, **kwargs):
"""
Initializes the Overlap class.
Args:
model: The model used for contamination detection.
tokenizer: The tokenizer used for tokenizing input data.
**kwargs: Additional keyword arguments.
Returns:
None
"""
self.model = model
self.tokenizer = tokenizer
self.max_length = get_max_length(model.config)
super().__init__(**kwargs)
def batch_call(self, outputs, inputs=None, batch_size=1):
"""
Calculate perplexity for a batch of outputs.
Args:
outputs (list): A list of output strings.
inputs (list, optional): A list of input strings. Defaults to None.
batch_size (int, optional): The batch size. Defaults to 1.
Returns:
list: A list of perplexity values for each item in the batch.
"""
indices_with_0_length_output = []
for i in range(len(outputs)):
if not isinstance(outputs[i], str) or len(outputs[i]) == 0:
indices_with_0_length_output.append(i)
if len(indices_with_0_length_output) > 0:
outputs_here = [outputs[i] for i in range(len(outputs)) if i not in indices_with_0_length_output]
inputs_here = None
if inputs is not None:
inputs_here = [inputs[i] for i in range(len(inputs)) if i not in indices_with_0_length_output]
perplexity = self.batch_call(outputs_here, inputs_here, batch_size)
# arrange the topkmin list to have the same length as the outputs list
for i in range(len(indices_with_0_length_output)):
perplexity.insert(indices_with_0_length_output[i], 0)
return perplexity
# Tokenize outputs
output_tokens = [self.tokenizer.encode(output, return_tensors='pt', add_special_tokens=False).to(self.model.device) for output in outputs]
# Tokenize inputs if provided
input_tokens = None
if inputs is not None:
input_tokens = [self.tokenizer.encode(input, return_tensors='pt').to(self.model.device) for input in inputs]
perplexities = []
for i in range(0, len(outputs), batch_size):
batch_output_tokens = output_tokens[i:i+batch_size]
# Handling input tokens for the batch
batch_input_tokens = None
if input_tokens is not None:
batch_input_tokens = input_tokens[i:i+batch_size]
# Padding tokens in the batch to have the same length
if batch_input_tokens is not None:
token_tensors = [torch.cat([batch_input_tokens[j], batch_output_tokens[j]], dim=-1) for j in range(len(batch_output_tokens))]
else:
token_tensors = batch_output_tokens
# pad token tensors to get a rectangular tensor
token_tensors_padded = torch.nn.utils.rnn.pad_sequence([token_tensor[0] for token_tensor in token_tensors], batch_first=True,
padding_value=self.tokenizer.pad_token_id).to(self.model.device)
# Truncate the tokens_tensor if it is longer than the max length
if token_tensors_padded.size(1) > self.max_length:
token_tensors_padded = token_tensors_padded[:, :self.max_length - 1]
# Calculate log likelihoods for the batch
with torch.no_grad():
outputs = self.model(input_ids=token_tensors_padded)
logits = torch.log_softmax(outputs.logits, dim=-1)
# Compute perplexity for each item in the batch
for j in range(logits.shape[0]):
logits_index = logits[j]
if len(batch_output_tokens[j]) == 0:
perplexities.append(0)
continue
if batch_input_tokens is not None:
logits_index = logits_index[batch_input_tokens[j].shape[1] - 1:]
if logits_index.shape[0] == 0:
perplexities.append(10000)
continue
log_likelihood = logits_index[:-1, :].gather(1, batch_output_tokens[j][0, :logits_index.shape[0] - 1].unsqueeze(-1)).mean()
else:
log_likelihood = logits_index[:-1, :].gather(1, batch_output_tokens[j][0, 1:logits_index.shape[0]].unsqueeze(-1)).mean()
perplexity = torch.exp(-log_likelihood)
perplexities.append(perplexity.item())
return perplexities
class Lowercase(Perplexity):
# https://arxiv.org/pdf/2012.07805.pdf
def __init__(self, model, tokenizer, **kwargs):
"""
Initializes the Overlap class.
Args:
model: The model object used for contamination detection.
tokenizer: The tokenizer object used for tokenizing input data.
**kwargs: Additional keyword arguments.
Returns:
None
"""
super().__init__(model, tokenizer, **kwargs)
def batch_call(self, outputs, inputs=None, batch_size=1):
"""
Perform a batch call to the superclass's `batch_call` method on lowercase input.
Args:
outputs (list): A list of outputs to be processed. Each output can be a string or an integer.
inputs (optional): The inputs to be passed to the superclass's `batch_call` method. Defaults to None.
batch_size (int): The batch size for processing the outputs. Defaults to 1.
Returns:
list: A list of perplexities calculated from the lowercased outputs.
"""
perplexities_lower = super().batch_call([output.lower() if isinstance(output, str) else 0 for output in outputs], inputs, batch_size)
return perplexities_lower
class TopKMin(SingleMetric):
# https://arxiv.org/pdf/2310.16789.pdf
def __init__(self, model, tokenizer, k=0.2, **kwargs):
"""
Initialize the Overlap class.
Args:
model: The model used for contamination detection.
tokenizer: The tokenizer used for tokenizing input data.
k (float): The number of most unlikely tokens to consider. Defaults to 0.2.
**kwargs: Additional keyword arguments.
Returns:
None
"""
self.model = model
self.tokenizer = tokenizer
self.k = k
self.max_length = get_max_length(model.config)
super().__init__(**kwargs)
def batch_call(self, outputs, inputs=None, batch_size=1):
"""
Perform batch processing on a list of outputs and inputs (optional) using a specified batch size.
Args:
outputs (list): A list of output strings.
inputs (list, optional): A list of input strings. Defaults to None.
batch_size (int, optional): The batch size for processing. Defaults to 1.
Returns:
list: A list of topkmin values calculated for each output.
"""
# Tokenize outputs
indices_with_0_length_output = []
for i in range(len(outputs)):
if not isinstance(outputs[i], str) or len(outputs[i]) == 0:
indices_with_0_length_output.append(i)
if len(indices_with_0_length_output) > 0:
outputs_here = [outputs[i] for i in range(len(outputs)) if i not in indices_with_0_length_output]
inputs_here = None
if inputs is not None:
inputs_here = [inputs[i] for i in range(len(inputs)) if i not in indices_with_0_length_output]
topkmin = self.batch_call(outputs_here, inputs_here, batch_size)
# arrange the topkmin list to have the same length as the outputs list
for i in range(len(indices_with_0_length_output)):
topkmin.insert(indices_with_0_length_output[i], 0)
return topkmin
output_tokens = [self.tokenizer.encode(output, return_tensors='pt', add_special_tokens=False).to(self.model.device) for output in outputs]
# Tokenize inputs if provided
input_tokens = None
if inputs is not None:
input_tokens = [self.tokenizer.encode(input, return_tensors='pt').to(self.model.device) for input in inputs]
topkmin = []
for i in range(0, len(outputs), batch_size):
batch_output_tokens = output_tokens[i:i+batch_size]
# Handling input tokens for the batch
batch_input_tokens = None
if input_tokens is not None:
batch_input_tokens = input_tokens[i:i+batch_size]
# Padding tokens in the batch to have the same length
if batch_input_tokens is not None:
token_tensors = [torch.cat([batch_input_tokens[j], batch_output_tokens[j]], dim=-1) for j in range(len(batch_output_tokens))]
else:
token_tensors = batch_output_tokens
# pad token tensors to get a rectangular tensor
token_tensors_padded = torch.nn.utils.rnn.pad_sequence([token_tensor[0] for token_tensor in token_tensors], batch_first=True,
padding_value=self.tokenizer.pad_token_id).to(self.model.device)
# Truncate the tokens_tensor if it is longer than the max length
if token_tensors_padded.size(1) > self.max_length:
token_tensors_padded = token_tensors_padded[:, :self.max_length - 1]
# Calculate log likelihoods for the batch
with torch.no_grad():
outputs = self.model(token_tensors_padded)
logits = torch.log_softmax(outputs.logits, dim=-1)
# Compute perplexity for each item in the batch
for j in range(logits.shape[0]):
logits_index = logits[j]
if len(batch_output_tokens[j]) == 0:
topkmin.append(0)
continue
if batch_input_tokens is not None:
logits_index = logits_index[batch_input_tokens[j].shape[1] - 1:]
if logits_index.shape[0] == 0:
topkmin.append(10000)
continue
log_likelihood = logits_index[:-1, :].gather(1, batch_output_tokens[j][0, :logits_index.shape[0] - 1].unsqueeze(-1))
else:
log_likelihood = logits_index[:-1, :].gather(1, batch_output_tokens[j][0, 1:logits_index.shape[0]].unsqueeze(-1))
# get the least likely tokens, top-k
top_k = int(self.k * log_likelihood.size(0))
if top_k == 0:
top_k = 1
least_likely_tokens = torch.topk(log_likelihood, top_k, dim=0, largest=False)[0]
# get the mean of the least likely tokens
mean = least_likely_tokens.mean(dim=0)
topkmin.append(mean.item())
return topkmin