-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultiscales.py
308 lines (277 loc) · 10.6 KB
/
multiscales.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from typing import Any, Dict, List, Sequence, Tuple, Optional
import numpy as np
from xarray import DataArray
from pydantic_ome_ngff.latest.axes import Axis
from pydantic_ome_ngff.latest.multiscales import MultiscaleDataset, Multiscale
from pydantic_ome_ngff.latest.coordinateTransformations import (
VectorScaleTransform,
VectorTranslationTransform,
CoordinateTransform,
)
import builtins
import warnings
from xarray_ome_ngff.core import ureg
def multiscale_metadata(
arrays: Sequence[DataArray],
array_paths: Optional[List[str]] = None,
name: Optional[str] = None,
type: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
normalize_units: bool = True,
infer_axis_type: bool = True,
):
"""
Create Multiscale metadata from a collection of xarray.DataArrays
Parameters
----------
arrays: sequence of DataArray
The arrays that represent the multiscale collection of images.
array_paths: sequence of strings, optional
The path of each array in the group.
name: string, optional
The name of the multiscale collection. Used to populate the 'name' field of
Multiscale metadata.
type: string, optional
The type of the multiscale collection. Used to populate the 'type' field of
Multiscale metadata.
metadata: dict, optional
Additional metadata associated with this multiscale collection. Used to populate
the 'metadata' field of Multiscale metadata.
normalize_units: bool, defaults to True
Whether to normalize units to standard names, e.g. 'nm' -> 'nanometer'
infer_axis_type: bool, defaults to True
Whether to infer the `type` field of the axis from units, e.g. if units are
"nanometer" then the type of the axis can safely be assumed to be "space".
This keyword argument is ignored if `type` is not None in the array coordinate
metadata. If axis type inference fails, `type` will be set to None.
Returns
-------
An instance of Multiscale metadata.
"""
for arr in arrays:
if not isinstance(arr, DataArray):
raise ValueError(
f"""
This function requires a list of xarray.DataArrays. Got an element with
type = '{builtins.type(arr)}' instead.
"""
)
# sort arrays by decreasing shape
ranks = [a.ndim for a in arrays]
if len(set(ranks)) > 1:
raise ValueError(
f"""
All arrays must have the same number of dimensions. Found arrays with different
numbers of dimensions: {set(ranks)}.
"""
)
arrays_sorted = tuple(reversed(sorted(arrays, key=lambda arr: np.prod(arr.shape))))
base_transforms = [
VectorScaleTransform(
scale=[
1,
]
* ranks[0]
)
]
axes, transforms = tuple(
zip(
*(
coords_to_transforms(
tuple(array.coords.values()),
normalize_units=normalize_units,
infer_axis_type=infer_axis_type,
)
for array in arrays_sorted
)
)
)
if array_paths is None:
paths = [d.name for d in arrays_sorted]
else:
assert len(array_paths) == len(
arrays
), f"""
Length of array_paths {len(array_paths)} doesn't match {len(arrays)}
"""
paths = array_paths
datasets = list(
MultiscaleDataset(path=p, coordinateTransformations=t)
for p, t in zip(paths, transforms)
)
return Multiscale(
name=name,
type=type,
axes=axes[0],
datasets=datasets,
metadata=metadata,
coordinateTransformations=base_transforms,
)
def coords_to_transforms(
coords: Tuple[DataArray, ...], normalize_units: bool = True, infer_axis_type=True
) -> Tuple[Tuple[Axis, ...], Tuple[VectorScaleTransform, VectorTranslationTransform]]:
"""
Generate Axes and CoordinateTransformations from an xarray.DataArray.
Parameters
----------
array: DataArray
A DataArray with coordinates for each dimension. Scale and translation
transform parameters will be inferred from the coordinates for each dimension.
Note that no effort is made to ensure that the coordinates represent a regular
grid. Axis types are inferred by querying the attributes of each
coordinate for the 'type' key. Axis units are inferred by querying the
attributes of each coordinate for the 'unit' key, and if that key is not present
then the 'units' key is queried. Axis names are inferred from the dimensions
of the array.
normalize_units: bool, defaults to True
If True, unit strings will be normalized to a canonical representation using the
`pint` library. For example, the abbreviation "nm" will be normalized to
"nanometer".
infer_axis_type: bool, defaults to True
Whether to infer the axis type from the units. This will have no effect if
the array has 'type' in its attrs.
Returns
-------
A tuple with two elements. The first value is a tuple of Axis objects with
length equal to the number of dimensions of the input array. The second value is
a tuple with two elements which contains a VectorScaleTransform and a
VectorTranslationTransform, both of which are derived by inspecting the
coordinates of the input array.
"""
translate = []
scale = []
axes = []
for coord in coords:
if ndim := len(coord.dims) != 1:
raise ValueError(
f"""
Each coordinate must have one and only one dimension.
Got a coordinate with {ndim}.
"""
)
dim = coord.dims[0]
translate.append(float(coord[0]))
# impossible to infer a scale coordinate from a coordinate with 1 sample
if len(coord) > 1:
scale.append(abs(float(coord[1]) - float(coord[0])))
else:
scale.append(1)
unit = coord.attrs.get("unit", None)
units = coord.attrs.get("units", None)
if unit is None and units is not None:
warnings.warn(
f"""
The key 'unit' was unset, but 'units' was found in array attrs, with a value
of '{units}'. The 'unit' property of the corresponding axis will be set to
'{units}', but this behavior may change in the future.
"""
)
unit = units
elif units is not None:
warnings.warn(
f"""
Both 'unit' and 'units' were found in array attrs, with values '{unit}' and
'{units}', respectively. The value associated with 'unit' ({unit}) will be
used in the axis metadata.
"""
)
if normalize_units and unit is not None:
unit = ureg.get_name(unit, case_sensitive=True)
if (type := coord.attrs.get("type", None)) is None and infer_axis_type:
unit_dimensionality = ureg.get_dimensionality(unit)
if len(unit_dimensionality) > 1:
warnings.warn(
f"""
Failed to infer the type of axis with unit = "{unit}", because it
appears that unit "{unit}" is a compound unit, which cannot be mapped
to a single axis type. "type" will be set to None for this axis.
""",
RuntimeWarning,
)
if "[length]" in unit_dimensionality:
type = "space"
elif "[time]" in unit_dimensionality:
type = "time"
else:
warnings.warn(
f"""
Failed to infer the type of axis with unit = "{unit}", because it could
not be mapped to either a time or space dimension. "type" will be set to
None for this axis.
""",
RuntimeWarning,
)
type = None
axes.append(
Axis(
name=dim,
unit=unit,
type=type,
)
)
transforms = (
VectorScaleTransform(scale=scale),
VectorTranslationTransform(translation=translate),
)
return axes, transforms
def transforms_to_coords(
axes: List[Axis], transforms: List[CoordinateTransform], shape: Tuple[int, ...]
) -> List[DataArray]:
"""
Given an output shape, convert a sequence of Axis objects and a corresponding
sequence of coordinateTransform objects into xarray-compatible coordinates.
"""
if len(axes) != len(shape):
raise ValueError(
f"""Length of axes must match length of shape.
Got {len(axes)} axes but shape has {len(shape)} elements"""
)
result = []
for idx, axis in enumerate(axes):
base_coord = np.arange(shape[idx], dtype="float")
name = axis.name
unit = axis.unit
# apply transforms in order
for tx in transforms:
if type(getattr(tx, "path", None)) == str:
raise ValueError(
f"""
Problematic transform: {tx}.
This library does not handle transforms with paths.
"""
)
if tx.type == "translation":
if len(tx.translation) != len(axes):
raise ValueError(
f"""
Translation parameter has length {len(tx.translation)}. This does
not match the number of axes {len(axes)}.
"""
)
base_coord += tx.translation[idx]
elif tx.type == "scale":
if len(tx.scale) != len(axes):
raise ValueError(
f"""
Scale parameter has length {len(tx.scale)}. This does not match the
number of axes {len(axes)}.
"""
)
base_coord *= tx.scale[idx]
elif tx.type == "identity":
pass
else:
raise ValueError(
f"""
Transform type {tx.type} not recognized. Must be one of scale,
translate, or identity
"""
)
result.append(
DataArray(
base_coord,
attrs={"unit": unit},
dims=(name,),
)
)
return result