Skip to content

Commit

Permalink
Remove (Multi)LabeledMultiSpan annotations (#405)
Browse files Browse the repository at this point in the history
* remove LabeledMultiSpan annotation

* remove MultiLabeledMultiSpan annotation

* remove _post_init_multi_span()
  • Loading branch information
ArneBinder authored Feb 19, 2024
1 parent c16a510 commit eef7905
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 86 deletions.
27 changes: 0 additions & 27 deletions src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def _post_init_multi_label(self):
)


def _post_init_multi_span(self):
if isinstance(self.slices, list):
object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices))


def _post_init_arguments_and_roles(self):
if len(self.arguments) != len(self.roles):
raise ValueError(
Expand Down Expand Up @@ -92,28 +87,6 @@ def __post_init__(self) -> None:
_post_init_multi_label(self)


@dataclass(eq=True, frozen=True)
class LabeledMultiSpan(Annotation):
slices: Tuple[Tuple[int, int], ...]
label: str
score: float = field(default=1.0, compare=False)

def __post_init__(self) -> None:
_post_init_multi_span(self)
_post_init_single_label(self)


@dataclass(eq=True, frozen=True)
class MultiLabeledMultiSpan(Annotation):
slices: Tuple[Tuple[int, int], ...]
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = field(default=None, compare=False)

def __post_init__(self) -> None:
_post_init_multi_span(self)
_post_init_multi_label(self)


@dataclass(eq=True, frozen=True)
class BinaryRelation(Annotation):
head: Annotation
Expand Down
59 changes: 0 additions & 59 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from pytorch_ie.annotations import (
BinaryRelation,
Label,
LabeledMultiSpan,
LabeledSpan,
MultiLabel,
MultiLabeledBinaryRelation,
MultiLabeledMultiSpan,
MultiLabeledSpan,
Span,
)
Expand Down Expand Up @@ -126,63 +124,6 @@ def test_multilabeled_span():
MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3))


def test_labeled_multi_span():
labeled_multi_span1 = LabeledMultiSpan(slices=((1, 2), (3, 4)), label="label1")
assert labeled_multi_span1.slices == ((1, 2), (3, 4))
assert labeled_multi_span1.label == "label1"
assert labeled_multi_span1.score == pytest.approx(1.0)

labeled_multi_span2 = LabeledMultiSpan(
slices=((5, 6), (7, 8)),
label="label2",
score=0.5,
)
assert labeled_multi_span2.slices == ((5, 6), (7, 8))
assert labeled_multi_span2.label == "label2"
assert labeled_multi_span2.score == pytest.approx(0.5)

assert labeled_multi_span2.asdict() == {
"_id": labeled_multi_span2._id,
"slices": ((5, 6), (7, 8)),
"label": "label2",
"score": 0.5,
}

_test_annotation_reconstruction(labeled_multi_span2)


def test_multilabeled_multi_span():
multilabeled_multi_span1 = MultiLabeledMultiSpan(
slices=((1, 2), (3, 4)), label=("label1", "label2")
)
assert multilabeled_multi_span1.slices == ((1, 2), (3, 4))
assert multilabeled_multi_span1.label == ("label1", "label2")
assert multilabeled_multi_span1.score == pytest.approx((1.0, 1.0))

multilabeled_multi_span2 = MultiLabeledMultiSpan(
slices=((5, 6), (7, 8)), label=("label3", "label4"), score=(0.4, 0.5)
)
assert multilabeled_multi_span2.slices == ((5, 6), (7, 8))
assert multilabeled_multi_span2.label == ("label3", "label4")
assert multilabeled_multi_span2.score == pytest.approx((0.4, 0.5))

assert multilabeled_multi_span2.asdict() == {
"_id": multilabeled_multi_span2._id,
"slices": ((5, 6), (7, 8)),
"label": ("label3", "label4"),
"score": (0.4, 0.5),
}

_test_annotation_reconstruction(multilabeled_multi_span2)

with pytest.raises(
ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.")
):
MultiLabeledMultiSpan(
slices=((9, 10), (11, 12)), label=("label5", "label6"), score=(0.1, 0.2, 0.3)
)


def test_binary_relation():
head = Span(start=1, end=2)
tail = Span(start=3, end=4)
Expand Down

0 comments on commit eef7905

Please sign in to comment.