Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Jan 8, 2024
1 parent 8a8b384 commit f0163f8
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 55 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.9"
pytorch-ie = ">=0.29.5,<0.30.0"
# requires https://github.com/ArneBinder/pytorch-ie/pull/392
#pytorch-ie = ">=0.29.5,<0.30.0"
pytorch-ie = {git = "https://github.com/arnebinder/pytorch-ie.git", branch = "taskmodule_configure_model_metric"}
pytorch-lightning = "^2.1.0"
torchmetrics = "^1"
pytorch-crf = ">=0.7.2"
Expand Down
14 changes: 3 additions & 11 deletions src/pie_modules/models/simple_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torchmetrics import Metric
from transformers import PreTrainedModel, get_linear_schedule_with_warmup

from pie_modules.taskmodules.common import HasConfigureMetric
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,16 +109,9 @@ def __init__(
strict=False,
config=taskmodule_config,
)
# TODO: remove this check when TaskModule.build_metric() is implemented
if not isinstance(taskmodule, HasConfigureMetric):
logger.warning(
f"taskmodule {taskmodule} does not implement HasConfigureMetric interface, no metrics will be "
f"used."
)
else:
metrics = {stage: taskmodule.configure_metric(stage) for stage in metric_splits}
# keep only the metrics that are not None
self.metrics = {k: v for k, v in metrics.items() if v is not None}
metrics = {stage: taskmodule.configure_model_metric(stage) for stage in metric_splits}
# keep only the metrics that are not None
self.metrics = {k: v for k, v in metrics.items() if v is not None}

def predict(self, inputs, **kwargs) -> torch.LongTensor:
is_training = self.training
Expand Down
9 changes: 0 additions & 9 deletions src/pie_modules/taskmodules/common/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,3 @@ def decode_annotations(
self, encoding: ACE, metadata: Optional[Dict[str, Any]] = None
) -> Tuple[Dict[str, List[Annotation]], Any]:
pass


# TODO: move into pytorch_ie
class HasConfigureMetric(abc.ABC):
"""Interface for modules that can configure a metric."""

@abc.abstractmethod
def configure_metric(self, stage: Optional[str] = None) -> Metric:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch
from pytorch_ie import AnnotationLayer, Document
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation, TaskEncoding, TaskModule
from pytorch_ie.core.taskmodule import (
InputEncoding,
Expand All @@ -39,7 +39,7 @@

from ..document.processing import token_based_document_to_text_based, tokenize_document
from ..utils import resolve_type
from .common import BatchableMixin, HasConfigureMetric, HasDecodeAnnotations
from .common import BatchableMixin, HasDecodeAnnotations
from .common.interfaces import DecodingException
from .common.metrics import AnnotationLayerMetric
from .pointer_network.annotation_encoder_decoder import (
Expand Down Expand Up @@ -121,7 +121,6 @@ def get_first_occurrence_index(

@TaskModule.register()
class PointerNetworkTaskModuleForEnd2EndRE(
HasConfigureMetric,
HasDecodeAnnotations[TaskOutputType],
TaskModule[
DocumentType,
Expand Down Expand Up @@ -388,7 +387,7 @@ def pointer_offset(self) -> int:
def target_ids(self) -> Set[int]:
return set(range(self.pointer_offset))

def configure_metric(self, stage: Optional[str] = None) -> Metric:
def configure_model_metric(self, stage: Optional[str] = None) -> Optional[Metric]:
return AnnotationLayerMetric(
taskmodule=self,
layer_names=self.layer_names,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_simple_generative_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_sciarg_predict(trained_model, sciarg_batch_truncated, loaded_taskmodule
assert prediction.tolist() == expected_prediction.tolist()

# calculate metrics just to check the scores
metric = loaded_taskmodule.configure_metric()
metric = loaded_taskmodule.configure_model_metric()
metric.update(prediction, targets["labels"])
values = metric.compute()
assert values == {
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_sciarg_predict_with_position_id_pattern(sciarg_batch_truncated, loaded_
assert prediction.tolist() == expected_prediction.tolist()

# calculate metrics just to check the scores
metric = loaded_taskmodule.configure_metric()
metric = loaded_taskmodule.configure_model_metric()
metric.update(prediction, targets["labels"])

values = metric.compute()
Expand Down
26 changes: 0 additions & 26 deletions tests/taskmodules/common/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, Dict, List, Set, Tuple

from pytorch_ie.annotations import Span
from torchmetrics import Metric

from pie_modules.taskmodules.common import (
AnnotationEncoderDecoder,
HasConfigureMetric,
HasDecodeAnnotations,
)

Expand Down Expand Up @@ -56,27 +54,3 @@ def decode_annotations(
{"spans": [Span(start=1, end=2)]},
{"too_long": True},
)


def test_has_build_metric():
"""Test the HasBuildMetric class."""

class MyMetric(Metric):
"""A dummy metric class."""

def update(self, x):
pass

def compute(self):
return 0

class MyMetricBuilder(HasConfigureMetric):
"""A class that uses the HasBuildMetric class."""

def configure_metric(self, stage: str = None):
return MyMetric()

my_builder = MyMetricBuilder()
my_metric = my_builder.configure_metric()
assert isinstance(my_metric, Metric)
assert my_metric.compute() == 0
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ def test_annotations_from_output(task_encodings, task_outputs, taskmodule):
)


def test_configure_metric(taskmodule):
metric = taskmodule.configure_metric()
def test_configure_model_metric(taskmodule):
metric = taskmodule.configure_model_metric()
assert metric is not None
assert isinstance(metric, AnnotationLayerMetric)

Expand Down

0 comments on commit f0163f8

Please sign in to comment.