Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add from_array convenience functions #219

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ no_implicit_reexport = true
ignore_missing_imports = true
disallow_untyped_defs = true
plugins = "pydantic.mypy"
enable_incomplete_feature = ["Unpack"]

[tool.pydantic-mypy]
init_forbid_extra = true
Expand Down
17 changes: 17 additions & 0 deletions src/ome_types/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,20 @@ def _get_root_ome_type(xml: FileLike | AnyElementTree) -> type[OMEType]:
return getattr(model, localname)
except AttributeError:
raise ValueError(f"Unknown root element {localname!r}") from None


def camel_to_snake(name: str) -> str:
"""Variant of camel_to_snake that preserves adjacent uppercase letters.

https://stackoverflow.com/a/1176023

Note: this function also exists in ome_autogen._util, but we shouldn't import
anything from that module at runtime, so we duplicate it here.
"""
import re

name = name.lstrip("@") # remove leading @ from "@any_element"
result = re.sub("([A-Z]+)([A-Z][a-z]+)", r"\1_\2", name)
result = re.sub("([a-z0-9])([A-Z])", r"\1_\2", result)
result = result.lower().replace(" ", "_")
return result
289 changes: 289 additions & 0 deletions src/ome_types/_from_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
from __future__ import annotations

import warnings
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast

import numpy as np

from ome_types import model as m
from ome_types._mixins._validators import numpy_dtype_to_pixel_type

if TYPE_CHECKING:
import datetime

import numpy.typing as npt
from typing_extensions import Literal, TypedDict, TypeVar, Unpack

Kt = TypeVar("Kt")
Vt = TypeVar("Vt", covariant=True)

from ome_types.model._color import ColorType

DimsOrderStr = Literal["XYZCT", "XYZTC", "XYCTZ", "XYCZT", "XYTCZ", "XYTZC"]

# TODO: these should be autogenerated

class ImagePixelsKwargs(TypedDict, total=False):
acquisition_date: datetime.datetime | None
description: str | None
name: str | None
dimension_order: m.Pixels_DimensionOrder | DimsOrderStr
physical_size_x: float | None
physical_size_y: float | None
physical_size_z: float | None
physical_size_x_unit: m.UnitsLength | str
physical_size_y_unit: m.UnitsLength | str
physical_size_z_unit: m.UnitsLength | str
time_increment: float | None
time_increment_unit: m.UnitsTime | str

class ChannelKwargs(TypedDict, total=False):
acquisition_mode: m.Channel_AcquisitionMode | None | str
color: m.Color | ColorType | None
contrast_method: m.Channel_ContrastMethod | None | str
emission_wavelength_unit: m.UnitsLength | str
emission_wavelength: float | None
excitation_wavelength_unit: m.UnitsLength | str
excitation_wavelength: float | None
fluor: str | None
illumination_type: m.Channel_IlluminationType | str | None
name: str | None
nd_filter: float | None
pinhole_size_unit: m.UnitsLength | str
pinhole_size: float | None
pockel_cell_setting: int | None
# samples_per_pixel : None | int # will be derived from the array

# same as above, but in {name: Sequence[values]} format
class ChannelTable(TypedDict, total=False):
acquisition_mode: Sequence[m.Channel_AcquisitionMode | None | str]
color: Sequence[m.Color | ColorType | None] | m.Color
contrast_method: Sequence[m.Channel_ContrastMethod | None | str]
emission_wavelength_unit: Sequence[m.UnitsLength | str]
emission_wavelength: Sequence[float | None] | float
excitation_wavelength_unit: Sequence[m.UnitsLength | str]
excitation_wavelength: Sequence[float | None] | float
fluor: Sequence[str | None]
illumination_type: Sequence[m.Channel_IlluminationType | str | None]
name: Sequence[str | None]
nd_filter: Sequence[float | None] | float
pinhole_size_unit: Sequence[m.UnitsLength | str]
pinhole_size: Sequence[float | None] | float
pockel_cell_setting: Sequence[int | None]

class PlaneKwargs(TypedDict, total=False):
delta_t: float | None
delta_t_unit: m.UnitsTime | str
exposure_time: float | None
exposure_time_unit: m.UnitsTime | str
position_x: float | None
position_x_unit: m.UnitsLength | str
position_y: float | None
position_y_unit: m.UnitsLength | str
position_z: float | None
position_z_unit: m.UnitsLength | str

class PlaneTable(TypedDict, total=False):
delta_t: Sequence[float | None] | float
delta_t_unit: Sequence[m.UnitsTime | str]
exposure_time: Sequence[float | None] | float
exposure_time_unit: Sequence[m.UnitsTime | str]
position_x: Sequence[float | None] | float
position_x_unit: Sequence[m.UnitsLength | str]
position_y: Sequence[float | None] | float
position_y_unit: Sequence[m.UnitsLength | str]
position_z: Sequence[float | None] | float
position_z_unit: Sequence[m.UnitsLength | str]


def ome_image(
shape: Sequence[int],
dtype: npt.DTypeLike,
axes: Sequence[str] = "",
*,
channels: Sequence[m.Channel] | Sequence[ChannelKwargs] | ChannelTable = (),
planes: Sequence[PlaneKwargs] | PlaneTable = (),
**img_kwargs: Unpack[ImagePixelsKwargs],
) -> m.Image:
shape = tuple(int(i) for i in shape)
ndim = len(shape)
if ndim > 6:
raise ValueError(f"shape must have at most 6 dimensions, not {ndim}")

# unify axes argument with dimension_order and validate
axes, dims_order = _determine_axes(
axes, shape, img_kwargs.pop("dimension_order", None)
)

# determine pixel axis sizes ------------------------------------
nc, nz, nt, nsamp = (shape[axes.index(x)] if x in axes else 1 for x in "CZTS")
sizes = {f"size_{ax}": 1 for ax in "xyczt"}
for ax, size in zip(axes, shape):
if ax == "S":
continue
if ax == "C":
size *= nsamp
sizes[f"size_{ax.lower()}"] = size

czt_order = tuple(dims_order.value[2:].index(ax) for ax in "CZT")

# pull out the kwargs that belong to Image and Pixels
_img_kwargs, _pix_kwargs = {}, {}
for k, v in img_kwargs.items():
if k in m.Image.__annotations__:
_img_kwargs[k] = v
elif k in m.Pixels.__annotations__:
_pix_kwargs[k] = v
img = m.Image(
pixels=m.Pixels(
dimension_order=dims_order,
**sizes,
type=numpy_dtype_to_pixel_type(dtype),
# big_endian=False,
# significant_bits=8,
# bin_data=numpy.zeros(shape, dtype=dtype),
**_convert_keys_to_snake_case(_pix_kwargs),
channels=ome_channels(channels, nc, nsamp),
planes=ome_planes((nc, nz, nt), czt_order, planes),
),
**_convert_keys_to_snake_case(_img_kwargs),
)
...
# TODO: validate against shape and dtype here
return img


def ome_image_like(
array: npt.NDArray,
axes: Sequence[str] = "",
*,
channels: Sequence[ChannelKwargs] | ChannelTable = (),
planes: Sequence[PlaneKwargs] | PlaneTable = (),
**img_kwargs: Unpack[ImagePixelsKwargs],
) -> m.Image:
return ome_image(
shape=array.shape,
dtype=array.dtype,
axes=axes,
channels=channels,
planes=planes,
**img_kwargs,
)


def _determine_axes(
axes: Sequence[str] | None,
shape: Sequence[int],
dimension_order: m.Pixels_DimensionOrder | str | None,
) -> tuple[str, m.Pixels_DimensionOrder]:
_dims_order = m.Pixels_DimensionOrder(dimension_order or "XYCZT")
ndim = len(shape)
if not axes:
axes = _dims_order.value[::-1][-len(shape) :]
if ndim == 6:
axes += "S"
return axes, _dims_order

if ndim == 6 and "S" not in axes:
raise ValueError(
"shape has 6 dimensions, so axes must be specified with 'S' in it"
)

axes = "".join(x[0] for x in axes).upper()
if len(axes) != len(shape):
raise ValueError(f"Axes {axes!r} do not match shape {shape!r}")

ome_axes = axes[::-1]
ome_axes = ome_axes.replace("S", "")

for order in m.Pixels_DimensionOrder:
if order.value.startswith(ome_axes):
if dimension_order and order != _dims_order:
warnings.warn(
f"Provided OME dimension_order {dimension_order!r} does not match "
f"provided (reversed) axes {axes[::-1]!r}. Using {order.value!r}",
stacklevel=2,
)
return axes, order

raise ValueError(f"Could not determine dimension order from axes {axes!r}")


def ome_channels(
channels: Sequence[m.Channel] | Sequence[ChannelKwargs] | ChannelTable = (),
max_channels: int | None = None,
samples_per_pixel: int = 1,
) -> list[m.Channel]:
if not channels:
return [
m.Channel(samples_per_pixel=samples_per_pixel)
for _ in range((max_channels or 1) // samples_per_pixel)
]

# convert dict of lists to list of dicts
if isinstance(channels, dict):
channels = cast("Sequence[ChannelKwargs]", _dol2lod(channels, max_channels))

# limit to max_channels (based on previous shape analysis)
# TODO: should we warn if too many channels are provided?
channels = channels[:max_channels]

channel_list: list[m.Channel] = []
for channel in channels[:max_channels]:
if isinstance(channel, m.Channel):
kwargs: dict = channel.dict()
else:
kwargs = _convert_keys_to_snake_case(channel)
kwargs["samples_per_pixel"] = samples_per_pixel
channel_list.append(m.Channel(**kwargs))
return channel_list


def ome_planes(
n_czt: tuple[int, int, int],
czt_order: tuple[int, int, int],
planes: Sequence[PlaneKwargs] | PlaneTable,
) -> list[m.Plane]:
# if not planes:
# return []

plane_count = int(np.prod(n_czt))

if isinstance(planes, dict):
# convert dict of lists to list of dicts
planes = cast("Sequence[PlaneKwargs]", _dol2lod(planes, plane_count))

if not planes:
planes = [{} for _ in range(plane_count)]
elif len(planes) > plane_count:
warnings.warn(
f"Provided {len(planes)} planes, but expected {plane_count}",
stacklevel=2,
)
planes = planes[:plane_count]
elif len(planes) < plane_count:
raise ValueError(f"Provided {len(planes)} planes, but expected {plane_count}")

plane_list = []
for idx, plane in enumerate(planes):
unraveled = np.unravel_index(idx, n_czt, order="F")
c, z, t = (unraveled[i] for i in czt_order)
plane_list.append(m.Plane(**dict(**plane, the_c=c, the_z=z, the_t=t)))
return plane_list


def _dol2lod(dol: Mapping[str, Any], max_items: int | None = None) -> list[dict]:
# convert dict of sequences to sequence of dicts
for k, v in dol.items():
# extend single items to max_items
if not isinstance(v, Sequence) or isinstance(v, str):
dol[k] = [v] * (max_items or 1) # type: ignore
val_zip = zip_longest(*dol.values()) # type: ignore
return [dict(zip(dol, v)) for v in val_zip]


def _convert_keys_to_snake_case(d: Mapping[str, Vt]) -> dict[str, Vt]:
from ome_types._conversion import camel_to_snake

return {camel_to_snake(k): v for k, v in d.items()}
49 changes: 32 additions & 17 deletions src/ome_types/_mixins/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Sequence

import numpy as np

if TYPE_CHECKING:
from numpy.typing import DTypeLike

from ome_types.model import ( # type: ignore
BinData,
Pixels,
Expand All @@ -32,15 +36,14 @@ def bin_data_root_validator(cls: "BinData", values: dict) -> Dict[str, Any]:

# @root_validator(pre=True)
def pixels_root_validator(cls: "Pixels", value: dict) -> dict:
if "metadata_only" in value:
if isinstance(value["metadata_only"], bool):
if not value["metadata_only"]:
value.pop("metadata_only")
else:
# type ignore in case the autogeneration hasn't been built
from ome_types.model import MetadataOnly # type: ignore
if "metadata_only" in value and isinstance(value["metadata_only"], bool):
if not value["metadata_only"]:
value.pop("metadata_only")
else:
# type ignore in case the autogeneration hasn't been built
from ome_types.model import MetadataOnly # type: ignore

value["metadata_only"] = MetadataOnly()
value["metadata_only"] = MetadataOnly()

return value

Expand Down Expand Up @@ -76,13 +79,25 @@ def xml_value_validator(cls: "XMLAnnotation", v: Any) -> "XMLAnnotation.Value":
return v


def pixel_type_to_numpy_dtype(self: "PixelType") -> str:
# maps OME PixelType names to numpy dtype names
NP_DTYPE_MAP: "dict[str, str]" = {
"float": "float32",
"double": "float64",
"complex": "complex64",
"double-complex": "complex128",
"bit": "bool", # ?
}
REV_NP_DTYPE_MAP: "dict[str, str]" = {v: k for k, v in NP_DTYPE_MAP.items()}


def pixel_type_to_numpy_dtype(self: "PixelType") -> "DTypeLike":
"""Get a numpy dtype string for this pixel type."""
m = {
"float": "float32",
"double": "float64",
"complex": "complex64",
"double-complex": "complex128",
"bit": "bool", # ?
}
return m.get(self.value, self.value)
return NP_DTYPE_MAP.get(self.value, self.value)


def numpy_dtype_to_pixel_type(dtype: "DTypeLike") -> "PixelType":
"""Return the PixelType corresponding to the numpy dtype."""
from ome_types.model import PixelType

_dtype = np.dtype(dtype).name
return PixelType(value=REV_NP_DTYPE_MAP.get(_dtype, _dtype))
Loading