Skip to content

Commit

Permalink
Update References with pointer to target object (#34)
Browse files Browse the repository at this point in the history
* Update References with pointer to target object

* Suppress internal trailing-underscore fields in repr

* Fix handling of multiple references to the same ID
  • Loading branch information
jmuhlich authored Aug 24, 2020
1 parent a9d5fc4 commit 05a15c4
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 11 deletions.
84 changes: 74 additions & 10 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def __post_init__(self) -> None:
self.body = indent(dedent(self.body), " " * 4)


@dataclass
class ClassOverride:
base_type: Optional[str] = None
imports: Optional[str] = None
fields: Optional[str] = None
body: Optional[str] = None

def __post_init__(self) -> None:
if self.imports:
self.imports = dedent(self.imports)
if self.fields:
self.fields = indent(dedent(self.fields), " " * 4)
if self.body:
self.body = indent(dedent(self.body), " " * 4)


# Maps XSD TypeName to Override configuration, used to control output for that type.
OVERRIDES = {
"MetadataOnly": Override(type_="bool", default="False"),
Expand Down Expand Up @@ -224,6 +240,48 @@ class UUID:
}


# Maps XSD TypeName to ClassOverride configuration, used to control dataclass
# generation.
CLASS_OVERRIDES = {
"OME": ClassOverride(
imports="""
from typing import Any
import weakref
from ome_types.util import collect_ids, collect_references
""",
body="""
def __post_init_post_parse__(self: Any, *args: Any) -> None:
ids = collect_ids(self)
for ref in collect_references(self):
ref.ref_ = weakref.ref(ids[ref.id])
""",
),
"Reference": ClassOverride(
imports="""
from dataclasses import field
from typing import Any, Optional
from .simple_types import LSID
""",
# FIXME Figure out typing for ref_ (weakref). Even with the "correct"
# typing, Pydantic has a problem.
fields="""
id: LSID
ref_: Any = field(default=None, init=False)
""",
# FIXME Could make `ref` abstract and implement stronger-typed overrides
# in subclasses.
body="""
@property
def ref(self) -> Any:
if self.ref_ is None:
raise ValueError("references not yet resolved on root OME object")
return self.ref_()
""",
),
"BinData": ClassOverride(base_type="object", fields="value: str"),
}


def black_format(text: str, line_length: int = 79) -> str:
return black.format_str(text, mode=black.FileMode(line_length=line_length))

Expand Down Expand Up @@ -275,23 +333,29 @@ def local_import(item_type: str) -> str:


def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
class_override = CLASS_OVERRIDES.get(component.local_name, None)
lines = ["from ome_types.dataclasses import ome_dataclass", ""]
# FIXME: Refactor to remove BinData special-case.
if component.local_name == "BinData":
base_type = None
elif isinstance(component, XsdType):
if isinstance(component, XsdType):
base_type = component.base_type
else:
base_type = component.type.base_type

if base_type and not hasattr(base_type, "python_type"):
if class_override and class_override.base_type:
if class_override.base_type == "object":
base_name = ""
else:
base_name = f"({class_override.base_type})"
base_type = None
elif base_type and not hasattr(base_type, "python_type"):
base_name = f"({base_type.local_name})"
if base_type.is_complex():
lines += [local_import(base_type.local_name)]
else:
lines += [f"from .simple_types import {base_type.local_name}"]
else:
base_name = ""
if class_override and class_override.imports:
lines.append(class_override.imports)

base_members = set()
_basebase = base_type
Expand All @@ -310,16 +374,17 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
lines[0] += ", AUTO_SEQUENCE"

lines += ["@ome_dataclass", f"class {component.local_name}{base_name}:"]
# FIXME: Refactor to remove BinData special-case.
if component.local_name == "BinData":
lines.append(" value: str")
if class_override and class_override.fields:
lines.append(class_override.fields)
lines += members.lines(
indent=1,
force_defaults=" = EMPTY # type: ignore"
if cannot_have_required_args
else None,
)

if class_override and class_override.body:
lines.append(class_override.body)
lines += members.body()

return lines
Expand Down Expand Up @@ -755,8 +820,7 @@ def _abstract_class(self) -> List[str]:
return lines

def lines(self) -> str:
# FIXME: Refactor to remove BinData special-case.
if not self.is_complex and self.elem.local_name != "BinData":
if not self.is_complex:
lines = self._simple_class()
elif self.elem.abstract:
lines = self._abstract_class()
Expand Down
4 changes: 3 additions & 1 deletion src/ome_types/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def new_post_init(self: Any, *args: Any) -> None:
def modify_repr(_cls: Type[Any]) -> None:
"""Improved dataclass repr function.
Only show non-default values, and summarize containers.
Only show non-default non-internal values, and summarize containers.
"""
# let classes still create their own
if _cls.__repr__ is not object.__repr__:
Expand All @@ -102,6 +102,8 @@ def new_repr(self: Any) -> str:
name = self.__class__.__qualname__
lines = []
for f in sorted(fields(self), key=lambda f: f.name not in ("name", "id")):
if f.name.endswith("_"):
continue
# https://github.com/python/mypy/issues/6910
if f.default_factory is not MISSING: # type: ignore
default = f.default_factory() # type: ignore
Expand Down
50 changes: 50 additions & 0 deletions src/ome_types/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import dataclasses
import weakref
from typing import Any, Dict, List

from .model.simple_types import LSID
from .model.reference import Reference


def collect_references(value: Any) -> List[Reference]:
"""Return a list of all References contained in value.
Recursively walks all dataclass fields and iterates over lists. The base
case is when value is either a Reference object, or an uninteresting type
that we don't need to inspect further.
"""
references: List[Reference] = []
if isinstance(value, Reference):
references.append(value)
elif isinstance(value, list):
for v in value:
references.extend(collect_references(v))
elif dataclasses.is_dataclass(value):
for f in dataclasses.fields(value):
references.extend(collect_references(getattr(value, f.name)))
# Do nothing for uninteresting types
return references


def collect_ids(value: Any) -> Dict[LSID, Any]:
"""Return a map of all model objects contained in value, keyed by id.
Recursively walks all dataclass fields and iterates over lists. The base
case is when value is neither a dataclass nor a list.
"""
ids: Dict[LSID, Any] = {}
if isinstance(value, list):
for v in value:
ids.update(collect_ids(v))
elif dataclasses.is_dataclass(value):
for f in dataclasses.fields(value):
if f.name == "id" and not isinstance(value, Reference):
# We don't need to recurse on the id string, so just record it
# and move on.
ids[value.id] = value
else:
ids.update(collect_ids(getattr(value, f.name)))
# Do nothing for uninteresting types.
return ids

0 comments on commit 05a15c4

Please sign in to comment.