Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update References with pointer to target object #34

Merged
merged 3 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions src/ome_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def __post_init__(self) -> None:
self.body = indent(dedent(self.body), " " * 4)


@dataclass
class ClassOverride:
base_type: Optional[str] = None
imports: Optional[str] = None
fields: Optional[str] = None
body: Optional[str] = None

def __post_init__(self) -> None:
if self.imports:
self.imports = dedent(self.imports)
if self.fields:
self.fields = indent(dedent(self.fields), " " * 4)
if self.body:
self.body = indent(dedent(self.body), " " * 4)


# Maps XSD TypeName to Override configuration, used to control output for that type.
OVERRIDES = {
"MetadataOnly": Override(type_="bool", default="False"),
Expand Down Expand Up @@ -224,6 +240,48 @@ class UUID:
}


# Maps XSD TypeName to ClassOverride configuration, used to control dataclass
# generation.
CLASS_OVERRIDES = {
"OME": ClassOverride(
imports="""
from typing import Any
import weakref
from ome_types.util import collect_ids, collect_references
""",
body="""
def __post_init_post_parse__(self: Any, *args: Any) -> None:
ids = collect_ids(self)
for ref in collect_references(self):
ref.ref_ = weakref.ref(ids[ref.id])
""",
),
"Reference": ClassOverride(
imports="""
from dataclasses import field
from typing import Any, Optional
from .simple_types import LSID
""",
# FIXME Figure out typing for ref_ (weakref). Even with the "correct"
# typing, Pydantic has a problem.
fields="""
id: LSID
ref_: Any = field(default=None, init=False)
""",
# FIXME Could make `ref` abstract and implement stronger-typed overrides
# in subclasses.
body="""
@property
def ref(self) -> Any:
if self.ref_ is None:
raise ValueError("references not yet resolved on root OME object")
return self.ref_()
""",
),
"BinData": ClassOverride(base_type="object", fields="value: str"),
}


def black_format(text: str, line_length: int = 79) -> str:
return black.format_str(text, mode=black.FileMode(line_length=line_length))

Expand Down Expand Up @@ -275,23 +333,29 @@ def local_import(item_type: str) -> str:


def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
class_override = CLASS_OVERRIDES.get(component.local_name, None)
lines = ["from ome_types.dataclasses import ome_dataclass", ""]
# FIXME: Refactor to remove BinData special-case.
if component.local_name == "BinData":
base_type = None
elif isinstance(component, XsdType):
if isinstance(component, XsdType):
base_type = component.base_type
else:
base_type = component.type.base_type

if base_type and not hasattr(base_type, "python_type"):
if class_override and class_override.base_type:
if class_override.base_type == "object":
base_name = ""
else:
base_name = f"({class_override.base_type})"
base_type = None
elif base_type and not hasattr(base_type, "python_type"):
base_name = f"({base_type.local_name})"
if base_type.is_complex():
lines += [local_import(base_type.local_name)]
else:
lines += [f"from .simple_types import {base_type.local_name}"]
else:
base_name = ""
if class_override and class_override.imports:
lines.append(class_override.imports)

base_members = set()
_basebase = base_type
Expand All @@ -310,16 +374,17 @@ def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
lines[0] += ", AUTO_SEQUENCE"

lines += ["@ome_dataclass", f"class {component.local_name}{base_name}:"]
# FIXME: Refactor to remove BinData special-case.
if component.local_name == "BinData":
lines.append(" value: str")
if class_override and class_override.fields:
lines.append(class_override.fields)
lines += members.lines(
indent=1,
force_defaults=" = EMPTY # type: ignore"
if cannot_have_required_args
else None,
)

if class_override and class_override.body:
lines.append(class_override.body)
lines += members.body()

return lines
Expand Down Expand Up @@ -755,8 +820,7 @@ def _abstract_class(self) -> List[str]:
return lines

def lines(self) -> str:
# FIXME: Refactor to remove BinData special-case.
if not self.is_complex and self.elem.local_name != "BinData":
if not self.is_complex:
lines = self._simple_class()
elif self.elem.abstract:
lines = self._abstract_class()
Expand Down
4 changes: 3 additions & 1 deletion src/ome_types/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def new_post_init(self: Any, *args: Any) -> None:
def modify_repr(_cls: Type[Any]) -> None:
"""Improved dataclass repr function.

Only show non-default values, and summarize containers.
Only show non-default non-internal values, and summarize containers.
"""
# let classes still create their own
if _cls.__repr__ is not object.__repr__:
Expand All @@ -102,6 +102,8 @@ def new_repr(self: Any) -> str:
name = self.__class__.__qualname__
lines = []
for f in sorted(fields(self), key=lambda f: f.name not in ("name", "id")):
if f.name.endswith("_"):
continue
# https://github.com/python/mypy/issues/6910
if f.default_factory is not MISSING: # type: ignore
default = f.default_factory() # type: ignore
Expand Down
50 changes: 50 additions & 0 deletions src/ome_types/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import dataclasses
import weakref
from typing import Any, Dict, List

from .model.simple_types import LSID
from .model.reference import Reference


def collect_references(value: Any) -> List[Reference]:
"""Return a list of all References contained in value.

Recursively walks all dataclass fields and iterates over lists. The base
case is when value is either a Reference object, or an uninteresting type
that we don't need to inspect further.

"""
references: List[Reference] = []
if isinstance(value, Reference):
references.append(value)
elif isinstance(value, list):
for v in value:
references.extend(collect_references(v))
elif dataclasses.is_dataclass(value):
for f in dataclasses.fields(value):
references.extend(collect_references(getattr(value, f.name)))
# Do nothing for uninteresting types
return references


def collect_ids(value: Any) -> Dict[LSID, Any]:
"""Return a map of all model objects contained in value, keyed by id.

Recursively walks all dataclass fields and iterates over lists. The base
case is when value is neither a dataclass nor a list.

"""
ids: Dict[LSID, Any] = {}
if isinstance(value, list):
for v in value:
ids.update(collect_ids(v))
elif dataclasses.is_dataclass(value):
for f in dataclasses.fields(value):
if f.name == "id" and not isinstance(value, Reference):
# We don't need to recurse on the id string, so just record it
# and move on.
ids[value.id] = value
else:
ids.update(collect_ids(getattr(value, f.name)))
# Do nothing for uninteresting types.
return ids