diff --git a/test_rotation.py b/test_rotation.py new file mode 100644 index 00000000..3ee1c65d --- /dev/null +++ b/test_rotation.py @@ -0,0 +1,220 @@ +import pathlib + +import numpy as np +import matplotlib.pyplot as plt +import skimage.io +from ashlar import reg, utils, thumbnail + + +class TestMetadata(reg.Metadata): + def __init__( + self, + path, + tile_size, + overlap, + pixel_size, + channel=0, + zarr=None, + img=None, + series=None, + ): + self.path = pathlib.Path(path) + self._tile_size = np.array(tile_size) + self.overlap = overlap + self._pixel_size = pixel_size + self.channel = channel + self.zarr = zarr + self.img = img + self.series = series + self.deconstruct_mosaic() + + def deconstruct_mosaic(self): + if self.zarr is not None: + self.mosaic = self.zarr + + if self.img is not None: + self.mosaic = self.img + + if self.zarr is None and self.img is None: + self.mosaic = skimage.io.imread(self.path, key=self.channel) + + m_shape = self.mosaic.shape + + step_shape = (1 - self.overlap) * self._tile_size + # round position to integer since no subpixel needed for already stitched image + step_shape = np.around(step_shape).astype("int") + overlap_shape = np.around(self.overlap * self._tile_size).astype(int) + m_limit = m_shape - overlap_shape + + self._slice_positions = ( + np.mgrid[: m_limit[0] : step_shape[0], : m_limit[1] : step_shape[1]] + .reshape(2, -1) + .T + ) + + self._positions = self._slice_positions.astype(float) + + if self.series is not None: + self._slice_positions = self._slice_positions[self.series] + self._positions = self._positions[self.series] + + @property + def _num_images(self): + return len(self._positions) + + @property + def num_channels(self): + return 1 + + @property + def pixel_size(self): + return self._pixel_size + + @property + def pixel_dtype(self): + return self.zarr.dtype + + @property + def mosaic_shape(self): + return self.zarr.shape + + def tile_size(self, i): + return self._tile_size + + +class TestReader(reg.Reader): + def __init__( + self, + path=None, + tile_size=(1000, 1000), + overlap=0.1, + pixel_size=1, + channel=0, + zarr=None, + img=None, + series=None, + flip_x=False, + flip_y=False, + angle=0, + center_crop_shape=None, + noise=0, + ): + path = "" if path is None else path + self.metadata = TestMetadata( + path, tile_size, overlap, pixel_size, channel, zarr, img, series + ) + self.path = pathlib.Path(path) + self.mosaic = self.metadata.mosaic + self.flip_x = flip_x + self.flip_y = flip_y + self.angle = angle + self.noise = noise + + def read(self, series, c): + position = self.metadata._slice_positions[series] + assert np.issubdtype(position.dtype, np.integer) + r, c = position + h, w = self.metadata._tile_size + img = self.mosaic[r : r + h, c : c + w] + if self.noise: + r = np.random.RandomState(seed=series) + noise_img = r.randint(0, self.noise + 1, size=img.shape) + img = np.clip(img + noise_img, img.min(), img.max()).astype(img.dtype) + if not np.all(img.shape == (h, w)): + img_h, img_w = img.shape + pad_h, pad_w = np.clip([h - img_h, w - img_w], 0, None) + img = np.pad(img, [(0, pad_h), (0, pad_w)]) + if self.flip_x: + img = np.fliplr(img) + if self.flip_y: + img = np.flipud(img) + if self.angle != 0: + img = skimage.transform.rotate(img, self.angle, center=(0, 0), resize=True) + return img + + +def align_cycles(reader1, reader2, scale=0.05): + import skimage.transform + + if not hasattr(reader1, "thumbnail"): + raise ValueError("reader1 does not have a thumbnail") + if not hasattr(reader2, "thumbnail"): + raise ValueError("reader2 does not have a thumbnail") + img1 = reader1.thumbnail + img2 = reader2.thumbnail + padded_shape = np.array((img1.shape, img2.shape)).max(axis=0) + img1 = skimage.transform.warp(img1, np.eye(3), output_shape=padded_shape) + img2 = skimage.transform.warp(img2, np.eye(3), output_shape=padded_shape) + angle = utils.register_angle(img1, img2, sigma=1) + if angle != 0: + print(f"\r estimated cycle rotation = {angle:.2f} degrees") + rotation_center = 0.5 * np.array(padded_shape[::-1]) - 0.5 + img2 = skimage.transform.rotate( + img2, angle, resize=False, center=rotation_center + ) + shifts = thumbnail.calculate_image_offset(img1, img2, int(1 / scale)) + print(f"\r estimated shift {shifts / scale}") + tform_steps = [ + ("translation", -reader2.metadata.origin[::-1]), + ("scale", scale), + ("translation", -rotation_center), + ("rotation", np.deg2rad(-angle)), + ("translation", rotation_center), + ("translation", shifts[::-1]), + ("scale", 1 / scale), + ("translation", reader1.metadata.origin[::-1]), + ] + tform = skimage.transform.AffineTransform() + for step in tform_steps: + tform += skimage.transform.AffineTransform(**{step[0]: step[1]}) + + return tform + + +import numpy as np +import skimage.data +import skimage.transform +from ashlar import thumbnail + +TILE_SIZE = (108, 128) + +img = skimage.data.astronaut()[..., 1] +c1r = TestReader(img=img, tile_size=TILE_SIZE, overlap=0.25, noise=1) + +affine = skimage.transform.AffineTransform +#tform = affine( +# translation=200 * (np.random.random(2) - 0.5), +# rotation=np.deg2rad(-10 * (np.random.random(1) - 0.5)[0]), +#) +tform = ( + affine(translation=(-250, -280)) + + affine(rotation=np.deg2rad(88)) + + affine(translation=(250, 280)) +) + +# apply known transform to image +img2 = skimage.transform.warp(img, tform.inverse, preserve_range=True).astype(img.dtype) +c2r = TestReader(img=img2, tile_size=TILE_SIZE, overlap=0.25, noise=1) + +# set random stage origin +c1r.metadata._positions += 2000 * (np.random.random(2) - 0.5) +c2r.metadata._positions += 2000 * (np.random.random(2) - 0.5) + +# randomly perturb stage positions +c1r.metadata._positions += np.random.random_sample(c1r.metadata._positions.shape) * 5 +c2r.metadata._positions += np.random.random_sample(c2r.metadata._positions.shape) * 5 + +a1 = reg.EdgeAligner(c1r, verbose=True) +a1.run() +print() +a2 = reg.LayerAligner(c2r, a1, verbose=True) +a2.run() + +fig, (ax1, ax2) = plt.subplots(1, 2) +ax1.imshow(a1.reader.thumbnail, cmap='gray', vmax=2e5) +ax2.imshow(a2.reader.thumbnail, cmap='gray', vmax=2e5) +for i, (y, x) in enumerate(a1.metadata.centers - a1.metadata.origin): + ax1.annotate(str(i), (x,y), ha='center', va='center', color='yellow') +for i, (y, x) in enumerate(a2.metadata.centers - a2.metadata.origin): + ax2.annotate(str(i), (x,y), ha='center', va='center', color='magenta') +plt.show()