Skip to content

Commit

Permalink
wip2
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Jun 26, 2023
1 parent eda6b22 commit d79a438
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 37 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include = ["src", "tests", "CHANGELOG.md"]

[tool.hatch.build.hooks.custom]
# requirements to run the autogen script in hatch_build.py
dependencies = ["black", "ruff", "xsdata-pydantic[cli]", "mypy"]
dependencies = ["black", "ruff", "xsdata[cli]>=23.6", "xsdata-pydantic", "mypy"]

# https://peps.python.org/pep-0621/
[project]
Expand Down Expand Up @@ -82,7 +82,7 @@ write_to = "src/ome_types/_version.py"
line-length = 88
src = ["src", "tests"]
target-version = "py38"
extend-select = [
select = [
"E", # style errors
"F", # flakes
"D", # pydocstyle
Expand All @@ -95,9 +95,10 @@ extend-select = [
"A001", # flake8-builtins
"RUF", # ruff-specific rules
]
extend-ignore = [
ignore = [
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D106", # Missing docstring in public nested class
"D107", # Missing docstring in __init__
"D203", # 1 blank line required before class docstring
"D212", # Multi-line docstring summary should start at the first line
Expand Down
17 changes: 17 additions & 0 deletions src/ome_autogen/_class_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any

from pydantic import BaseModel
from xsdata_pydantic.compat import Pydantic


class OME(Pydantic):
def is_model(self, obj: Any) -> bool:
clazz = obj if isinstance(obj, type) else type(obj)
if isinstance(clazz, BaseModel):
clazz.update_forward_refs()
return True

return False



52 changes: 43 additions & 9 deletions src/ome_autogen/_generator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from xsdata.codegen.writer import CodeWriter
from typing import Any

from xsdata.codegen.models import Attr
from xsdata.formats.dataclass.filters import Filters
from xsdata.models.config import GeneratorConfig
from xsdata_pydantic.generator import PydanticFilters, PydanticGenerator
from xsdata.formats.dataclass.generator import DataclassGenerator
from xsdata.models.config import GeneratorConfig, OutputFormat
from xsdata_pydantic.generator import PydanticFilters

PRESERVED_NAMES = {"OME", "ROIRef", "XMLAnnotation", "ROI"}


class OmeGenerator(PydanticGenerator):
"""Python pydantic dataclasses code generator."""

KEY = "ome"

class OmeGenerator(DataclassGenerator):
@classmethod
def init_filters(cls, config: GeneratorConfig) -> Filters:
return OmeFilters(config)
Expand All @@ -23,5 +22,40 @@ def class_name(self, name: str) -> str:
def field_name(self, name: str, class_name: str) -> str:
return super().field_name(name, class_name)

@classmethod
def build_class_annotation(cls, fmt: OutputFormat) -> str:
# remove the @dataclass decorator
return ""

def field_definition(
self,
attr: Attr,
ns_map: dict,
parent_namespace: str | None,
parents: list[str],
) -> str:
"""Return the field definition with any extra metadata."""
# updated to use pydantic Field
default_value = self.field_default_value(attr, ns_map)
metadata = self.field_metadata(attr, parent_namespace, parents)

kwargs: dict[str, Any] = {}
if attr.fixed or attr.is_prohibited:
kwargs["init"] = False

if default_value is not False and not attr.is_prohibited:
key = self.FACTORY_KEY if attr.is_factory else self.DEFAULT_KEY
kwargs[key] = default_value

if metadata:
kwargs["metadata"] = metadata

return f"Field({self.format_arguments(kwargs, 4)})"

@classmethod
def build_import_patterns(cls) -> dict[str, dict]:
patterns = Filters.build_import_patterns()
patterns.pop("dataclasses")
patterns.update({"pydantic": {"Field": [" = Field("]}})

CodeWriter.register_generator(OmeGenerator.KEY, OmeGenerator)
return {key: patterns[key] for key in sorted(patterns)}
58 changes: 38 additions & 20 deletions src/ome_autogen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@
from typing import Any, Callable

from xsdata.codegen.transformer import SchemaTransformer
from xsdata.codegen.writer import CodeWriter
from xsdata.formats.dataclass.compat import class_types
from xsdata.logger import logger
from xsdata.models.config import (
CompoundFields,
DocstringStyle,
ExtensionType,
GeneratorConfig,
GeneratorConventions,
GeneratorExtension,
GeneratorExtensions,
GeneratorOutput,
NameConvention,
OutputFormat,
StructureStyle,
)
from xsdata_pydantic.hooks import class_type, cli # noqa: F401

from ome_autogen._generator import OmeGenerator
from ome_autogen._util import camel_to_snake, cd, get_plural_names, resolve_source

from ._generator import OmeGenerator

SRC_PATH = Path(__file__).parent.parent
SCHEMA_FILE = SRC_PATH / "ome_types" / "ome-2016-06.xsd"
PACKAGE = f"ome_types2.model.{SCHEMA_FILE.stem.replace('-', '_')}"
Expand All @@ -30,8 +38,27 @@
"S105", # Possible hardcoded password
]

OME_BASE_EXTENSION = GeneratorExtension(
type=ExtensionType.CLASS,
class_name=".*",
import_string="ome_types2.model._base_type.OMEType",
)


# These are critical to be able to use the format="OME"
OME_FORMAT = "OME"
CodeWriter.register_generator(OME_FORMAT, OmeGenerator)
from xsdata_pydantic.compat import Pydantic

class_types.register(OME_FORMAT, Pydantic())


class OmeNameCase(Enum):
"""Mimic the xsdata NameConvention enum, to modify snake case function.
We want adjacent capital letters to remain caps.
"""

OME_SNAKE = "omeSnakeCase"

def __call__(self, string: str, **kwargs: Any) -> str:
Expand All @@ -54,22 +81,20 @@ def convert_schema(
) -> None:
"""Convert the OME schema to a python model."""
if debug:
from xsdata.logger import logger

logger.setLevel("DEBUG")

output = GeneratorOutput(
package=output_package,
format=OutputFormat(value=OmeGenerator.KEY, slots=False),
structure_style=StructureStyle.CLUSTERS,
docstring_style=DocstringStyle.NUMPY,
compound_fields=CompoundFields(enabled=False),
)
config = GeneratorConfig(
output=output,
conventions=GeneratorConventions(
field_name=NameConvention(OmeNameCase.OME_SNAKE, "value")
output=GeneratorOutput(
package=output_package,
format=OutputFormat(value="pydantic_base_model", slots=False),
structure_style=StructureStyle.CLUSTERS,
docstring_style=DocstringStyle.NUMPY,
compound_fields=CompoundFields(enabled=False),
),
# conventions=GeneratorConventions(
# field_name=NameConvention(OmeNameCase.OME_SNAKE, "value")
# ),
# extensions=GeneratorExtensions(extension=[OME_BASE_EXTENSION]),
)

uris = sorted(resolve_source(str(schema_file), recursive=False))
Expand All @@ -91,13 +116,6 @@ def convert_schema(

package_dir = Path(output_dir) / output_package.replace(".", "/")

# Fix bug in xsdata output
# https://github.com/tefra/xsdata/pull/806
light_source = next(package_dir.rglob("light_source.py"))
src = light_source.read_text()
src = src.replace("UnitsPower.M_W,", "UnitsPower.M_W_1,")
light_source.write_text(src)

if not do_linting:
print(f"\033[92m\033[1m✓ OME python model created at {output_package}\033[0m")
return
Expand Down
6 changes: 4 additions & 2 deletions src/ome_types2/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import TYPE_CHECKING, Any, cast
from xml.etree import ElementTree as ET

from xsdata.formats.dataclass.parsers import XmlParser
from xsdata.formats.dataclass.parsers.config import ParserConfig

# from xsdata.formats.dataclass.parsers import XmlParser
from xsdata_pydantic.base_model.bindings import XmlParser

if TYPE_CHECKING:
from typing import TypedDict

Expand All @@ -34,7 +36,7 @@ def _get_ome(xml: str | bytes) -> type[OME]:
root = ET.fromstring(xml) # noqa: S314

if root.tag == OME_2016_06:
from ome_types2.model.ome_2016_06 import OME
from ome_types2.model import OME

return OME
raise NotImplementedError(f"Unknown root tag: {root.tag}")
Expand Down
2 changes: 1 addition & 1 deletion src/ome_types2/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .ome_2016_06 import OME
from .ome_2016_06 import Ome as OME

__all__ = ["OME"]
Loading

0 comments on commit d79a438

Please sign in to comment.