Skip to content

Commit

Permalink
wip on validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Jun 29, 2023
1 parent af3ec09 commit b7b857c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 14 deletions.
5 changes: 5 additions & 0 deletions src/ome_autogen/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
("Instrument", f"{MIXIN_MODULE}._instrument.InstrumentMixin", False),
("Reference", f"{MIXIN_MODULE}._reference.ReferenceMixin", True),
("BinData", f"{MIXIN_MODULE}._bin_data.BinDataMixin", True),
(
"StructuredAnnotations",
f"{MIXIN_MODULE}._structured_annotations.StructuredAnnotationsMixin",
True,
),
]

OUTPUT_PACKAGE = "ome_types.model.ome_2016_06"
Expand Down
12 changes: 9 additions & 3 deletions src/ome_types/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def to_dict(source: OME | Path | str | bytes) -> dict[str, Any]:
)


def _class_factory(cls: type, kwargs: Any):
kwargs.setdefault("validation", "strict")
return cls(**kwargs)


def from_xml(
xml: Path | str | bytes,
*,
Expand All @@ -67,10 +72,11 @@ def from_xml(
xml = str(xml)

OME_type = _get_ome(xml)
parser = XmlParser(**(parser_kwargs or {}))
parser_kwargs = {"config": ParserConfig(class_factory=_class_factory)}
_parser = XmlParser(**(parser_kwargs or {}))
if isinstance(xml, bytes):
return parser.from_bytes(xml, OME_type)
return parser.parse(xml, OME_type)
return _parser.from_bytes(xml, OME_type)
return _parser.parse(xml, OME_type)


def to_xml(
Expand Down
33 changes: 22 additions & 11 deletions src/ome_types/_mixins/_base_type.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import contextlib
import re
import warnings
from datetime import datetime
from enum import Enum
from textwrap import indent
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Set, Type, cast

import pydantic
from pydantic import BaseModel, ValidationError, validator
from pydantic import BaseModel, PrivateAttr, ValidationError, validator

from ome_types.units import ureg

Expand Down Expand Up @@ -49,9 +51,11 @@ class Config:
underscore_attrs_are_private = True
use_enum_values = False
validate_all = True
validation_mode: str = "strict"

# allow use with weakref
__slots__: ClassVar[Set[str]] = {"__weakref__"} # type: ignore
_validation_mode: str = PrivateAttr("strict")

def __init__(__pydantic_self__, **data: Any) -> None:
if "id" in __pydantic_self__.__fields__:
Expand All @@ -68,7 +72,6 @@ def __init__(__pydantic_self__, **data: Any) -> None:
data[key] = _AUTO_SEQUENCE
else:
data.pop(key, None)
print(data)
super().__init__(**data)

def __init_subclass__(cls) -> None:
Expand Down Expand Up @@ -111,30 +114,38 @@ def __repr__(self) -> str:

@validator("id", pre=True, always=True, check_fields=False)
@classmethod
def validate_id(cls, value: Any) -> Any:
def _validate_id(cls, value: Any, values=None, config=None) -> Any:
"""Pydantic validator for ID fields in OME models.
If no value is provided, this validator provides and integer ID, and stores the
maximum previously-seen value on the class.
"""
# get the required LSID field from the annotation
current_count = _COUNTERS.setdefault(cls, 0)
# FIXME: clean this up
id_field = cls.__fields__["id"]
id_regex = cast(str, id_field.field_info.regex)
id_name = id_regex.split(":")[-3]
current_count = _COUNTERS.setdefault(cls, -1)
if isinstance(value, str):
# parse the id and update the counter
v_id = value.rsplit(":", 1)[-1]
*name, v_id = value.rsplit(":", 1)
if not re.match(id_regex, value):
warnings.warn(f"Casting invalid {id_name}ID", stacklevel=2)
return v_id if v_id.isnumeric() else _AUTO_SEQUENCE

with contextlib.suppress(ValueError):
_COUNTERS[cls] = max(current_count, int(v_id))
return value

if isinstance(value, int):
_COUNTERS[cls] = max(current_count, value)
return f"{cls.__name__}:{value}"

if value is _AUTO_SEQUENCE:
elif value is _AUTO_SEQUENCE:
# just increment the counter
_COUNTERS[cls] += 1
return f"{cls.__name__}:{_COUNTERS[cls]}"
value = _COUNTERS[cls]
else:
raise ValueError(f"Invalid ID value: {value!r}, {type(value)}")

raise ValueError(f"Invalid ID value: {value!r}")
return f"{id_name}:{value}"

# @classmethod
# def snake_name(cls) -> str:
Expand Down
26 changes: 26 additions & 0 deletions src/ome_types/_mixins/_structured_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING, Iterable, Sequence

from ome_types.model.ome_2016_06 import Annotation

from ._base_type import OMEType

if TYPE_CHECKING:
from ome_types.model.ome_2016_06 import StructuredAnnotations


class StructuredAnnotationsMixin(OMEType, Sequence):
def _iter_annotations(self: "StructuredAnnotations") -> Iterable[Annotation]:
for x in self.__fields__.values():
if issubclass(x.type_, Annotation):
yield from getattr(self, x.name)
else:
breakpoint()

def __getitem__(self: "StructuredAnnotations", key) -> Annotation:
return list(self._iter_annotations())[key]

def __len__(self) -> int:
return len(list(self._iter_annotations()))

def append(self: "StructuredAnnotations", value: Annotation) -> None:
raise NotImplementedError

0 comments on commit b7b857c

Please sign in to comment.