diff --git a/CHANGELOG.md b/CHANGELOG.md index 61a3665282..b74da5e12c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Task.serializer` in favour of `Task.output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) +- Deprecated `flash.text.seq2seq.core.metrics` in favour of `torchmetrics[text]` ([#648](https://github.com/PyTorchLightning/lightning-flash/pull/648)) + ### Fixed ### Removed diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index 8e088bddb0..50750dad05 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -99,6 +99,3 @@ _______________ seq2seq.core.data.Seq2SeqOutputTransform seq2seq.core.data.Seq2SeqPreprocess seq2seq.core.data.Seq2SeqSentencesDataSource - seq2seq.core.metrics.BLEUScore - seq2seq.core.metrics.RougeBatchAggregator - seq2seq.core.metrics.RougeMetric diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 0859763532..cd50f8cb8c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -94,9 +94,9 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") _TORCHAUDIO_AVAILABLE = _module_available("torchaudio") -_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") +_TM_TEXT_AVAILABLE: bool = _module_available("torchmetrics.text") _ICEVISION_AVAILABLE = _module_available("icevision") _ICEDATA_AVAILABLE = _module_available("icedata") _LEARN2LEARN_AVAILABLE = _module_available("learn2learn") and _compare_version("learn2learn", operator.ge, "0.1.6") @@ -123,9 +123,9 @@ class Image: _TEXT_AVAILABLE = all( [ _TRANSFORMERS_AVAILABLE, - _ROUGE_SCORE_AVAILABLE, _SENTENCEPIECE_AVAILABLE, _DATASETS_AVAILABLE, + _TM_TEXT_AVAILABLE, ] ) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 7087239c18..1c17b0eca5 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -125,7 +125,7 @@ def __init__( self._initialize_model_specific_parameters() self.rouge = RougeMetric( - rouge_newline_sep=rouge_newline_sep, + newline_sep=rouge_newline_sep, use_stemmer=use_stemmer, ) diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index cc6398db2e..5eee448851 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -16,220 +16,29 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from collections import Counter -from typing import Dict, List, Tuple +from functools import partial +from typing import Tuple -import numpy as np -import torch -from torch import tensor -from torchmetrics import Metric +from deprecate import deprecated, void +from pytorch_lightning.utilities import rank_zero_deprecation +from torchmetrics.text import BLEUScore as _BLEUScore +from torchmetrics.text import ROUGEScore as _ROUGEScore -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires -from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence +_deprecated_text_metrics = partial(deprecated, deprecated_in="0.6.0", remove_in="0.7.0", stream=rank_zero_deprecation) -if _TEXT_AVAILABLE: - from rouge_score import rouge_scorer - from rouge_score.scoring import AggregateScore, BootstrapAggregator, Score -else: - AggregateScore, Score, BootstrapAggregator = None, None, object - - -def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """ - Counting how many times each word appears in a given text with ngram - Args: - ngram_input_list: A list of translated text or reference texts - n_gram: gram value ranged 1 to 4 - - Return: - ngram_counter: a collections.Counter object of ngram - """ - - ngram_counter = Counter() - - for i in range(1, n_gram + 1): - for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j : (i + j)]) - ngram_counter[ngram_key] += 1 - - return ngram_counter - - -class BLEUScore(Metric): - """Calculate BLEU score of machine translated text with one or more references. - - Example: - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> metric = BLEUScore() - >>> metric(translate_corpus, reference_corpus) - tensor(0.7598) - """ +class BLEUScore(_BLEUScore): + @_deprecated_text_metrics(target=_BLEUScore) def __init__(self, n_gram: int = 4, smooth: bool = False): - """ - Args: - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - """ - super().__init__() - self.n_gram = n_gram - self.smooth = smooth - - self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - - def compute(self): - - trans_len = self.c.clone().detach() - ref_len = self.r.clone().detach() - - if min(self.numerator) == 0.0: - return tensor(0.0, device=self.r.device) - - if self.smooth: - precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0) - else: - precision_scores = self.numerator / self.denominator - - log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log( - precision_scores - ) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) - bleu = brevity_penalty * geometric_mean - return bleu - - def update(self, translate_corpus, reference_corpus) -> None: - """ - Actual metric computation - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - """ - for (translation, references) in zip(translate_corpus, reference_corpus): - self.c += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter = _count_ngram(translation, self.n_gram) - reference_counter = Counter() - - for ref in references: - reference_counter |= _count_ngram(ref, self.n_gram) - - ngram_counter_clip = translation_counter & reference_counter - - for counter_clip in ngram_counter_clip: - self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + void(n_gram, smooth) - for counter in translation_counter: - self.denominator[len(counter) - 1] += translation_counter[counter] - -class RougeMetric(Metric): - """Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/ - - Example: - - >>> target = "Is your name John".split() - >>> preds = "My name is John".split() - >>> rouge = RougeMetric() # doctest: +SKIP - >>> from pprint import pprint - >>> pprint(rouge(preds, target)) # doctest: +NORMALIZE_WHITESPACE +SKIP - {'rouge1_fmeasure': 0.25, - 'rouge1_precision': 0.25, - 'rouge1_recall': 0.25, - 'rouge2_fmeasure': 0.0, - 'rouge2_precision': 0.0, - 'rouge2_recall': 0.0, - 'rougeL_fmeasure': 0.25, - 'rougeL_precision': 0.25, - 'rougeL_recall': 0.25, - 'rougeLsum_fmeasure': 0.25, - 'rougeLsum_precision': 0.25, - 'rougeLsum_recall': 0.25} - """ - - @requires("text") +class RougeMetric(_ROUGEScore): + @_deprecated_text_metrics(target=_ROUGEScore) def __init__( self, - rouge_newline_sep: bool = False, + newline_sep: bool = False, use_stemmer: bool = False, rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), ): - super().__init__() - - self.rouge_newline_sep = rouge_newline_sep - self.rouge_keys = rouge_keys - self.use_stemmer = use_stemmer - self.aggregator = RougeBatchAggregator() - self.scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=self.use_stemmer) - - for key in rouge_keys: - self.add_state(key, []) - - def update(self, pred_lns: List[str], tgt_lns: List[str]): - for pred, tgt in zip(pred_lns, tgt_lns): - # rougeLsum expects "\n" separated sentences within a summary - if self.rouge_newline_sep: - pred = add_newline_to_end_of_each_sentence(pred) - tgt = add_newline_to_end_of_each_sentence(tgt) - results = self.scorer.score(pred, tgt) - for key, score in results.items(): - score = tensor([score.precision, score.recall, score.fmeasure]) - getattr(self, key).append(score) - - def compute(self) -> Dict[str, float]: - scores = {key: getattr(self, key) for key in self.rouge_keys} - self.aggregator.add_scores(scores) - result = self.aggregator.aggregate() - return format_rouge_results(result) - - def __hash__(self): - # override to hash list objects. - # this is a bug in the upstream pytorch release. - hash_vals = [self.__class__.__name__] - - for key in self._defaults: - value = getattr(self, key) - if isinstance(value, list): - value = tuple(value) - hash_vals.append(value) - - return hash(tuple(hash_vals)) - - -class RougeBatchAggregator(BootstrapAggregator): - """Aggregates rouge scores and provides confidence intervals.""" - - def aggregate(self): - """Override function to wrap the final results in `Score` objects. - - This is due to the scores being replaced with a list of torch tensors. - """ - result = {} - for score_type, scores in self._scores.items(): - # Stack scores into a 2-d matrix of (sample, measure). - score_matrix = np.vstack(tuple(scores)) - # Percentiles are returned as (interval, measure). - percentiles = self._bootstrap_resample(score_matrix) - # Extract the three intervals (low, mid, high). - intervals = tuple(Score(*percentiles[j, :]) for j in range(3)) - result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2]) - return result - - def add_scores(self, scores): - self._scores = scores - - -def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, float]: - flattened_result = {} - for rouge_key, rouge_aggregate_score in result.items(): - for stat in ["precision", "recall", "fmeasure"]: - mid = rouge_aggregate_score.mid - score = round(getattr(mid, stat), decimal_places) - flattened_result[f"{rouge_key}_{stat}"] = score - return flattened_result + void(newline_sep, use_stemmer, rouge_keys) diff --git a/flash/text/seq2seq/core/utils.py b/flash/text/seq2seq/core/utils.py deleted file mode 100644 index e48248754c..0000000000 --- a/flash/text/seq2seq/core/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import re - -from pytorch_lightning.utilities import _module_available - -nltk = None -if _module_available("nltk"): - import nltk - - nltk.download("punkt", quiet=True) - - -def add_newline_to_end_of_each_sentence(x: str) -> str: - """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" - re.sub("", "", x) # remove pegasus newline char - assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)" - return "\n".join(nltk.sent_tokenize(x)) diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 261de7a03a..6067eb5ceb 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -14,9 +14,9 @@ from typing import Any, Dict, List, Optional import torch +from torchmetrics import ROUGEScore from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE -from flash.text.seq2seq.core.metrics import RougeMetric from flash.text.seq2seq.core.model import Seq2SeqTask @@ -72,8 +72,8 @@ def __init__( num_beams=num_beams, enable_ort=enable_ort, ) - self.rouge = RougeMetric( - rouge_newline_sep=rouge_newline_sep, + self.rouge = ROUGEScore( + newline_sep=rouge_newline_sep, use_stemmer=use_stemmer, ) diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index d57f7775df..553adb6b7a 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -13,8 +13,9 @@ # limitations under the License. from typing import Any, Dict, List, Optional +from torchmetrics import BLEUScore + from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE -from flash.text.seq2seq.core.metrics import BLEUScore from flash.text.seq2seq.core.model import Seq2SeqTask diff --git a/requirements.txt b/requirements.txt index e29b34f8ac..5e391bf5be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ packaging numpy torch>=1.7.1 -torchmetrics>=0.4.0,!=0.5.1 +torchmetrics>=0.5.1 pytorch-lightning>=1.4.0 pyDeprecate pandas<1.3.0 diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 75611db808..aba24a7ef5 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,5 +1,5 @@ -rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock transformers>=4.5 +torchmetrics[text]>=0.5.1 datasets>=1.8,<1.13 diff --git a/tests/text/seq2seq/core/test_metrics.py b/tests/text/seq2seq/core/test_metrics.py deleted file mode 100644 index c16f828c37..0000000000 --- a/tests/text/seq2seq/core/test_metrics.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch - -from flash.text.seq2seq.core.metrics import BLEUScore, RougeMetric -from tests.helpers.utils import _TEXT_TESTING - - -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_rouge(): - preds = "My name is John".split() - target = "Is your name John".split() - metric = RougeMetric() - assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4) - - -@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) -def test_bleu_score(smooth, expected): - translate_corpus = ["the cat is on the mat".split()] - reference_corpus = [["there is a cat on the mat".split(), "a cat is on the mat".split()]] - metric = BLEUScore(smooth=smooth) - assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4)