From 05a15c43746aecfda98f5935e7f4d1276050798f Mon Sep 17 00:00:00 2001 From: Jeremy Muhlich Date: Mon, 24 Aug 2020 11:41:02 -0400 Subject: [PATCH] Update References with pointer to target object (#34) * Update References with pointer to target object * Suppress internal trailing-underscore fields in repr * Fix handling of multiple references to the same ID --- src/ome_autogen.py | 84 +++++++++++++++++++++++++++++++----- src/ome_types/dataclasses.py | 4 +- src/ome_types/util.py | 50 +++++++++++++++++++++ 3 files changed, 127 insertions(+), 11 deletions(-) create mode 100644 src/ome_types/util.py diff --git a/src/ome_autogen.py b/src/ome_autogen.py index 4ddf81b0..fb3d790f 100644 --- a/src/ome_autogen.py +++ b/src/ome_autogen.py @@ -50,6 +50,22 @@ def __post_init__(self) -> None: self.body = indent(dedent(self.body), " " * 4) +@dataclass +class ClassOverride: + base_type: Optional[str] = None + imports: Optional[str] = None + fields: Optional[str] = None + body: Optional[str] = None + + def __post_init__(self) -> None: + if self.imports: + self.imports = dedent(self.imports) + if self.fields: + self.fields = indent(dedent(self.fields), " " * 4) + if self.body: + self.body = indent(dedent(self.body), " " * 4) + + # Maps XSD TypeName to Override configuration, used to control output for that type. OVERRIDES = { "MetadataOnly": Override(type_="bool", default="False"), @@ -224,6 +240,48 @@ class UUID: } +# Maps XSD TypeName to ClassOverride configuration, used to control dataclass +# generation. +CLASS_OVERRIDES = { + "OME": ClassOverride( + imports=""" + from typing import Any + import weakref + from ome_types.util import collect_ids, collect_references + """, + body=""" + def __post_init_post_parse__(self: Any, *args: Any) -> None: + ids = collect_ids(self) + for ref in collect_references(self): + ref.ref_ = weakref.ref(ids[ref.id]) + """, + ), + "Reference": ClassOverride( + imports=""" + from dataclasses import field + from typing import Any, Optional + from .simple_types import LSID + """, + # FIXME Figure out typing for ref_ (weakref). Even with the "correct" + # typing, Pydantic has a problem. + fields=""" + id: LSID + ref_: Any = field(default=None, init=False) + """, + # FIXME Could make `ref` abstract and implement stronger-typed overrides + # in subclasses. + body=""" + @property + def ref(self) -> Any: + if self.ref_ is None: + raise ValueError("references not yet resolved on root OME object") + return self.ref_() + """, + ), + "BinData": ClassOverride(base_type="object", fields="value: str"), +} + + def black_format(text: str, line_length: int = 79) -> str: return black.format_str(text, mode=black.FileMode(line_length=line_length)) @@ -275,16 +333,20 @@ def local_import(item_type: str) -> str: def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]: + class_override = CLASS_OVERRIDES.get(component.local_name, None) lines = ["from ome_types.dataclasses import ome_dataclass", ""] - # FIXME: Refactor to remove BinData special-case. - if component.local_name == "BinData": - base_type = None - elif isinstance(component, XsdType): + if isinstance(component, XsdType): base_type = component.base_type else: base_type = component.type.base_type - if base_type and not hasattr(base_type, "python_type"): + if class_override and class_override.base_type: + if class_override.base_type == "object": + base_name = "" + else: + base_name = f"({class_override.base_type})" + base_type = None + elif base_type and not hasattr(base_type, "python_type"): base_name = f"({base_type.local_name})" if base_type.is_complex(): lines += [local_import(base_type.local_name)] @@ -292,6 +354,8 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]: lines += [f"from .simple_types import {base_type.local_name}"] else: base_name = "" + if class_override and class_override.imports: + lines.append(class_override.imports) base_members = set() _basebase = base_type @@ -310,9 +374,8 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]: lines[0] += ", AUTO_SEQUENCE" lines += ["@ome_dataclass", f"class {component.local_name}{base_name}:"] - # FIXME: Refactor to remove BinData special-case. - if component.local_name == "BinData": - lines.append(" value: str") + if class_override and class_override.fields: + lines.append(class_override.fields) lines += members.lines( indent=1, force_defaults=" = EMPTY # type: ignore" @@ -320,6 +383,8 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]: else None, ) + if class_override and class_override.body: + lines.append(class_override.body) lines += members.body() return lines @@ -755,8 +820,7 @@ def _abstract_class(self) -> List[str]: return lines def lines(self) -> str: - # FIXME: Refactor to remove BinData special-case. - if not self.is_complex and self.elem.local_name != "BinData": + if not self.is_complex: lines = self._simple_class() elif self.elem.abstract: lines = self._abstract_class() diff --git a/src/ome_types/dataclasses.py b/src/ome_types/dataclasses.py index 612f979d..475b641b 100644 --- a/src/ome_types/dataclasses.py +++ b/src/ome_types/dataclasses.py @@ -92,7 +92,7 @@ def new_post_init(self: Any, *args: Any) -> None: def modify_repr(_cls: Type[Any]) -> None: """Improved dataclass repr function. - Only show non-default values, and summarize containers. + Only show non-default non-internal values, and summarize containers. """ # let classes still create their own if _cls.__repr__ is not object.__repr__: @@ -102,6 +102,8 @@ def new_repr(self: Any) -> str: name = self.__class__.__qualname__ lines = [] for f in sorted(fields(self), key=lambda f: f.name not in ("name", "id")): + if f.name.endswith("_"): + continue # https://github.com/python/mypy/issues/6910 if f.default_factory is not MISSING: # type: ignore default = f.default_factory() # type: ignore diff --git a/src/ome_types/util.py b/src/ome_types/util.py new file mode 100644 index 00000000..f7f0271a --- /dev/null +++ b/src/ome_types/util.py @@ -0,0 +1,50 @@ +import dataclasses +import weakref +from typing import Any, Dict, List + +from .model.simple_types import LSID +from .model.reference import Reference + + +def collect_references(value: Any) -> List[Reference]: + """Return a list of all References contained in value. + + Recursively walks all dataclass fields and iterates over lists. The base + case is when value is either a Reference object, or an uninteresting type + that we don't need to inspect further. + + """ + references: List[Reference] = [] + if isinstance(value, Reference): + references.append(value) + elif isinstance(value, list): + for v in value: + references.extend(collect_references(v)) + elif dataclasses.is_dataclass(value): + for f in dataclasses.fields(value): + references.extend(collect_references(getattr(value, f.name))) + # Do nothing for uninteresting types + return references + + +def collect_ids(value: Any) -> Dict[LSID, Any]: + """Return a map of all model objects contained in value, keyed by id. + + Recursively walks all dataclass fields and iterates over lists. The base + case is when value is neither a dataclass nor a list. + + """ + ids: Dict[LSID, Any] = {} + if isinstance(value, list): + for v in value: + ids.update(collect_ids(v)) + elif dataclasses.is_dataclass(value): + for f in dataclasses.fields(value): + if f.name == "id" and not isinstance(value, Reference): + # We don't need to recurse on the id string, so just record it + # and move on. + ids[value.id] = value + else: + ids.update(collect_ids(getattr(value, f.name))) + # Do nothing for uninteresting types. + return ids