Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement MultiModalSequenceTaggingTaskModule #233

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,26 @@ def fromdict(
tmp_dct["tail"] = resolve_annotation(tmp_dct["tail"], store=annotation_store)

return cls(**tmp_dct)


def _post_init_bbox(self):
if not isinstance(self.bbox, tuple):
object.__setattr__(self, "bbox", tuple(self.bbox))
if not len(self.bbox) == 4:
raise ValueError("bounding box has to consist of 4 values.")


@dataclass(eq=True, frozen=True)
class OcrAnnotation(Annotation):
bbox: Tuple[int, int, int, int]
text: str

def __post_init__(self) -> None:
_post_init_bbox(self)


class OcrLabeledSpan(LabeledSpan):
def __str__(self) -> str:
if self.target is None:
return ""
return str([ocr_annotation.text for ocr_annotation in self.target[self.start : self.end]])
32 changes: 30 additions & 2 deletions src/pytorch_ie/documents.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
import dataclasses
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

from pytorch_ie.core import Document
from pytorch_ie.annotations import OcrAnnotation, OcrLabeledSpan
from pytorch_ie.core import AnnotationList, Document, annotation_field


@dataclasses.dataclass
class TextDocument(Document):
text: str
id: Optional[str] = None
metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class OcrDocument(Document):
# 3d: channel x row x col
image: Tuple[
Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...], Tuple[Tuple[int, ...], ...]
]
image_width: int
image_height: int
image_format: str
words: AnnotationList[OcrAnnotation] = annotation_field(target="image")

def __post_init__(self):
# when creating from a dataset, this comes in as a list (json does not know tuples)
if not isinstance(self.image, tuple):
object.__setattr__(
self,
"image",
tuple(tuple(tuple(row) for row in channel) for channel in self.image),
)
super().__post_init__()


@dataclasses.dataclass
class OcrDocumentWithEntities(OcrDocument):
entities: AnnotationList[OcrLabeledSpan] = annotation_field(target="words")
1 change: 1 addition & 0 deletions src/pytorch_ie/taskmodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .multi_modal_sequence_tagging import MultiModalSequenceTaggingTaskModule
from .simple_transformer_text_classification import SimpleTransformerTextClassificationTaskModule
from .transformer_re_text_classification import TransformerRETextClassificationTaskModule
from .transformer_seq2seq import TransformerSeq2SeqTaskModule
Expand Down
Loading