-
Notifications
You must be signed in to change notification settings - Fork 269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TextGeneration Evaluator #350
Changes from all commits
5e740ba
58c9338
64fa5da
f39b566
42ba716
c27b843
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
from abc import ABC, abstractmethod | ||
from numbers import Number | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
# Lint as: python3 | ||
from datasets import Dataset, load_dataset | ||
|
@@ -234,7 +234,7 @@ def compute( | |
input_column: str = "text", | ||
label_column: str = "label", | ||
label_mapping: Optional[Dict[str, Number]] = None, | ||
) -> Tuple[Dict[str, float], Any]: | ||
) -> Dict[str, float]: | ||
|
||
result = {} | ||
|
||
|
@@ -347,7 +347,7 @@ def load_data(self, data: Union[str, Dataset], subset: str = None, split: str = | |
"Please specify a valid `data` object - either a `str` with a name or a `Dataset` object." | ||
) | ||
|
||
def prepare_data(self, data: Dataset, input_column: str, label_column: str): | ||
def prepare_data(self, data: Dataset, input_column: str, label_column: str, *args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The TextGeneration Evaluator's |
||
""" | ||
Prepare data. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2022 The HuggingFace Evaluate Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, Tuple | ||
|
||
from datasets import Dataset | ||
|
||
from .base import Evaluator | ||
from .utils import DatasetColumn | ||
|
||
|
||
TASK_DOCUMENTATION_KWARGS = r""" | ||
input_column (`str`, defaults to `"text"`): | ||
the name of the column containing the input text in the dataset specified by `data`. | ||
generation_kwargs (`Dict`, *optional*, defaults to `None`): | ||
The generation kwargs are passed to the pipeline and set the text generation strategy. | ||
""" | ||
|
||
|
||
class TextGenerationEvaluator(Evaluator): | ||
""" | ||
Text generation evaluator. | ||
This Text generation evaluator can currently be loaded from [`evaluator`] using the default task name | ||
`text-generation`. | ||
Methods in this class assume a data format compatible with the [`TextGenerationPipeline`]. | ||
""" | ||
|
||
def predictions_processor(self, predictions, *args, **kwargs): | ||
""" | ||
Args: | ||
predictions: A list of lists of dicts | ||
|
||
Returns: | ||
`dict`: All the generated texts are flattened and stored under the "data" key. | ||
""" | ||
return {"data": [pred[f"{self.predictions_prefix}_text"] for pred_list in predictions for pred in pred_list]} | ||
|
||
def __init__(self, task="text-generation", default_metric_name=None, predictions_prefix: str = "generated"): | ||
super().__init__(task=task, default_metric_name=default_metric_name) | ||
self.predictions_prefix = predictions_prefix | ||
|
||
def prepare_data(self, data: Dataset, input_column: str, *args, **kwargs) -> Tuple[Dict, DatasetColumn]: | ||
""" | ||
Prepare data. | ||
|
||
Args: | ||
data (`Dataset`): Specifies the dataset we will run evaluation on. | ||
input_column (`str`, defaults to `"text"`): | ||
the name of the column containing the text feature in the dataset specified by `data`. | ||
Returns: | ||
`dict`: metric inputs. | ||
`list`: pipeline inputs. | ||
""" | ||
|
||
self.check_required_columns(data, {"input_column": input_column}) | ||
|
||
return {}, DatasetColumn(data, input_column) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To solve the perplexity issue we could just make
gpt2
the default model so it's a kwarg instead of arg.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand – gpt2 would be the default value for
model_or_pipeline
for the TextGenerationEvaluator'scompute
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I meant we can update the
perplexity
metric to have a default value for the model (gpt2
) so it works easily with the Evaluator. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! Sure, that sounds good to me. I think that one of the basic requirements for perplexity is that it also needs the ability to receive the actual model itself as well, so I'll make sure that's possible and I'll also include that as an option in the
TextGenerationEvaluator
here. I'll open a separate PR for the perplexity change.