Skip to content

Commit

Permalink
Make light sources and shapes work (#17)
Browse files Browse the repository at this point in the history
* Make light sources and shapes work

* Fix MetadataOnly and empty BinData parsing

* rename MyConverter -> OMEConverter

Co-authored-by: Talley Lambert <[email protected]>
  • Loading branch information
jmuhlich and tlambert03 authored Jul 22, 2020
1 parent ee9966e commit 4012024
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 44 deletions.
191 changes: 158 additions & 33 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import os
import re
import shutil
from textwrap import dedent, indent
from pathlib import Path
from typing import Generator, Iterable, List, Set, Union, Iterator, Tuple
from typing import Generator, Iterable, List, Set, Union, Iterator, Tuple, Optional
from dataclasses import dataclass

import black
import isort
Expand All @@ -19,29 +21,140 @@
XsdComponent,
)

TIFF_UUID = """
from typing import Optional
from .simple_types import UniversallyUniqueIdentifier
# FIXME: Work out a better way to implement these override hacks.


@dataclass
class UUID:
file_name: str
value: UniversallyUniqueIdentifier
"""

# FIXME: hacks
# OVERIDE is a mapping of XSD TypeName to a tuple of desired output for that type
# where the tuple is (type, default value, imports/strings needed)
OVERRIDE = {
"MetadataOnly": ("bool", "False", None),
"XMLAnnotation": ("Optional[str]", "None", "from typing import Optional\n\n",),
"BinData/Length": ("int", None, None),
"ROI/Union": (
"List[Shape] = Field(..., min_items=1)",
None,
"from pydantic import Field\nfrom .shape import Shape",
class Override:
type_: str
default: Optional[str] = None
imports: Optional[str] = None
body: Optional[str] = None

def __post_init__(self) -> None:
if self.imports:
self.imports = dedent(self.imports)
if self.body:
self.body = indent(dedent(self.body), " " * 4)


# Maps XSD TypeName to Override configuration, used to control output for that type.
OVERRIDES = {
"MetadataOnly": Override(type_="bool", default="False"),
"XMLAnnotation": Override(
type_="Optional[str]", default="None", imports="from typing import Optional",
),
"BinData/Length": Override(type_="int"),
# FIXME: hard-coded LightSource subclass lists
"Instrument/LightSourceGroup": Override(
type_="List[LightSource]",
default="field(default_factory=list)",
imports="""
from typing import Dict, Union, Any
from pydantic import validator
from .light_source import LightSource
from .laser import Laser
from .arc import Arc
from .filament import Filament
from .light_emitting_diode import LightEmittingDiode
from .generic_excitation_source import GenericExcitationSource
_light_source_types: Dict[str, type] = {
"laser": Laser,
"arc": Arc,
"filament": Filament,
"light_emitting_diode": LightEmittingDiode,
"generic_excitation_source": GenericExcitationSource,
}
""",
body="""
@validator("light_source_group", pre=True, each_item=True)
def validate_light_source_group(
cls, value: Union[LightSource, Dict[Any, Any]]
) -> LightSource:
if isinstance(value, LightSource):
return value
elif isinstance(value, dict):
try:
_type = value.pop("_type")
except KeyError:
raise ValueError(
"dict initialization requires _type"
) from None
try:
light_source_cls = _light_source_types[_type]
except KeyError:
raise ValueError(
f"unknown LightSource type '{_type}'"
) from None
return light_source_cls(**value)
else:
raise ValueError("invalid type for light_source_group values")
""",
),
"ROI/Union": Override(
type_="List[Shape]",
default="field(default_factory=list)",
imports="""
from typing import Dict, Union, Any
from pydantic import validator
from .shape import Shape
from .point import Point
from .line import Line
from .rectangle import Rectangle
from .ellipse import Ellipse
from .polyline import Polyline
from .polygon import Polygon
from .mask import Mask
from .label import Label
_shape_types: Dict[str, type] = {
"point": Point,
"line": Line,
"rectangle": Rectangle,
"ellipse": Ellipse,
"polyline": Polyline,
"polygon": Polygon,
"mask": Mask,
"label": Label,
}
""",
body="""
@validator("union", pre=True, each_item=True)
def validate_union(
cls, value: Union[Shape, Dict[Any, Any]]
) -> Shape:
if isinstance(value, Shape):
return value
elif isinstance(value, dict):
try:
_type = value.pop("_type")
except KeyError:
raise ValueError(
"dict initialization requires _type"
) from None
try:
shape_cls = _shape_types[_type]
except KeyError:
raise ValueError(f"unknown Shape type '{_type}'") from None
return shape_cls(**value)
else:
raise ValueError("invalid type for union values")
""",
),
"TiffData/UUID": Override(
type_="Optional[UUID]",
default="None",
imports="""
from typing import Optional
from .simple_types import UniversallyUniqueIdentifier
@dataclass
class UUID:
file_name: str
value: UniversallyUniqueIdentifier
""",
),
"TiffData/UUID": ("Optional[UUID]", "None", TIFF_UUID),
}


Expand Down Expand Up @@ -138,6 +251,8 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
f'required argument: {m.identifier!r}")'
]

lines += members.body()

return lines


Expand Down Expand Up @@ -237,12 +352,12 @@ def key(self) -> str:
p = p.parent
name = p.local_name
name = f"{name}/{self.component.local_name}"
if name not in OVERRIDE and self.component.local_name in OVERRIDE:
if name not in OVERRIDES and self.component.local_name in OVERRIDES:
return self.component.local_name
return name

def locals(self) -> Set[str]:
if self.key in OVERRIDE:
if self.key in OVERRIDES:
return set()
if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
return set()
Expand All @@ -259,8 +374,8 @@ def locals(self) -> Set[str]:
return locals_

def imports(self) -> Set[str]:
if self.key in OVERRIDE:
_imp = OVERRIDE[self.key][2]
if self.key in OVERRIDES:
_imp = OVERRIDES[self.key].imports
return set([_imp]) if _imp else set()
if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
return set(["from typing import Any"])
Expand All @@ -284,16 +399,21 @@ def imports(self) -> Set[str]:
imports.add(f"from .simple_types import {self.type.local_name}")

if self.component.ref is not None:
if self.component.ref.local_name not in OVERRIDE:
if self.component.ref.local_name not in OVERRIDES:
imports.add(local_import(self.component.ref.local_name))

return imports

def body(self) -> str:
if self.key in OVERRIDES:
return OVERRIDES[self.key].body or ""
return ""

@property
def type_string(self) -> str:
"""single type, without Optional, etc..."""
if self.key in OVERRIDE:
return OVERRIDE[self.key][0]
if self.key in OVERRIDES:
return OVERRIDES[self.key].type_
if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
return "Any"
if self.component.ref is not None:
Expand Down Expand Up @@ -322,7 +442,7 @@ def type_string(self) -> str:
@property
def full_type_string(self) -> str:
"""full type, like Optional[List[str]]"""
if self.key in OVERRIDE and self.type_string:
if self.key in OVERRIDES and self.type_string:
return f": {self.type_string}"
type_string = self.type_string
if not type_string:
Expand All @@ -335,8 +455,8 @@ def full_type_string(self) -> str:

@property
def default_val_str(self) -> str:
if self.key in OVERRIDE:
default = OVERRIDE[self.key][1]
if self.key in OVERRIDES:
default = OVERRIDES[self.key].default
return f" = {default}" if default else ""
if not self.is_optional:
return ""
Expand Down Expand Up @@ -414,6 +534,11 @@ def locals(self) -> List[str]:
return list(set.union(*[m.locals() for m in self._members]))
return []

def body(self) -> List[str]:
if self._members:
return [m.body() for m in self._members]
return []

def has_non_default_args(self) -> bool:
return any(not m.default_val_str for m in self._members)

Expand Down Expand Up @@ -552,7 +677,7 @@ def convert_schema(url: str = _url, target_dir: str = _target) -> None:
init_imports = []
simples: List[GlobalElem] = []
for elem in sorted(schema.types.values(), key=sort_types):
if elem.local_name in OVERRIDE:
if elem.local_name in OVERRIDES:
continue
converter = GlobalElem(elem)
if not elem.is_complex():
Expand All @@ -563,7 +688,7 @@ def convert_schema(url: str = _url, target_dir: str = _target) -> None:
converter.write(filename=targetfile)

for elem in schema.elements.values():
if elem.local_name in OVERRIDE:
if elem.local_name in OVERRIDES:
continue
converter = GlobalElem(elem)
targetfile = os.path.join(target_dir, converter.fname)
Expand Down
62 changes: 58 additions & 4 deletions src/ome_types/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ def get_schema(xml: str) -> xmlschema.XMLSchema:
with open(local, "rb") as f:
__cache__[version] = pickle.load(f)
else:
__cache__[version] = xmlschema.XMLSchema(url)
schema = xmlschema.XMLSchema(url)

# FIXME Hack to work around xmlschema poor support for keyrefs to
# substitution groups
ns = "{http://www.openmicroscopy.org/Schemas/OME/2016-06}"
ls_sgs = schema.maps.substitution_groups[f"{ns}LightSourceGroup"]
ls_id_maps = schema.maps.identities[f"{ns}LightSourceIDKey"]
ls_id_maps.elements = {e: None for e in ls_sgs}

__cache__[version] = schema
with open(local, "wb") as f:
pickle.dump(__cache__[version], f)
return __cache__[version]
Expand All @@ -40,7 +49,7 @@ def validate(xml: str, schema: Optional[xmlschema.XMLSchema] = None) -> None:
schema.validate(xml)


class MyConverter(XMLSchemaConverter):
class OMEConverter(XMLSchemaConverter):
def __init__(self, namespaces: Optional[Dict[str, Any]] = None):
super().__init__(namespaces, attr_prefix="")

Expand All @@ -51,15 +60,60 @@ def map_qname(self, qname: str) -> str:
def element_decode(self, data, xsd_element, xsd_type=None, level=0): # type: ignore
"""Converts a decoded element data to a data structure."""
result = super().element_decode(data, xsd_element, xsd_type, level)
if result and "$" in result:
if isinstance(result, dict) and "$" in result:
result["value"] = result.pop("$")
# FIXME: Work out a better way to deal with concrete extensions of
# abstract types.
if xsd_element.local_name == "MetadataOnly":
result = True
elif xsd_element.local_name == "BinData":
if result["length"] == 0 and "value" not in result:
result["value"] = ""
elif xsd_element.local_name == "Instrument":
light_sources = []
for _type in (
"laser",
"arc",
"filament",
"light_emitting_diode",
"generic_excitation_source",
):
if _type in result:
values = result.pop(_type)
if isinstance(values, dict):
values = [values]
for v in values:
v["_type"] = _type
light_sources.extend(values)
if light_sources:
result["light_source_group"] = light_sources
elif xsd_element.local_name == "Union":
shapes = []
for _type in (
"point",
"line",
"rectangle",
"ellipse",
"polyline",
"polygon",
"mask",
"label",
):
if _type in result:
values = result.pop(_type)
if isinstance(values, dict):
values = [values]
for v in values:
v["_type"] = _type
shapes.extend(values)
result = shapes
return result


def to_dict( # type: ignore
xml: str,
schema: Optional[xmlschema.XMLSchema] = None,
converter: XMLSchemaConverter = MyConverter,
converter: XMLSchemaConverter = OMEConverter,
**kwargs,
) -> Dict[str, Any]:
schema = schema or get_schema(xml)
Expand Down
7 changes: 0 additions & 7 deletions testing/test_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@ def model(tmp_path_factory, request):


SHOULD_FAIL = {
"ROI",
"commentannotation",
"filter",
"hcs",
"instrument",
"instrument-units-alternate",
"instrument-units-default",
"mapannotation",
"metadata-only",
"spim",
"tagannotation",
"timestampannotation",
Expand Down

0 comments on commit 4012024

Please sign in to comment.