diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6bf3cf16..dee14d79 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,34 +23,101 @@ jobs: - run: pipx run check-manifest test: - uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2 - with: - os: ${{ matrix.os }} - python-version: ${{ matrix.python-version }} - coverage-upload: artifact - qt: pyqt6 + name: ${{ matrix.os }} py${{ matrix.python-version }} ${{ matrix.gui }} ${{ matrix.canvas }} + runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, macos-latest, windows-latest] + # using 3.12 as main current version, until 3.13 support + # is ubiquitous in upstream dependencies + python-version: ["3.10", "3.12"] + gui: [pyside, pyqt, jup, wxpython] + canvas: [vispy, pygfx] exclude: + # unsolved intermittent segfaults on this combo + - python-version: "3.10" + gui: pyside + # wxpython does not build wheels for ubuntu or macos-latest py3.10 - os: ubuntu-latest - python-version: "3.11" # unknown CI segfault... + gui: wxpython + - os: macos-latest + gui: wxpython + python-version: "3.10" + include: + # test a couple more python variants, without + # full os/gui/canvas matrix coverage + - os: ubuntu-latest + python-version: "3.13" + gui: jup + canvas: vispy + - os: ubuntu-latest + python-version: "3.13" + gui: jup + canvas: pygfx + # pyside6 is struggling with 3.9 + - os: ubuntu-latest + python-version: "3.9" + gui: pyqt + canvas: vispy + - os: macos-13 + gui: wxpython + python-version: "3.9" + canvas: vispy + - os: windows-latest + gui: jup + python-version: "3.9" + canvas: pygfx + + steps: + - uses: actions/checkout@v4 + - name: ๐Ÿ Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + + - name: ๐Ÿ“ฆ Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e '.[test,${{ matrix.gui }},${{ matrix.canvas }}]' + + - uses: pyvista/setup-headless-display-action@v3 + with: + qt: ${{ matrix.gui == 'pyside' || matrix.gui == 'pyqt' }} + + - name: Install llvmpipe and lavapipe for offscreen canvas + if: matrix.os == 'ubuntu-latest' && matrix.canvas == 'pygfx' + run: | + sudo apt-get update -y -qq + sudo apt install -y libegl1-mesa-dev libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers + + - name: install pytest-qt + if: matrix.gui == 'pyside' || matrix.gui == 'pyqt' + run: pip install pytest-qt + + - name: ๐Ÿงช Test + run: | + pytest --cov --cov-report=xml -v --color yes tests + + - uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} test-array-libs: uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2 with: os: ${{ matrix.os }} python-version: ${{ matrix.python-version }} - extras: "test,third_party_arrays" + extras: "test,vispy,third_party_arrays" coverage-upload: artifact + pip-post-installs: "pytest-qt" qt: pyqt6 strategy: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: ["3.9", "3.12"] + python-version: ["3.10", "3.12"] upload_coverage: if: always() diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95129cae..1c17375c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,6 @@ repos: files: "^src/" additional_dependencies: - numpy + - pydantic + - psygnal + - IPython diff --git a/README.md b/README.md index 94121b7a..7c429fa7 100644 --- a/README.md +++ b/README.md @@ -60,15 +60,25 @@ See examples for each of these array types in [examples](./examples/) ## Installation +To just get started using Qt and vispy: + +```python +pip install ndv[qt] +``` + +For Jupyter, without requiring Qt, you can use: + +```python +pip install ndv[jupyter] +``` + +If you'd like more control over the backend, you can install the optional dependencies directly. + The only required dependencies are `numpy` and `superqt[cmap,iconify]`. You will also need a Qt backend (PyQt or PySide) and one of either [vispy](https://github.com/vispy/vispy) or [pygfx](https://github.com/pygfx/pygfx), which can be installed through extras `ndv[,]`: -```python -pip install ndv[pyqt,vispy] -``` - > [!TIP] > If you have both vispy and pygfx installed, `ndv` will default to using vispy, > but you can override this with the environment variable diff --git a/examples/custom_store.py b/examples/custom_store.py index 9d3fbff6..114a3cf4 100644 --- a/examples/custom_store.py +++ b/examples/custom_store.py @@ -7,13 +7,15 @@ import ndv if TYPE_CHECKING: - from ndv import Indices, Sizes + from collections.abc import Hashable, Mapping, Sequence class MyArrayThing: + """Some custom data type that we want to visualize.""" + def __init__(self, shape: tuple[int, ...]) -> None: self.shape = shape - self._data = np.random.randint(0, 256, shape) + self._data = np.random.randint(0, 256, shape).astype(np.uint16) def __getitem__(self, item: Any) -> np.ndarray: return self._data[item] # type: ignore [no-any-return] @@ -22,16 +24,32 @@ def __getitem__(self, item: Any) -> np.ndarray: class MyWrapper(ndv.DataWrapper[MyArrayThing]): @classmethod def supports(cls, data: Any) -> bool: + """Return True if the data is supported by this wrapper""" if isinstance(data, MyArrayThing): return True return False - def sizes(self) -> Sizes: - """Return a mapping of {dim: size} for the data""" - return {f"dim_{k}": v for k, v in enumerate(self.data.shape)} + @property + def dims(self) -> tuple[Hashable, ...]: + """Return the dimensions of the data""" + return tuple(f"dim_{k}" for k in range(len(self.data.shape))) + + @property + def coords(self) -> dict[Hashable, Sequence]: + """Return a mapping of {dim: coords} for the data""" + return {f"dim_{k}": range(v) for k, v in enumerate(self.data.shape)} + + @property + def dtype(self) -> np.dtype: + """Return the dtype of the data""" + return self.data._data.dtype + + def isel(self, indexers: Mapping[int, int | slice]) -> np.ndarray: + """Select a subset of the data. - def isel(self, indexers: Indices) -> Any: - """Convert mapping of {dim: index} to conventional indexing""" + `indexers` is a mapping of {dim: index} where index is either an integer or a + slice. + """ idx = tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) return self.data[idx] diff --git a/examples/histogram.py b/examples/histogram.py deleted file mode 100644 index 68e15c78..00000000 --- a/examples/histogram.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np -from qtpy.QtCore import QTimer -from qtpy.QtWidgets import ( - QApplication, - QPushButton, - QVBoxLayout, - QWidget, -) - -from ndv.histogram.model import StatsModel -from ndv.histogram.views._vispy import VispyHistogramView - -if TYPE_CHECKING: - from typing import Any - - -class Controller: - """A (Qt) wrapper around another HistogramView with some additional controls.""" - - def __init__(self) -> None: - self._wdg = QWidget() - self._model = StatsModel() - self._view = VispyHistogramView() - - # A HistogramView is both a StatsView and a LUTView - # StatModel <-> StatsView - self._model.events.histogram.connect( - lambda data: self._view.set_histogram(*data) - ) - # LutModel <-> LutView (TODO) - # LutView -> LutModel (TODO: Currently LutView <-> LutView) - self._view.gammaChanged.connect(self._view.set_gamma) - self._view.climsChanged.connect(self._view.set_clims) - - # Vertical box - self._vert = QPushButton("Vertical") - self._vert.setCheckable(True) - self._vert.toggled.connect(self._view.set_vertical) - - # Log box - self._log = QPushButton("Logarithmic") - self._log.setCheckable(True) - self._log.toggled.connect(self._view.set_range_log) - - # Data updates - self._data_btn = QPushButton("Change Data") - self._data_btn.setCheckable(True) - self._data_btn.toggled.connect( - lambda toggle: self.timer.blockSignals(not toggle) - ) - - def _update_data() -> None: - """Replaces the displayed data.""" - self._model.data = np.random.normal(10, 10, 10000) - - self.timer = QTimer() - self.timer.setInterval(10) - self.timer.blockSignals(True) - self.timer.timeout.connect(_update_data) - self.timer.start() - - # Layout - self._layout = QVBoxLayout(self._wdg) - self._layout.addWidget(self._view.view()) - self._layout.addWidget(self._vert) - self._layout.addWidget(self._log) - self._layout.addWidget(self._data_btn) - - def view(self) -> Any: - """Returns an object that can be displayed by the active backend.""" - return self._wdg - - -app = QApplication.instance() or QApplication([]) - -widget = Controller() -widget.view().show() -app.exec() diff --git a/examples/notebook.ipynb b/examples/notebook.ipynb new file mode 100644 index 00000000..f0f1b705 --- /dev/null +++ b/examples/notebook.ipynb @@ -0,0 +1,86 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f2fe7ce4f36847ffaba8c042bf85b76d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "92985e7df82f47cabfe3ab5ce9ab0498", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='.ndarray (60, 2, 256, 256), uint16, 15.73MB'), CanvasBackend(css_height='600px', cโ€ฆ" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ndv import data, imshow\n", + "\n", + "viewer = imshow(data.cells3d())\n", + "viewer.model.channel_mode = \"composite\"\n", + "viewer.model.current_index.update({0: 32})" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "viewer.model.visible_axes = (0, 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "viewer.model.default_lut.cmap = \"cubehelix\"\n", + "viewer.model.channel_mode = \"grayscale\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/numpy_arr.py b/examples/numpy_arr.py index 7a4ee67a..d48bd687 100644 --- a/examples/numpy_arr.py +++ b/examples/numpy_arr.py @@ -8,4 +8,4 @@ print(e) img = ndv.data.nd_sine_wave((10, 3, 8, 512, 512)) -ndv.imshow(img) +viewer = ndv.imshow(img) diff --git a/examples/tensorstore_arr.py b/examples/tensorstore_arr.py index 7314cddc..f0a088d5 100644 --- a/examples/tensorstore_arr.py +++ b/examples/tensorstore_arr.py @@ -1,22 +1,12 @@ from __future__ import annotations +import ndv + try: - import tensorstore as ts + from ndv.data import cosem_dataset + + ts_array = cosem_dataset() except ImportError: raise ImportError("Please install tensorstore to run this example") - -import ndv - -ts_array = ts.open( - { - "driver": "n5", - "kvstore": { - "driver": "s3", - "bucket": "janelia-cosem-datasets", - "path": "jrc_hela-3/jrc_hela-3.n5/labels/er-mem_pred/s4/", - }, - }, -).result() -ts_array = ts_array[ts.d[:].label["z", "y", "x"]] -ndv.imshow(ts_array[ts.d[("y", "x", "z")].transpose[:]]) +ndv.imshow(ts_array) diff --git a/examples/xarray_arr.py b/examples/xarray_arr.py index b19e139f..330e2d6e 100644 --- a/examples/xarray_arr.py +++ b/examples/xarray_arr.py @@ -7,4 +7,4 @@ import ndv da = xr.tutorial.open_dataset("air_temperature").air -ndv.imshow(da, cmap="thermal") +ndv.imshow(da, default_lut={"cmap": "thermal"}) diff --git a/examples/zarr_arr.py b/examples/zarr_arr.py index e3a759bc..a951422f 100644 --- a/examples/zarr_arr.py +++ b/examples/zarr_arr.py @@ -12,4 +12,4 @@ URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr" zarr_arr = zarr.open(URL, mode="r") -ndv.imshow(zarr_arr["s0"]) +ndv.imshow(zarr_arr["s0"].astype("uint16")) diff --git a/pyproject.toml b/pyproject.toml index 334d820c..5d351fd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,10 @@ description = "simple nd image viewer" readme = "README.md" requires-python = ">=3.9" license = { text = "BSD-3-Clause" } -authors = [{ name = "Talley Lambert", email = "talley.lambert@gmail.com" }] +authors = [ + { name = "Talley Lambert", email = "talley.lambert@gmail.com" }, + { name = "Gabriel Selzer", email = "gjselzer@wisc.edu" }, +] classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", @@ -32,14 +35,45 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Typing :: Typed", ] -dependencies = ["qtpy", "numpy", "superqt[cmap,iconify]", "psygnal"] +dependencies = [ + "qtpy", + "numpy", + "superqt[cmap,iconify]", + "pydantic", + "psygnal", + "typing_extensions", +] # https://peps.python.org/pep-0621/#dependencies-optional-dependencies [project.optional-dependencies] +# Supported GUI frontends +jup = ["ipywidgets", "jupyter", "jupyter_rfb", "glfw"] pyqt = ["pyqt6"] -vispy = ["vispy", "pyopengl"] -pyside = ["pyside6"] -pygfx = ["pygfx"] +pyside = ["pyside6<6.8"] +wxpython = ["wxpython"] + +# Supported Canavs backends +vispy = ["vispy>=0.14.3", "pyopengl"] +pygfx = ["pygfx>=0.6.0"] + +# ready to go bundles with vispy +qt = ["ndv[vispy,pyqt]", "imageio[tifffile]"] +jupyter = ["ndv[jup,vispy]", "imageio[tifffile]"] +wx = ["ndv[vispy,wxpython]", "imageio[tifffile]"] + +test = ["imageio[tifffile]", "pytest-cov", "pytest"] +dev = [ + "ndv[test,vispy,pygfx,pyqt,jupyter]", + "pytest-qt", + "ipython", + "mypy", + "pdbpp", + "pre-commit", + "rich", + "ruff", + "ipykernel", +] + third_party_arrays = [ "aiohttp", # for zarr example "jax[cpu]", @@ -50,17 +84,9 @@ third_party_arrays = [ "numpy<2.0", # for tensorstore (at least) "torch", "xarray", - "zarr", -] -test = [ - "ndv[vispy,pygfx]", + "zarr<3", "dask[array]", - "imageio[tifffile]", - "pytest-cov", - "pytest-qt", - "pytest", ] -dev = ["ipython", "mypy", "pdbpp", "pre-commit", "rich", "ruff"] [project.urls] homepage = "https://github.com/pyapp-kit/ndv" @@ -87,11 +113,12 @@ select = [ "B", # flake8-bugbear "A001", # flake8-builtins "RUF", # ruff-specific rules - "TCH", # flake8-type-checking + "TC", # flake8-type-checking "TID", # flake8-tidy-imports ] ignore = [ "D401", # First line should be in imperative mood + "D10", # Missing docstring... ] [tool.ruff.lint.per-file-ignores] @@ -112,6 +139,11 @@ disallow_any_generics = false disallow_subclassing_any = false show_error_codes = true pretty = true +plugins = ["pydantic.mypy"] + +[[tool.mypy.overrides]] +module = ["jupyter_rfb.*", "vispy.*", "ipywidgets.*"] +ignore_missing_imports = true # https://docs.pytest.org/ [tool.pytest.ini_options] @@ -123,9 +155,9 @@ filterwarnings = [ "ignore:Method has been garbage collected::superqt", # occasionally happens on linux CI with vispy "ignore:Got wrong number of dimensions", - # requires pygfx > 0.2.0 - "ignore:This version of pygfx does not yet support additive blending", "ignore:Unable to import recommended hash", + # CI-only error on jupyter, vispy, macos + "ignore:.*Failed to find a suitable pixel format", ] markers = ["allow_leaks: mark test to allow widget leaks"] @@ -152,3 +184,6 @@ ignore = [".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*"] [tool.typos.default] extend-ignore-identifiers-re = ["(?i)nd2?.*", "(?i)ome", ".*ser_schema"] + +[tool.typos.files] +extend-exclude = ["**/*.ipynb"] diff --git a/src/ndv/__init__.py b/src/ndv/__init__.py index ca5bcfb9..5f309980 100644 --- a/src/ndv/__init__.py +++ b/src/ndv/__init__.py @@ -6,24 +6,11 @@ __version__ = version("ndv") except PackageNotFoundError: __version__ = "uninstalled" -__author__ = "Talley Lambert" -__email__ = "talley.lambert@example.com" - -from typing import TYPE_CHECKING from . import data +from ._views import run_app +from .controllers import ArrayViewer +from .models import DataWrapper from .util import imshow -from .viewer._data_wrapper import DataWrapper -from .viewer._viewer import NDViewer - -__all__ = ["DataWrapper", "NDViewer", "data", "imshow"] - - -if TYPE_CHECKING: - # these may be used externally, but are not guaranteed to be available at runtime - # they must be used inside a TYPE_CHECKING block - from .viewer._dims_slider import DimKey as DimKey - from .viewer._dims_slider import Index as Index - from .viewer._dims_slider import Indices as Indices - from .viewer._dims_slider import Sizes as Sizes +__all__ = ["ArrayViewer", "DataWrapper", "data", "imshow", "run_app"] diff --git a/src/ndv/_types.py b/src/ndv/_types.py new file mode 100644 index 00000000..27660487 --- /dev/null +++ b/src/ndv/_types.py @@ -0,0 +1,109 @@ +"""General model for ndv.""" + +from __future__ import annotations + +from collections.abc import Hashable, Sequence +from contextlib import suppress +from enum import Enum, IntFlag, auto +from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, cast + +from pydantic import PlainSerializer, PlainValidator +from typing_extensions import TypeAlias + +if TYPE_CHECKING: + from qtpy.QtCore import Qt + from qtpy.QtWidgets import QWidget + + from ndv._views.bases._view_base import Viewable + + +def _maybe_int(val: Any) -> Any: + # try to convert to int if possible + with suppress(ValueError, TypeError): + val = int(float(val)) + return val + + +def _to_slice(val: Any) -> slice: + # slices are returned as is + if isinstance(val, slice): + if not all( + isinstance(i, (int, type(None))) for i in (val.start, val.stop, val.step) + ): + raise TypeError(f"Slice start/stop/step must all be integers: {val!r}") + return val + # single integers are converted to slices starting at that index + if isinstance(val, int): + return slice(val, val + 1) + # sequences are interpreted as arguments to the slice constructor + if isinstance(val, Sequence): + return slice(*(int(x) if x is not None else None for x in val)) + raise TypeError(f"Expected int or slice, got {type(val)}") + + +Slice = Annotated[slice, PlainValidator(_to_slice)] + +# An axis key is any hashable object that can be used to index an axis +# In many cases it will be an integer, but for some labeled arrays it may be a string +# or other hashable object. It is up to the DataWrapper to convert these keys to +# actual integer indices. +AxisKey: TypeAlias = Annotated[ + Hashable, PlainValidator(_maybe_int), PlainSerializer(str, return_type=str) +] + + +class MouseButton(IntFlag): + LEFT = auto() + MIDDLE = auto() + RIGHT = auto() + + +class MouseMoveEvent(NamedTuple): + """Event emitted when the user moves the cursor.""" + + x: float + y: float + + +class MousePressEvent(NamedTuple): + """Event emitted when mouse button is pressed.""" + + x: float + y: float + btn: MouseButton = MouseButton.LEFT + + +class MouseReleaseEvent(NamedTuple): + """Event emitted when mouse button is released.""" + + x: float + y: float + btn: MouseButton = MouseButton.LEFT + + +class CursorType(Enum): + DEFAULT = "default" + V_ARROW = "v_arrow" + H_ARROW = "h_arrow" + ALL_ARROW = "all_arrow" + BDIAG_ARROW = "bdiag_arrow" + FDIAG_ARROW = "fdiag_arrow" + + def apply_to(self, widget: Viewable) -> None: + """Applies the cursor type to the given widget.""" + native = widget.frontend_widget() + if hasattr(native, "setCursor"): + cast("QWidget", native).setCursor(self.to_qt()) + + def to_qt(self) -> Qt.CursorShape: + """Converts CursorType to Qt.CursorShape.""" + from qtpy.QtCore import Qt + + return { + CursorType.DEFAULT: Qt.CursorShape.ArrowCursor, + CursorType.V_ARROW: Qt.CursorShape.SizeVerCursor, + CursorType.H_ARROW: Qt.CursorShape.SizeHorCursor, + CursorType.ALL_ARROW: Qt.CursorShape.SizeAllCursor, + CursorType.BDIAG_ARROW: Qt.CursorShape.SizeBDiagCursor, + CursorType.FDIAG_ARROW: Qt.CursorShape.SizeFDiagCursor, + }[self] diff --git a/src/ndv/_views/__init__.py b/src/ndv/_views/__init__.py new file mode 100644 index 00000000..e73a7d97 --- /dev/null +++ b/src/ndv/_views/__init__.py @@ -0,0 +1,15 @@ +from ._app import ( + get_array_canvas_class, + get_array_view_class, + get_histogram_canvas_class, + gui_frontend, + run_app, +) + +__all__ = [ + "get_array_canvas_class", + "get_array_view_class", + "get_histogram_canvas_class", + "gui_frontend", + "run_app", +] diff --git a/src/ndv/_views/_app.py b/src/ndv/_views/_app.py new file mode 100644 index 00000000..6742bc9d --- /dev/null +++ b/src/ndv/_views/_app.py @@ -0,0 +1,583 @@ +from __future__ import annotations + +import importlib.util +import os +import sys +import traceback +from contextlib import suppress +from enum import Enum +from functools import cache +from types import MethodType +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, cast + +from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent + +if TYPE_CHECKING: + from collections.abc import Container + from types import TracebackType + + from IPython.core.interactiveshell import InteractiveShell + + from ndv._views.bases import ArrayCanvas, ArrayView, HistogramCanvas + from ndv._views.bases.graphics._mouseable import Mouseable + + +GUI_ENV_VAR = "NDV_GUI_FRONTEND" +"""Preferred GUI frontend. If not set, the first available GUI frontend is used.""" + +CANVAS_ENV_VAR = "NDV_CANVAS_BACKEND" +"""Preferred canvas backend. If not set, the first available canvas backend is used.""" + +DEBUG_EXCEPTIONS = "NDV_DEBUG_EXCEPTIONS" +"""Whether to drop into a debugger when an exception is raised. Default False.""" + +EXIT_ON_EXCEPTION = "NDV_EXIT_ON_EXCEPTION" +"""Whether to exit the application when an exception is raised. Default False.""" + +IPYTHON_GUI_QT = "NDV_IPYTHON_GUI_QT" +"""Whether to use gui_qt magic when running in IPython. Default True.""" + + +class GuiFrontend(str, Enum): + QT = "qt" + JUPYTER = "jupyter" + WX = "wx" + + +class CanvasBackend(str, Enum): + VISPY = "vispy" + PYGFX = "pygfx" + + +class GuiProvider(Protocol): + @staticmethod + def is_running() -> bool: ... + @staticmethod + def create_app() -> bool: ... + @staticmethod + def array_view_class() -> type[ArrayView]: ... + @staticmethod + def exec() -> None: ... + @staticmethod + def filter_mouse_events(canvas: Any, receiver: Mouseable) -> Callable[[], None]: ... + + +class CanvasProvider(Protocol): + @staticmethod + def is_imported() -> bool: ... + @staticmethod + def is_available() -> bool: ... + @staticmethod + def array_canvas_class() -> type[ArrayCanvas]: ... + @staticmethod + def histogram_canvas_class() -> type[HistogramCanvas]: ... + + +class QtProvider(GuiProvider): + """Provider for PyQt5/PySide2/PyQt6/PySide6.""" + + _APP_INSTANCE: ClassVar[Any] = None + + @staticmethod + def is_running() -> bool: + for mod_name in ("PyQt5", "PySide2", "PySide6", "PyQt6"): + if mod := sys.modules.get(f"{mod_name}.QtWidgets"): + if qapp := getattr(mod, "QApplication", None): + return qapp.instance() is not None + return False + + @staticmethod + def create_app() -> Any: + from qtpy.QtWidgets import QApplication + + if (qapp := QApplication.instance()) is None: + # if we're running in IPython + # start the %gui qt magic if NDV_IPYTHON_GUI_QT!=0 + if (ipy_shell := _ipython_shell()) and ( + os.getenv(IPYTHON_GUI_QT, "true").lower() not in ("0", "false", "no") + ): + ipy_shell.enable_gui("qt") # type: ignore [no-untyped-call] + # otherwise create a new QApplication + else: + # must be stored in a class variable to prevent garbage collection + QtProvider._APP_INSTANCE = qapp = QApplication(sys.argv) + qapp.setOrganizationName("ndv") + qapp.setApplicationName("ndv") + + _install_excepthook() + return qapp + + @staticmethod + def exec() -> None: + from qtpy.QtWidgets import QApplication + + app = QApplication.instance() or QtProvider.create_app() + + for wdg in QApplication.topLevelWidgets(): + wdg.raise_() + + if ipy_shell := _ipython_shell(): + # if we're already in an IPython session with %gui qt, don't block + if str(ipy_shell.active_eventloop).startswith("qt"): + return + + app.exec() + + @staticmethod + def array_view_class() -> type[ArrayView]: + from ._qt._array_view import QtArrayView + + return QtArrayView + + @staticmethod + def filter_mouse_events(canvas: Any, receiver: Mouseable) -> Callable[[], None]: + from qtpy.QtCore import QEvent, QObject + from qtpy.QtGui import QMouseEvent + + if not isinstance(canvas, QObject): + raise TypeError(f"Expected canvas to be QObject, got {type(canvas)}") + + class Filter(QObject): + def eventFilter(self, obj: QObject | None, qevent: QEvent | None) -> bool: + """Event filter installed on the canvas to handle mouse events. + + here is where we get a chance to intercept mouse events before allowing + the canvas to respond to them. Return `True` to prevent the event from + being passed to the canvas. + """ + if qevent is None: + return False # pragma: no cover + + try: + # use children in case backend has a subwidget stealing events. + children: Container = canvas.children() + except RuntimeError: + # native is likely dead + return False + + intercept = False + if obj is canvas or obj in children: + if isinstance(qevent, QMouseEvent): + pos = qevent.pos() + etype = qevent.type() + if etype == QEvent.Type.MouseMove: + mme = MouseMoveEvent(x=pos.x(), y=pos.y()) + intercept |= receiver.on_mouse_move(mme) + receiver.mouseMoved.emit(mme) + elif etype == QEvent.Type.MouseButtonPress: + mpe = MousePressEvent(x=pos.x(), y=pos.y()) + intercept |= receiver.on_mouse_press(mpe) + receiver.mousePressed.emit(mpe) + elif etype == QEvent.Type.MouseButtonRelease: + mre = MouseReleaseEvent(x=pos.x(), y=pos.y()) + intercept |= receiver.on_mouse_release(mre) + receiver.mouseReleased.emit(mre) + return intercept + + f = Filter() + canvas.installEventFilter(f) + return lambda: canvas.removeEventFilter(f) + + +class WxProvider(GuiProvider): + """Provider for wxPython.""" + + @staticmethod + def is_running() -> bool: + if wx := sys.modules.get("wx"): + return wx.App.Get() is not None + return False + + @staticmethod + def create_app() -> Any: + import wx + + if (wxapp := wx.App.Get()) is None: + wxapp = wx.App() + + _install_excepthook() + return wxapp + + @staticmethod + def exec() -> None: + import wx + + app = wx.App.Get() or WxProvider.create_app() + app.MainLoop() + _install_excepthook() + + @staticmethod + def array_view_class() -> type[ArrayView]: + from ._wx._array_view import WxArrayView + + return WxArrayView + + @staticmethod + def filter_mouse_events(canvas: Any, receiver: Mouseable) -> Callable[[], None]: + from wx import EVT_LEFT_DOWN, EVT_LEFT_UP, EVT_MOTION, EvtHandler, MouseEvent + + if not isinstance(canvas, EvtHandler): + raise TypeError( + f"Expected vispy canvas to be wx EvtHandler, got {type(canvas)}" + ) + + # TIP: event.Skip() allows the event to propagate to other handlers. + + def on_mouse_move(event: MouseEvent) -> None: + mme = MouseMoveEvent(x=event.GetX(), y=event.GetY()) + if not receiver.on_mouse_move(mme): + receiver.mouseMoved.emit(mme) + event.Skip() + + def on_mouse_press(event: MouseEvent) -> None: + mpe = MousePressEvent(x=event.GetX(), y=event.GetY()) + if not receiver.on_mouse_press(mpe): + receiver.mousePressed.emit(mpe) + event.Skip() + + def on_mouse_release(event: MouseEvent) -> None: + mre = MouseReleaseEvent(x=event.GetX(), y=event.GetY()) + if not receiver.on_mouse_release(mre): + receiver.mouseReleased.emit(mre) + event.Skip() + + canvas.Bind(EVT_MOTION, on_mouse_move) + canvas.Bind(EVT_LEFT_DOWN, on_mouse_press) + canvas.Bind(EVT_LEFT_UP, on_mouse_release) + + def _unbind() -> None: + canvas.Unbind(EVT_MOTION, on_mouse_move) + canvas.Unbind(EVT_LEFT_DOWN, on_mouse_press) + canvas.Unbind(EVT_LEFT_UP, on_mouse_release) + + return _unbind + + +class JupyterProvider(GuiProvider): + """Provider for Jupyter notebooks/lab (NOT ipython).""" + + @staticmethod + def is_running() -> bool: + if ipy_shell := _ipython_shell(): + return bool(ipy_shell.__class__.__name__ == "ZMQInteractiveShell") + return False + + @staticmethod + def create_app() -> Any: + if not JupyterProvider.is_running() and not os.getenv("PYTEST_CURRENT_TEST"): + # if we got here, it probably means that someone used + # NDV_GUI_FRONTEND=jupyter without actually being in a juptyer notebook + # we allow it in tests, but not in normal usage. + raise RuntimeError( # pragma: no cover + "Jupyter is not running a notebook shell. Cannot create app." + ) + return None + + @staticmethod + def exec() -> None: + pass + + @staticmethod + def array_view_class() -> type[ArrayView]: + from ._jupyter._array_view import JupyterArrayView + + return JupyterArrayView + + @staticmethod + def filter_mouse_events(canvas: Any, receiver: Mouseable) -> Callable[[], None]: + from jupyter_rfb import RemoteFrameBuffer + + if not isinstance(canvas, RemoteFrameBuffer): + raise TypeError( + f"Expected canvas to be RemoteFrameBuffer, got {type(canvas)}" + ) + + # patch the handle_event from _jupyter_rfb.CanvasBackend + # to intercept various mouse events. + super_handle_event = canvas.handle_event + + def handle_event(self: RemoteFrameBuffer, ev: dict) -> None: + etype = ev["event_type"] + if etype == "pointer_move": + mme = MouseMoveEvent(x=ev["x"], y=ev["y"]) + receiver.on_mouse_move(mme) + receiver.mouseMoved.emit(mme) + elif etype == "pointer_down": + mpe = MousePressEvent(x=ev["x"], y=ev["y"]) + receiver.on_mouse_press(mpe) + receiver.mousePressed.emit(mpe) + elif etype == "pointer_up": + mre = MouseReleaseEvent(x=ev["x"], y=ev["y"]) + receiver.on_mouse_release(mre) + receiver.mouseReleased.emit(mre) + super_handle_event(ev) + + canvas.handle_event = MethodType(handle_event, canvas) + return lambda: setattr(canvas, "handle_event", super_handle_event) + + +class VispyProvider(CanvasProvider): + @staticmethod + def is_imported() -> bool: + return "vispy" in sys.modules + + @staticmethod + def is_available() -> bool: + return importlib.util.find_spec("vispy") is not None + + @staticmethod + def array_canvas_class() -> type[ArrayCanvas]: + from vispy.app import use_app + + from ndv._views._vispy._array_canvas import VispyArrayCanvas + + # these may not be necessary, since we likely have already called + # create_app by this point and vispy will autodetect that. + # it's an extra precaution + _frontend = gui_frontend() + if _frontend == GuiFrontend.JUPYTER: + use_app("jupyter_rfb") + elif _frontend == GuiFrontend.WX: + use_app("wx") + # there is no `use_app('qt')`... it's all specific to pyqt/pyside, etc... + # so we just let vispy autodetect it + + return VispyArrayCanvas + + @staticmethod + def histogram_canvas_class() -> type[HistogramCanvas]: + from ndv._views._vispy._histogram import VispyHistogramCanvas + + return VispyHistogramCanvas + + +class PygfxProvider(CanvasProvider): + @staticmethod + def is_imported() -> bool: + return "pygfx" in sys.modules + + @staticmethod + def is_available() -> bool: + return importlib.util.find_spec("pygfx") is not None + + @staticmethod + def array_canvas_class() -> type[ArrayCanvas]: + from ndv._views._pygfx._array_canvas import GfxArrayCanvas + + return GfxArrayCanvas + + @staticmethod + def histogram_canvas_class() -> type[HistogramCanvas]: + raise RuntimeError("Histogram not supported for pygfx") + + +def _ipython_shell() -> InteractiveShell | None: + if (ipy := sys.modules.get("IPython")) and (shell := ipy.get_ipython()): + return cast("InteractiveShell", shell) + return None + + +# -------------------- Provider selection -------------------- + +# list of available GUI frontends and canvas backends, tried in order + +GUI_PROVIDERS: dict[GuiFrontend, GuiProvider] = { + GuiFrontend.QT: QtProvider, + GuiFrontend.WX: WxProvider, + GuiFrontend.JUPYTER: JupyterProvider, +} +CANVAS_PROVIDERS: dict[CanvasBackend, CanvasProvider] = { + CanvasBackend.VISPY: VispyProvider, + CanvasBackend.PYGFX: PygfxProvider, +} + + +@cache # not allowed to change +def gui_frontend() -> GuiFrontend: + """Return the preferred GUI frontend. + + This is determined first by the NDV_GUI_FRONTEND environment variable, after which + GUI_PROVIDERS are tried in order until one is found that is either already running, + or available. + """ + requested = os.getenv(GUI_ENV_VAR, "").lower() + valid = {x.value for x in GuiFrontend} + if requested: + if requested not in valid: + raise ValueError( + f"Invalid GUI frontend: {requested!r}. Valid options: {valid}" + ) + key = GuiFrontend(requested) + # ensure the app is created for explicitly requested frontends + provider = GUI_PROVIDERS[key] + if not provider.is_running(): + provider.create_app() + return key + + for key, provider in GUI_PROVIDERS.items(): + if provider.is_running(): + return key + + errors: list[tuple[GuiFrontend, BaseException]] = [] + for key, provider in GUI_PROVIDERS.items(): + try: + provider.create_app() + return key + except Exception as e: + errors.append((key, e)) + + raise RuntimeError( # pragma: no cover + f"Could not find an appropriate GUI frontend: {valid!r}. Tried:\n\n" + + "\n".join(f"- {key.value}: {err}" for key, err in errors) + ) + + +def canvas_backend(requested: str | None) -> CanvasBackend: + """Return the preferred canvas backend. + + This is determined first by the NDV_CANVAS_BACKEND environment variable, after which + CANVAS_PROVIDERS are tried in order until one is found that is either already + imported or available + """ + backend = requested or os.getenv(CANVAS_ENV_VAR, "").lower() + + valid = {x.value for x in CanvasBackend} + if backend: + if backend not in valid: + raise ValueError( + f"Invalid canvas backend: {backend!r}. Valid options: {valid}" + ) + return CanvasBackend(backend) + + for key, provider in CANVAS_PROVIDERS.items(): + if provider.is_imported(): + return key + errors: list[tuple[CanvasBackend, BaseException]] = [] + for key, provider in CANVAS_PROVIDERS.items(): + try: + if provider.is_available(): + return key + except Exception as e: + errors.append((key, e)) + + raise RuntimeError( # pragma: no cover + f"Could not find an appropriate canvas backend: {valid!r}. Tried:\n\n" + + "\n".join(f"- {key.value}: {err}" for key, err in errors) + ) + + +# TODO: add a way to set the frontend via an environment variable +# (for example, it should be possible to use qt frontend in a jupyter notebook) +def get_array_view_class() -> type[ArrayView]: + if (frontend := gui_frontend()) not in GUI_PROVIDERS: # pragma: no cover + raise NotImplementedError(f"No GUI frontend found for {frontend}") + return GUI_PROVIDERS[frontend].array_view_class() + + +def get_array_canvas_class(backend: str | None = None) -> type[ArrayCanvas]: + _backend = canvas_backend(backend) + if _backend not in CANVAS_PROVIDERS: # pragma: no cover + raise NotImplementedError(f"No canvas backend found for {_backend}") + return CANVAS_PROVIDERS[_backend].array_canvas_class() + + +def get_histogram_canvas_class(backend: str | None = None) -> type[HistogramCanvas]: + _backend = canvas_backend(backend) + if _backend not in CANVAS_PROVIDERS: # pragma: no cover + raise NotImplementedError(f"No canvas backend found for {_backend}") + return CANVAS_PROVIDERS[_backend].histogram_canvas_class() + + +def filter_mouse_events(canvas: Any, receiver: Mouseable) -> Callable[[], None]: + """Intercept mouse events on `scene_canvas` and forward them to `receiver`. + + Parameters + ---------- + canvas : Any + The front-end canvas widget to intercept mouse events from. + receiver : Mouseable + The object to forward mouse events to. + + Returns + ------- + Callable[[], None] + A function that can be called to remove the event filter. + """ + return GUI_PROVIDERS[gui_frontend()].filter_mouse_events(canvas, receiver) + + +def run_app() -> None: + """Start the GUI application event loop.""" + GUI_PROVIDERS[gui_frontend()].exec() + + +# -------------------- Exception handling -------------------- + + +def _install_excepthook() -> None: + """Install a custom excepthook that does not raise sys.exit(). + + This is necessary to prevent the application from closing when an exception + is raised. + """ + if hasattr(sys, "_original_excepthook_"): + # don't install the excepthook more than once + return + sys._original_excepthook_ = sys.excepthook # type: ignore + sys.excepthook = ndv_excepthook + + +def _print_exception( + exc_type: type[BaseException], + exc_value: BaseException, + exc_traceback: TracebackType | None, +) -> None: + try: + import psygnal + from rich.console import Console + from rich.traceback import Traceback + + tb = Traceback.from_exception( + exc_type, exc_value, exc_traceback, suppress=[psygnal], max_frames=10 + ) + Console(stderr=True).print(tb) + except ImportError: + traceback.print_exception(exc_type, value=exc_value, tb=exc_traceback) + + +def ndv_excepthook( + exc_type: type[BaseException], exc_value: BaseException, tb: TracebackType | None +) -> None: + _print_exception(exc_type, exc_value, tb) + if not tb: + return + + if ( + (debugpy := sys.modules.get("debugpy")) + and debugpy.is_client_connected() + and ("pydevd" in sys.modules) + ): + with suppress(Exception): + import threading + + import pydevd + + py_db = pydevd.get_global_debugger() + thread = threading.current_thread() + additional_info = py_db.set_additional_thread_info(thread) + additional_info.is_tracing += 1 + + try: + arg = (exc_type, exc_value, tb) + py_db.stop_on_unhandled_exception(py_db, thread, additional_info, arg) + finally: + additional_info.is_tracing -= 1 + elif os.getenv(DEBUG_EXCEPTIONS): + # Default to pdb if no better option is available + import pdb + + pdb.post_mortem(tb) + + if os.getenv(EXIT_ON_EXCEPTION): + print(f"\n{EXIT_ON_EXCEPTION} is set, exiting.") + sys.exit(1) diff --git a/src/ndv/_views/_jupyter/__init__.py b/src/ndv/_views/_jupyter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/_views/_jupyter/_array_view.py b/src/ndv/_views/_jupyter/_array_view.py new file mode 100644 index 00000000..6888c619 --- /dev/null +++ b/src/ndv/_views/_jupyter/_array_view.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, cast + +import cmap +import ipywidgets as widgets + +from ndv._views.bases import ArrayView, LutView +from ndv.models._array_display_model import ChannelMode + +if TYPE_CHECKING: + from collections.abc import Container, Hashable, Mapping, Sequence + + from vispy.app.backends import _jupyter_rfb + + from ndv._types import AxisKey + +# not entirely sure why it's necessary to specifically annotat signals as : PSignal +# i think it has to do with type variance? + + +class JupyterLutView(LutView): + def __init__(self) -> None: + # WIDGETS + self._visible = widgets.Checkbox(value=True) + self._cmap = widgets.Dropdown( + options=["gray", "green", "magenta", "cubehelix"], value="gray" + ) + self._clims = widgets.FloatRangeSlider( + value=[0, 2**16], + min=0, + max=2**16, + step=1, + orientation="horizontal", + readout=True, + readout_format=".0f", + ) + self._auto_clim = widgets.ToggleButton( + value=True, + description="Auto", + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Auto scale", + icon="check", + ) + + # LAYOUT + + self.layout = widgets.HBox( + [self._visible, self._cmap, self._clims, self._auto_clim] + ) + + # CONNECTIONS + self._visible.observe(self._on_visible_changed, names="value") + self._cmap.observe(self._on_cmap_changed, names="value") + self._clims.observe(self._on_clims_changed, names="value") + self._auto_clim.observe(self._on_autoscale_changed, names="value") + + # ------------------ emit changes to the controller ------------------ + + def _on_clims_changed(self, change: dict[str, Any]) -> None: + self.climsChanged.emit(self._clims.value) + + def _on_visible_changed(self, change: dict[str, Any]) -> None: + self.visibilityChanged.emit(self._visible.value) + + def _on_cmap_changed(self, change: dict[str, Any]) -> None: + self.cmapChanged.emit(cmap.Colormap(self._cmap.value)) + + def _on_autoscale_changed(self, change: dict[str, Any]) -> None: + self.autoscaleChanged.emit(self._auto_clim.value) + + # ------------------ receive changes from the controller --------------- + + def set_channel_name(self, name: str) -> None: + self._visible.description = name + + # NOTE: it's important to block signals when setting values from the controller + # to avoid loops, unnecessary updates, and unexpected behavior + + def set_auto_scale(self, auto: bool) -> None: + with self.autoscaleChanged.blocked(): + self._auto_clim.value = auto + + def set_colormap(self, cmap: cmap.Colormap) -> None: + with self.cmapChanged.blocked(): + self._cmap.value = cmap.name.split(":")[-1] # FIXME: this is a hack + + def set_clims(self, clims: tuple[float, float]) -> None: + with self.climsChanged.blocked(): + self._clims.value = clims + + def set_channel_visible(self, visible: bool) -> None: + with self.visibilityChanged.blocked(): + self._visible.value = visible + + def set_gamma(self, gamma: float) -> None: + pass + + def set_visible(self, visible: bool) -> None: + # show or hide the actual widget itself + self.layout.layout.display = "flex" if visible else "none" + + def close(self) -> None: + self.layout.close() + + def frontend_widget(self) -> Any: + return self.layout + + +class JupyterArrayView(ArrayView): + def __init__( + self, canvas_widget: _jupyter_rfb.CanvasBackend, **kwargs: Any + ) -> None: + # WIDGETS + self._canvas_widget = canvas_widget + + self._sliders: dict[Hashable, widgets.IntSlider] = {} + self._slider_box = widgets.VBox([]) + self._data_info_label = widgets.Label() + self._hover_info_label = widgets.Label() + + # the button that controls the display mode of the channels + self._channel_mode_combo = widgets.Dropdown( + options=[x.value for x in ChannelMode], value=str(ChannelMode.GRAYSCALE) + ) + + self._channel_mode_combo.observe(self._on_channel_mode_changed, names="value") + + # LAYOUT + + self.layout = widgets.VBox( + [ + self._data_info_label, + self._canvas_widget, + self._hover_info_label, + self._slider_box, + self._channel_mode_combo, + ] + ) + + # CONNECTIONS + + self._channel_mode_combo.observe(self._on_channel_mode_changed, names="value") + + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: + """Update sliders with the given coordinate ranges.""" + sliders = [] + self._sliders.clear() + for axis, _coords in coords.items(): + if not isinstance(_coords, range): + raise NotImplementedError("Only range is supported for now") + + sld = widgets.IntSlider( + value=_coords.start, + min=_coords.start, + max=_coords.stop - 1, + step=_coords.step, + description=str(axis), + continuous_update=True, + orientation="horizontal", + ) + sld.observe(self._on_slider_change, "value") + sliders.append(sld) + self._sliders[axis] = sld + self._slider_box.children = sliders + + self.currentIndexChanged.emit() + + def hide_sliders( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + """Hide sliders based on visible axes.""" + for ax, slider in self._sliders.items(): + if ax in axes_to_hide: + slider.layout.display = "none" + elif show_remainder: + slider.layout.display = "flex" + + def current_index(self) -> Mapping[AxisKey, int | slice]: + """Return the current value of the sliders.""" + return {axis: slider.value for axis, slider in self._sliders.items()} + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + """Set the current value of the sliders.""" + changed = False + # this type ignore is only necessary because we had to override the signal + # to be a PSignal in the class def above :( + with self.currentIndexChanged.blocked(): + for axis, val in value.items(): + if isinstance(val, slice): + raise NotImplementedError("Slices are not supported yet") + + if sld := self._sliders.get(axis): + if sld.value != val: + sld.value = val + changed = True + else: # pragma: no cover + warnings.warn(f"Axis {axis} not found in sliders", stacklevel=2) + if changed: + self.currentIndexChanged.emit() + + def add_lut_view(self) -> JupyterLutView: + """Add a LUT view to the viewer.""" + wdg = JupyterLutView() + self.layout.children = (*self.layout.children, wdg.layout) + return wdg + + def remove_lut_view(self, view: LutView) -> None: + """Remove a LUT view from the viewer.""" + view = cast("JupyterLutView", view) + self.layout.children = tuple( + wdg for wdg in self.layout.children if wdg != view.frontend_widget() + ) + + def set_data_info(self, data_info: str) -> None: + self._data_info_label.value = data_info + + def set_hover_info(self, hover_info: str) -> None: + self._hover_info_label.value = hover_info + + def set_channel_mode(self, mode: ChannelMode) -> None: + with self.channelModeChanged.blocked(): + self._channel_mode_combo.value = mode.value + + def _on_slider_change(self, change: dict[str, Any]) -> None: + """Emit signal when a slider value changes.""" + self.currentIndexChanged.emit() + + def _on_channel_mode_changed(self, change: dict[str, Any]) -> None: + """Emit signal when the channel mode changes.""" + self.channelModeChanged.emit(ChannelMode(change["new"])) + + def add_histogram(self, widget: Any) -> None: + """Add a histogram widget to the viewer.""" + warnings.warn("Histograms are not supported in Jupyter frontend", stacklevel=2) + + def remove_histogram(self, widget: Any) -> None: + """Remove a histogram widget from the viewer.""" + + def frontend_widget(self) -> Any: + return self.layout + + def set_visible(self, visible: bool) -> None: + # show or hide the actual widget itself + from IPython import display + + if visible: + display.display(self.layout) # type: ignore [no-untyped-call] + else: + display.clear_output() # type: ignore [no-untyped-call] + + def close(self) -> None: + self.layout.close() diff --git a/src/ndv/_views/_pygfx/__init__.py b/src/ndv/_views/_pygfx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/viewer/_backends/_pygfx.py b/src/ndv/_views/_pygfx/_array_canvas.py similarity index 81% rename from src/ndv/viewer/_backends/_pygfx.py rename to src/ndv/_views/_pygfx/_array_canvas.py index b6307b03..97ee47b1 100755 --- a/src/ndv/viewer/_backends/_pygfx.py +++ b/src/ndv/_views/_pygfx/_array_canvas.py @@ -1,26 +1,29 @@ from __future__ import annotations import warnings +from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, Literal, cast from weakref import WeakKeyDictionary -import cmap +import cmap as _cmap import numpy as np import pygfx import pylinalg as la -from qtpy.QtCore import QSize, Qt -from wgpu.gui.qt import QWgpuCanvas -from ._protocols import PCanvas +from ndv._types import CursorType +from ndv._views._app import filter_mouse_events +from ndv._views.bases import ArrayCanvas, CanvasElement, ImageHandle if TYPE_CHECKING: from collections.abc import Sequence + from typing import TypeAlias from pygfx.materials import ImageBasicMaterial from pygfx.resources import Texture - from qtpy.QtWidgets import QWidget + from wgpu.gui.jupyter import JupyterWgpuCanvas + from wgpu.gui.qt import QWgpuCanvas - from ._protocols import CanvasElement + WgpuCanvas: TypeAlias = QWgpuCanvas | JupyterWgpuCanvas def _is_inside(bounding_box: np.ndarray, pos: Sequence[float]) -> bool: @@ -32,58 +35,55 @@ def _is_inside(bounding_box: np.ndarray, pos: Sequence[float]) -> bool: ) -class PyGFXImageHandle: +class PyGFXImageHandle(ImageHandle): def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: self._image = image self._render = render self._grid = cast("Texture", image.geometry.grid) self._material = cast("ImageBasicMaterial", image.material) - @property def data(self) -> np.ndarray: return self._grid.data # type: ignore [no-any-return] - @data.setter - def data(self, data: np.ndarray) -> None: + def set_data(self, data: np.ndarray) -> None: self._grid.data[:] = data self._grid.update_range((0, 0, 0), self._grid.size) - @property def visible(self) -> bool: return bool(self._image.visible) - @visible.setter - def visible(self, visible: bool) -> None: + def set_visible(self, visible: bool) -> None: self._image.visible = visible self._render() - @property def can_select(self) -> bool: return False - @property def selected(self) -> bool: return False - @selected.setter - def selected(self, selected: bool) -> None: + def set_selected(self, selected: bool) -> None: raise NotImplementedError("Images cannot be selected") - @property def clim(self) -> Any: return self._material.clim - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: + def set_clims(self, clims: tuple[float, float]) -> None: self._material.clim = clims self._render() - @property - def cmap(self) -> cmap.Colormap: + def gamma(self) -> float: + return 1 + + def set_gamma(self, gamma: float) -> None: + # self._material.gamma = gamma + # self._render() + warnings.warn("Gamma correction is not supported in pygfx", stacklevel=2) + + def cmap(self) -> _cmap.Colormap: return self._cmap - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: + def set_cmap(self, cmap: _cmap.Colormap) -> None: self._cmap = cmap self._material.map = cmap.to_pygfx() self._render() @@ -98,7 +98,7 @@ def remove(self) -> None: if (par := self._image.parent) is not None: par.remove(self._image) - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: return None @@ -160,51 +160,44 @@ def visible(self, visible: bool) -> None: handles.visible = self.selected self._render() - @property def can_select(self) -> bool: return True - @property def selected(self) -> bool: if self._handles: return bool(self._handles.visible) # Can't be selected without handles return False - @selected.setter - def selected(self, selected: bool) -> None: + def set_selected(self, selected: bool) -> None: if self._handles: self._handles.visible = selected - @property def color(self) -> Any: if self._fill: - return cmap.Color(self._fill.material.color) - return cmap.Color("transparent") + return _cmap.Color(self._fill.material.color) + return _cmap.Color("transparent") - @color.setter - def color(self, color: cmap.Color | None = None) -> None: + def set_color(self, color: _cmap.Color | None = None) -> None: if self._fill: if color is None: - color = cmap.Color("transparent") - if not isinstance(color, cmap.Color): - color = cmap.Color(color) + color = _cmap.Color("transparent") + if not isinstance(color, _cmap.Color): + color = _cmap.Color(color) self._fill.material.color = color.rgba self._render() - @property def border_color(self) -> Any: if self._outline: - return cmap.Color(self._outline.material.color) - return cmap.Color("transparent") + return _cmap.Color(self._outline.material.color) + return _cmap.Color("transparent") - @border_color.setter - def border_color(self, color: cmap.Color | None = None) -> None: + def set_border_color(self, color: _cmap.Color | None = None) -> None: if self._outline: if color is None: - color = cmap.Color("yellow") - if not isinstance(color, cmap.Color): - color = cmap.Color(color) + color = _cmap.Color("yellow") + if not isinstance(color, _cmap.Color): + color = _cmap.Color(color) self._outline.material.color = color.rgba self._render() @@ -220,7 +213,7 @@ def remove(self) -> None: if (par := self.parent) is not None: par.remove(self) - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: # To be implemented by subclasses raise NotImplementedError("Must be implemented in subclasses") @@ -378,7 +371,7 @@ def _handle_hover_idx(self, pos: Sequence[float]) -> int | None: return i return None - def cursor_at(self, canvas_pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, canvas_pos: Sequence[float]) -> CursorType | None: # Convert canvas -> world world_pos = self._canvas_to_world(canvas_pos) # Step 1: Check if over handle @@ -386,33 +379,58 @@ def cursor_at(self, canvas_pos: Sequence[float]) -> Qt.CursorShape | None: if np.array_equal( self._positions[idx], self._positions.min(axis=0) ) or np.array_equal(self._positions[idx], self._positions.max(axis=0)): - return Qt.CursorShape.SizeFDiagCursor - return Qt.CursorShape.SizeBDiagCursor + return CursorType.FDIAG_ARROW + return CursorType.BDIAG_ARROW # Step 2: Check if over ROI if self._outline: roi_bb = self._outline.geometry.get_bounding_box() if _is_inside(roi_bb, world_pos): - return Qt.CursorShape.SizeAllCursor + return CursorType.ALL_ARROW return None -class _QWgpuCanvas(QWgpuCanvas): - def installEventFilter(self, filter: Any) -> None: - self._subwidget.installEventFilter(filter) +def get_canvas_class() -> WgpuCanvas: + from ndv._views._app import GuiFrontend, gui_frontend + + frontend = gui_frontend() + if frontend == GuiFrontend.QT: + from qtpy.QtCore import QSize + from wgpu.gui import qt + + class QWgpuCanvas(qt.QWgpuCanvas): + def installEventFilter(self, filter: Any) -> None: + self._subwidget.installEventFilter(filter) + + def sizeHint(self) -> QSize: + return QSize(self.width(), self.height()) - def sizeHint(self) -> QSize: - return QSize(self.width(), self.height()) + return QWgpuCanvas + if frontend == GuiFrontend.JUPYTER: + from wgpu.gui.jupyter import JupyterWgpuCanvas + return JupyterWgpuCanvas + if frontend == GuiFrontend.WX: + from wgpu.gui.wx import WxWgpuCanvas -class PyGFXViewerCanvas(PCanvas): + return WxWgpuCanvas + + +class GfxArrayCanvas(ArrayCanvas): """pygfx-based canvas wrapper.""" def __init__(self) -> None: self._current_shape: tuple[int, ...] = () self._last_state: dict[Literal[2, 3], Any] = {} - self._canvas = _QWgpuCanvas(size=(600, 600)) + cls = get_canvas_class() + self._canvas = cls(size=(600, 600)) + # this filter needs to remain in scope for the lifetime of the canvas + # or mouse events will not be intercepted + # the returned function can be called to remove the filter, (and it also + # closes on the event filter and keeps it in scope). + self._disconnect_mouse_events = filter_mouse_events(self._canvas, self) + self._renderer = pygfx.renderers.WgpuRenderer(self._canvas) try: # requires https://github.com/pygfx/pygfx/pull/752 @@ -428,10 +446,10 @@ def __init__(self) -> None: self._camera: pygfx.Camera | None = None self._ndim: Literal[2, 3] | None = None - self._elements: WeakKeyDictionary = WeakKeyDictionary() + self._elements = WeakKeyDictionary[pygfx.WorldObject, CanvasElement]() - def qwidget(self) -> QWidget: - return cast("QWidget", self._canvas) + def frontend_widget(self) -> Any: + return self._canvas def set_ndim(self, ndim: Literal[2, 3]) -> None: """Set the number of dimensions of the displayed data.""" @@ -444,7 +462,8 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: self._ndim = ndim if ndim == 3: self._camera = cam = pygfx.PerspectiveCamera(0, 1) - cam.show_object(self._scene, up=(0, -1, 0), view_dir=(0, 0, 1)) + with suppress(ValueError): + cam.show_object(self._scene, up=(0, -1, 0), view_dir=(0, 0, 1)) controller = pygfx.OrbitController(cam, register_events=self._renderer) zoom = "zoom" # FIXME: there is still an issue with rotational centration. @@ -466,9 +485,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: if state := self._last_state.get(ndim): cam.set_state(state) - def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> PyGFXImageHandle: + def add_image(self, data: np.ndarray | None = None) -> PyGFXImageHandle: """Add a new Image node to the scene.""" tex = pygfx.Texture(data, dim=2) image = pygfx.Image( @@ -486,14 +503,10 @@ def add_image( # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. handle = PyGFXImageHandle(image, self.refresh) - if cmap is not None: - handle.cmap = cmap self._elements[image] = handle return handle - def add_volume( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> PyGFXImageHandle: + def add_volume(self, data: np.ndarray | None = None) -> PyGFXImageHandle: tex = pygfx.Texture(data, dim=3) vol = pygfx.Volume( pygfx.Geometry(grid=tex), @@ -511,16 +524,14 @@ def add_volume( # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. handle = PyGFXImageHandle(vol, self.refresh) - if cmap is not None: - handle.cmap = cmap self._elements[vol] = handle return handle def add_roi( self, vertices: Sequence[tuple[float, float]] | None = None, - color: cmap.Color | None = None, - border_color: cmap.Color | None = None, + color: _cmap.Color | None = None, + border_color: _cmap.Color | None = None, ) -> PyGFXRoiHandle: """Add a new Rectangular ROI node to the scene.""" handle = RectangularROIHandle(self.refresh, self.canvas_to_world) @@ -528,8 +539,8 @@ def add_roi( self._scene.add(handle) if vertices: handle.vertices = vertices - handle.color = color - handle.border_color = border_color + handle.set_color(color) + handle.set_border_color(border_color) self._elements[handle] = handle return handle @@ -562,7 +573,8 @@ def set_range( self.refresh() def refresh(self) -> None: - self._canvas.update() + with suppress(AttributeError): + self._canvas.update() self._canvas.request_draw(self._animate) def _animate(self) -> None: @@ -611,7 +623,14 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: pos = self.canvas_to_world((pos_xy[0], pos_xy[1])) for c in self._scene.children: bb = c.get_bounding_box() - if _is_inside(bb, pos): - element = cast("CanvasElement", self._elements.get(c)) + if _is_inside(bb, pos) and (element := self._elements.get(c)): elements.append(element) return elements + + def set_visible(self, visible: bool) -> None: + """Set the visibility of the canvas.""" + self._canvas.visible = visible + + def close(self) -> None: + self._disconnect_mouse_events() + self._canvas.close() diff --git a/src/ndv/_views/_qt/__init__.py b/src/ndv/_views/_qt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/_views/_qt/_array_view.py b/src/ndv/_views/_qt/_array_view.py new file mode 100644 index 00000000..518efd22 --- /dev/null +++ b/src/ndv/_views/_qt/_array_view.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, cast + +from qtpy.QtCore import Qt, Signal +from qtpy.QtWidgets import ( + QCheckBox, + QComboBox, + QFormLayout, + QFrame, + QHBoxLayout, + QPushButton, + QSplitter, + QVBoxLayout, + QWidget, +) +from superqt import QCollapsible, QElidingLabel, QLabeledRangeSlider, QLabeledSlider +from superqt.cmap import QColormapComboBox +from superqt.iconify import QIconifyIcon +from superqt.utils import signals_blocked + +from ndv._views.bases import ArrayView, LutView +from ndv.models._array_display_model import ChannelMode + +if TYPE_CHECKING: + from collections.abc import Container, Hashable, Mapping, Sequence + + import cmap + from qtpy.QtGui import QIcon + + from ndv._types import AxisKey + +SLIDER_STYLE = """ +QSlider::groove:horizontal { + height: 15px; + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(128, 128, 128, 0.25), + stop:1 rgba(128, 128, 128, 0.1) + ); + border-radius: 3px; +} + +QSlider::handle:horizontal { + width: 38px; + background: #999999; + border-radius: 3px; +} + +QSlider::sub-page:horizontal { + background: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 100, 100, 0.25), + stop:1 rgba(100, 100, 100, 0.1) + ); +} + +QLabel { font-size: 12px; } + +QRangeSlider { qproperty-barColor: qlineargradient( + x1:0, y1:0, x2:0, y2:1, + stop:0 rgba(100, 80, 120, 0.2), + stop:1 rgba(100, 80, 120, 0.4) + )} +""" + + +class _CmapCombo(QColormapComboBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent, allow_user_colormaps=True, add_colormap_text="Add...") + self.setMinimumSize(140, 21) + # self.setStyleSheet("background-color: transparent;") + + def showPopup(self) -> None: + super().showPopup() + popup = self.findChild(QFrame) + popup.setMinimumWidth(self.width() + 100) + popup.move(popup.x(), popup.y() - self.height() - popup.height()) + + # TODO: upstream me + def setCurrentColormap(self, cmap_: cmap.Colormap) -> None: + """Adds the color to the QComboBox and selects it.""" + for idx in range(self.count()): + if item := self.itemColormap(idx): + if item.name == cmap_.name: + self.setCurrentIndex(idx) + else: + self.addColormap(cmap_) + + +class _QLUTWidget(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.visible = QCheckBox() + + self.cmap = _CmapCombo() + self.cmap.setFocusPolicy(Qt.FocusPolicy.NoFocus) + self.cmap.addColormaps(["gray", "green", "magenta"]) + + self.clims = QLabeledRangeSlider(Qt.Orientation.Horizontal) + + WHITE_SS = SLIDER_STYLE + "SliderLabel { font-size: 10px; color: white;}" + self.clims.setStyleSheet(WHITE_SS) + self.clims.setHandleLabelPosition( + QLabeledRangeSlider.LabelPosition.LabelsOnHandle + ) + self.clims.setEdgeLabelMode(QLabeledRangeSlider.EdgeLabelMode.NoLabel) + self.clims.setRange(0, 2**16) # TODO: expose + + self.auto_clim = QPushButton("Auto") + self.auto_clim.setMaximumWidth(42) + self.auto_clim.setCheckable(True) + + layout = QHBoxLayout(self) + layout.setSpacing(5) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.visible) + layout.addWidget(self.cmap) + layout.addWidget(self.clims) + layout.addWidget(self.auto_clim) + + +class QLutView(LutView): + def __init__(self) -> None: + super().__init__() + self._qwidget = _QLUTWidget() + # TODO: use emit_fast + self._qwidget.visible.toggled.connect(self.visibilityChanged.emit) + self._qwidget.cmap.currentColormapChanged.connect(self.cmapChanged.emit) + self._qwidget.clims.valueChanged.connect(self.climsChanged.emit) + self._qwidget.auto_clim.toggled.connect(self.autoscaleChanged.emit) + + def frontend_widget(self) -> QWidget: + return self._qwidget + + def set_channel_name(self, name: str) -> None: + self._qwidget.visible.setText(name) + + def set_auto_scale(self, auto: bool) -> None: + self._qwidget.auto_clim.setChecked(auto) + + def set_colormap(self, cmap: cmap.Colormap) -> None: + self._qwidget.cmap.setCurrentColormap(cmap) + + def set_clims(self, clims: tuple[float, float]) -> None: + self._qwidget.clims.setValue(clims) + + def set_gamma(self, gamma: float) -> None: + pass + + def set_channel_visible(self, visible: bool) -> None: + self._qwidget.visible.setChecked(visible) + + def set_visible(self, visible: bool) -> None: + self._qwidget.setVisible(visible) + + def close(self) -> None: + self._qwidget.close() + + +class _QDimsSliders(QWidget): + currentIndexChanged = Signal() + + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._sliders: dict[Hashable, QLabeledSlider] = {} + self.setStyleSheet(SLIDER_STYLE) + + layout = QFormLayout(self) + layout.setSpacing(2) + layout.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.AllNonFixedFieldsGrow) + layout.setContentsMargins(0, 0, 0, 0) + + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: + """Update sliders with the given coordinate ranges.""" + layout = cast("QFormLayout", self.layout()) + for axis, _coords in coords.items(): + sld = QLabeledSlider(Qt.Orientation.Horizontal) + sld.valueChanged.connect(self.currentIndexChanged) + if isinstance(_coords, range): + sld.setRange(_coords.start, _coords.stop - 1) + sld.setSingleStep(_coords.step) + else: + sld.setRange(0, len(_coords) - 1) + layout.addRow(str(axis), sld) + self._sliders[axis] = sld + self.currentIndexChanged.emit() + + def hide_dimensions( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + layout = cast("QFormLayout", self.layout()) + for ax, slider in self._sliders.items(): + if ax in axes_to_hide: + layout.setRowVisible(slider, False) + elif show_remainder: + layout.setRowVisible(slider, True) + + def current_index(self) -> Mapping[AxisKey, int | slice]: + """Return the current value of the sliders.""" + return {axis: slider.value() for axis, slider in self._sliders.items()} + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + """Set the current value of the sliders.""" + changed = False + # only emit signal if the value actually changed + # NOTE: this may be unnecessary, since usually the only thing calling + # set_current_index is the controller, which already knows the value + # however, we use this method directly in testing and it's nice to ensure. + with signals_blocked(self): + for axis, val in value.items(): + if isinstance(val, slice): + raise NotImplementedError("Slices are not supported yet") + if slider := self._sliders.get(axis): + if slider.value() != val: + changed = True + slider.setValue(val) + else: # pragma: no cover + warnings.warn(f"Axis {axis} not found in sliders", stacklevel=2) + if changed: + self.currentIndexChanged.emit() + + +class _UpCollapsible(QCollapsible): + def __init__( + self, + title: str = "", + parent: QWidget | None = None, + expandedIcon: QIcon | str | None = "โ–ผ", + collapsedIcon: QIcon | str | None = "โ–ฒ", + ): + super().__init__(title, parent, expandedIcon, collapsedIcon) + # little hack to make the lut collapsible take up less space + layout = cast("QVBoxLayout", self.layout()) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(0) + if ( + # look-before-leap on private attribute that may change + hasattr(self, "_content") and (inner := self._content.layout()) is not None + ): + inner.setContentsMargins(0, 4, 0, 0) + inner.setSpacing(0) + + self.setDuration(100) + + # this is a little hack to allow the buttons on the main view (below) + # share the same row as the LUT toggle button + layout.removeWidget(self._toggle_btn) + self.btn_row = QHBoxLayout() + self.btn_row.setContentsMargins(0, 0, 0, 0) + self.btn_row.setSpacing(0) + self.btn_row.addWidget(self._toggle_btn) + self.btn_row.addStretch() + layout.addLayout(self.btn_row) + + def setContent(self, content: QWidget) -> None: + """Replace central widget (the widget that gets expanded/collapsed).""" + self._content = content + # this is different from upstream + cast("QVBoxLayout", self.layout()).insertWidget(0, self._content) + self._animation.setTargetObject(content) + + +# this is a PView ... but that would make a metaclass conflict +class _QArrayViewer(QWidget): + def __init__(self, canvas_widget: QWidget, parent: QWidget | None = None): + super().__init__(parent) + + self.dims_sliders = _QDimsSliders(self) + + # place to display dataset summary + self.data_info_label = QElidingLabel("", parent=self) + # place to display arbitrary text + self.hover_info_label = QElidingLabel("", self) + + # the button that controls the display mode of the channels + # not using QEnumComboBox because we want to exclude some values for now + self.channel_mode_combo = QComboBox(self) + self.channel_mode_combo.addItems( + [ChannelMode.GRAYSCALE.value, ChannelMode.COMPOSITE.value] + ) + + # button to reset the zoom of the canvas + # TODO: unify icons across all the view frontends in a new file + set_range_icon = QIconifyIcon("fluent:full-screen-maximize-24-filled") + self.set_range_btn = QPushButton(set_range_icon, "", self) + + # button to add a histogram + add_histogram_icon = QIconifyIcon("foundation:graph-bar") + self.histogram_btn = QPushButton(add_histogram_icon, "", self) + + self.luts = _UpCollapsible( + "LUTs", + parent=self, + expandedIcon=QIconifyIcon("bi:chevron-up", color="#888888"), + collapsedIcon=QIconifyIcon("bi:chevron-down", color="#888888"), + ) + self._btn_layout = self.luts.btn_row + self._btn_layout.setParent(None) + self.luts.expand() + + self._btn_layout.addWidget(self.channel_mode_combo) + # self._btns.addWidget(self._ndims_btn) + self._btn_layout.addWidget(self.histogram_btn) + self._btn_layout.addWidget(self.set_range_btn) + # self._btns.addWidget(self._add_roi_btn) + + # above the canvas + info_widget = QWidget() + info = QHBoxLayout(info_widget) + info.setContentsMargins(0, 0, 0, 2) + info.setSpacing(0) + info.addWidget(self.data_info_label) + info_widget.setFixedHeight(16) + + left = QWidget() + left_layout = QVBoxLayout(left) + left_layout.setSpacing(2) + left_layout.setContentsMargins(0, 0, 0, 0) + left_layout.addWidget(info_widget) + left_layout.addWidget(canvas_widget, 1) + left_layout.addWidget(self.hover_info_label) + left_layout.addWidget(self.dims_sliders) + left_layout.addWidget(self.luts) + left_layout.addLayout(self._btn_layout) + + self.splitter = QSplitter(Qt.Orientation.Vertical, self) + self.splitter.addWidget(left) + + layout = QVBoxLayout(self) + layout.setSpacing(2) + layout.setContentsMargins(6, 6, 6, 6) + layout.addWidget(self.splitter) + + +class QtArrayView(ArrayView): + def __init__(self, canvas_widget: QWidget) -> None: + self._qwidget = qwdg = _QArrayViewer(canvas_widget) + qwdg.histogram_btn.clicked.connect(self._on_add_histogram_clicked) + + # TODO: use emit_fast + qwdg.dims_sliders.currentIndexChanged.connect(self.currentIndexChanged.emit) + qwdg.channel_mode_combo.currentTextChanged.connect( + self._on_channel_mode_changed + ) + qwdg.set_range_btn.clicked.connect(self.resetZoomClicked.emit) + + def add_lut_view(self) -> QLutView: + view = QLutView() + self._qwidget.luts.addWidget(view.frontend_widget()) + return view + + def remove_lut_view(self, view: LutView) -> None: + self._qwidget.luts.removeWidget(cast("QLutView", view).frontend_widget()) + + def _on_channel_mode_changed(self, text: str) -> None: + self.channelModeChanged.emit(ChannelMode(text)) + + def _on_add_histogram_clicked(self) -> None: + splitter = self._qwidget.splitter + if hasattr(self, "_hist"): + if not (sizes := splitter.sizes())[-1]: + splitter.setSizes([self._qwidget.height() - 100, 100]) + else: + splitter.setSizes([sum(sizes), 0]) + else: + self.histogramRequested.emit() + + def add_histogram(self, widget: QWidget) -> None: + if hasattr(self, "_hist"): + raise RuntimeError("Only one histogram can be added at a time") + self._hist = widget + self._qwidget.splitter.addWidget(widget) + self._qwidget.splitter.setSizes([self._qwidget.height() - 100, 100]) + + def remove_histogram(self, widget: QWidget) -> None: + widget.setParent(None) + widget.deleteLater() + + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: + """Update sliders with the given coordinate ranges.""" + self._qwidget.dims_sliders.create_sliders(coords) + + def hide_sliders( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + """Hide sliders based on visible axes.""" + self._qwidget.dims_sliders.hide_dimensions(axes_to_hide, show_remainder) + + def current_index(self) -> Mapping[AxisKey, int | slice]: + """Return the current value of the sliders.""" + return self._qwidget.dims_sliders.current_index() + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + """Set the current value of the sliders.""" + self._qwidget.dims_sliders.set_current_index(value) + + def set_data_info(self, text: str) -> None: + """Set the data info text, above the canvas.""" + self._qwidget.data_info_label.setText(text) + + def set_hover_info(self, text: str) -> None: + """Set the hover info text, below the canvas.""" + self._qwidget.hover_info_label.setText(text) + + def set_channel_mode(self, mode: ChannelMode) -> None: + """Set the channel mode button text.""" + self._qwidget.channel_mode_combo.setCurrentText(mode.value) + + def set_visible(self, visible: bool) -> None: + self._qwidget.setVisible(visible) + + def close(self) -> None: + self._qwidget.close() + + def frontend_widget(self) -> QWidget: + return self._qwidget diff --git a/src/ndv/viewer/_save_button.py b/src/ndv/_views/_qt/_save_button.py similarity index 95% rename from src/ndv/viewer/_save_button.py rename to src/ndv/_views/_qt/_save_button.py index 0ce45116..b86dff28 100644 --- a/src/ndv/viewer/_save_button.py +++ b/src/ndv/_views/_qt/_save_button.py @@ -7,7 +7,7 @@ from superqt.iconify import QIconifyIcon if TYPE_CHECKING: - from ._data_wrapper import DataWrapper + from ndv.models import DataWrapper class SaveButton(QPushButton): diff --git a/src/ndv/viewer/spin.gif b/src/ndv/_views/_qt/spin.gif similarity index 100% rename from src/ndv/viewer/spin.gif rename to src/ndv/_views/_qt/spin.gif diff --git a/src/ndv/_views/_vispy/__init__.py b/src/ndv/_views/_vispy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/_views/_vispy/_array_canvas.py similarity index 75% rename from src/ndv/viewer/_backends/_vispy.py rename to src/ndv/_views/_vispy/_array_canvas.py index e74a05e7..aac99b2a 100755 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/_views/_vispy/_array_canvas.py @@ -5,25 +5,31 @@ from typing import TYPE_CHECKING, Any, Literal, cast from weakref import WeakKeyDictionary -import cmap +import cmap as _cmap import numpy as np import vispy import vispy.scene import vispy.visuals -from qtpy.QtCore import Qt from vispy import scene from vispy.color import Color from vispy.util.quaternion import Quaternion -from ._protocols import PCanvas +from ndv._types import CursorType +from ndv._views._app import filter_mouse_events +from ndv._views._vispy._utils import supports_float_textures +from ndv._views.bases import ArrayCanvas +from ndv._views.bases.graphics._canvas_elements import ( + CanvasElement, + ImageHandle, + RoiHandle, +) if TYPE_CHECKING: from collections.abc import Sequence from typing import Callable - from qtpy.QtWidgets import QWidget + import vispy.app - from ._protocols import CanvasElement turn = np.sin(np.pi / 4) DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) @@ -36,8 +42,8 @@ def __init__( self, parent: RectangularROI, on_move: Callable[[Sequence[float]], None] | None = None, - cursor: Qt.CursorShape - | Callable[[Sequence[float]], Qt.CursorShape] = Qt.CursorShape.SizeAllCursor, + cursor: CursorType + | Callable[[Sequence[float]], CursorType] = CursorType.ALL_ARROW, ) -> None: super().__init__(parent=parent) self.unfreeze() @@ -47,12 +53,12 @@ def __init__( if on_move: self.on_move.append(on_move) # cusror preference function - if isinstance(cursor, Qt.CursorShape): - self._cursor_at = cast( - "Callable[[Sequence[float]], Qt.CursorShape]", lambda _: cursor - ) - else: - self._cursor_at = cursor + if not callable(cursor): + + def cursor(_: Any) -> CursorType: + return cursor + + self._cursor_at = cursor self._selected = False # NB VisPy asks that the data is a 2D array self._pos = np.array([[0, 0]], dtype=np.float32) @@ -84,7 +90,7 @@ def selected(self, selected: bool) -> None: self._selected = selected self.parent.selected = selected - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: return self._cursor_at(self.pos) @@ -137,15 +143,15 @@ def __init__( self._selected = False self.freeze() - def _handle_cursor_pref(self, handle_pos: Sequence[float]) -> Qt.CursorShape: + def _handle_cursor_pref(self, handle_pos: Sequence[float]) -> CursorType: # Bottom left handle if handle_pos[0] < self.center[0] and handle_pos[1] < self.center[1]: - return Qt.CursorShape.SizeFDiagCursor + return CursorType.FDIAG_ARROW # Top right handle if handle_pos[0] > self.center[0] and handle_pos[1] > self.center[1]: - return Qt.CursorShape.SizeFDiagCursor + return CursorType.FDIAG_ARROW # Top left, bottom right - return Qt.CursorShape.SizeBDiagCursor + return CursorType.BDIAG_ARROW def move_top_left(self, pos: Sequence[float]) -> None: self._handles[3].pos = [pos[0], self._handles[3].pos[1]] @@ -239,26 +245,24 @@ def move(self, pos: Sequence[float]) -> None: ] self.center = new_center - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: - return Qt.CursorShape.SizeAllCursor + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: + return CursorType.ALL_ARROW # ------------------- End EditableROI interface ------------------------- -class VispyImageHandle: - def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: +class VispyImageHandle(ImageHandle): + def __init__(self, visual: scene.Image | scene.Volume) -> None: self._visual = visual self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 - @property def data(self) -> np.ndarray: try: return self._visual._data # type: ignore [no-any-return] except AttributeError: return self._visual._last_data # type: ignore [no-any-return] - @data.setter - def data(self, data: np.ndarray) -> None: + def set_data(self, data: np.ndarray) -> None: if not data.ndim == self._ndim: warnings.warn( f"Got wrong number of dimensions ({data.ndim}) for vispy " @@ -268,50 +272,46 @@ def data(self, data: np.ndarray) -> None: return self._visual.set_data(data) - @property def visible(self) -> bool: return bool(self._visual.visible) - @visible.setter - def visible(self, visible: bool) -> None: + def set_visible(self, visible: bool) -> None: self._visual.visible = visible - @property + # TODO: shouldn't be needed def can_select(self) -> bool: return False - @property def selected(self) -> bool: return False - @selected.setter - def selected(self, selected: bool) -> None: + def set_selected(self, selected: bool) -> None: raise NotImplementedError("Images cannot be selected") - @property def clim(self) -> Any: return self._visual.clim - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: + def set_clims(self, clims: tuple[float, float]) -> None: with suppress(ZeroDivisionError): self._visual.clim = clims - @property - def cmap(self) -> cmap.Colormap: - return self._cmap + def gamma(self) -> float: + return self._visual.gamma # type: ignore [no-any-return] + + def set_gamma(self, gamma: float) -> None: + self._visual.gamma = gamma + + def cmap(self) -> _cmap.Colormap: + return self._cmap # FIXME - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: + def set_cmap(self, cmap: _cmap.Colormap) -> None: self._cmap = cmap self._visual.cmap = cmap.to_vispy() - @property def transform(self) -> np.ndarray: raise NotImplementedError - @transform.setter - def transform(self, transform: np.ndarray) -> None: + def set_transform(self, transform: np.ndarray) -> None: raise NotImplementedError def start_move(self, pos: Sequence[float]) -> None: @@ -323,34 +323,29 @@ def move(self, pos: Sequence[float]) -> None: def remove(self) -> None: self._visual.parent = None - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: return None # FIXME: Unfortunate naming :) -class VispyHandleHandle: +class VispyHandleHandle(CanvasElement): def __init__(self, handle: Handle, parent: CanvasElement) -> None: self._handle = handle self._parent = parent - @property def visible(self) -> bool: return cast("bool", self._handle.visible) - @visible.setter - def visible(self, visible: bool) -> None: + def set_visible(self, visible: bool) -> None: self._handle.visible = visible - @property def can_select(self) -> bool: return True - @property def selected(self) -> bool: return self._handle.selected - @selected.setter - def selected(self, selected: bool) -> None: + def set_selected(self, selected: bool) -> None: self._handle.selected = selected def start_move(self, pos: Sequence[float]) -> None: @@ -362,40 +357,33 @@ def move(self, pos: Sequence[float]) -> None: def remove(self) -> None: self._parent.remove() - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: return self._handle.cursor_at(pos) -class VispyRoiHandle: +class VispyRoiHandle(RoiHandle): def __init__(self, roi: RectangularROI) -> None: self._roi = roi - @property def vertices(self) -> Sequence[Sequence[float]]: return self._roi.vertices - @vertices.setter - def vertices(self, vertices: Sequence[Sequence[float]]) -> None: + def set_vertices(self, vertices: Sequence[Sequence[float]]) -> None: self._roi.vertices = vertices - @property def visible(self) -> bool: return bool(self._roi.visible) - @visible.setter - def visible(self, visible: bool) -> None: + def set_visible(self, visible: bool) -> None: self._roi.visible = visible - @property def can_select(self) -> bool: return True - @property def selected(self) -> bool: return self._roi.selected - @selected.setter - def selected(self, selected: bool) -> None: + def set_selected(self, selected: bool) -> None: self._roi.selected = selected def start_move(self, pos: Sequence[float]) -> None: @@ -404,41 +392,33 @@ def start_move(self, pos: Sequence[float]) -> None: def move(self, pos: Sequence[float]) -> None: self._roi.move(pos) - @property def color(self) -> Any: return self._roi.color - @color.setter - def color(self, color: Any | None = None) -> None: + def set_color(self, color: _cmap.Color | None) -> None: if color is None: - color = cmap.Color("transparent") - if not isinstance(color, cmap.Color): - color = cmap.Color(color) + color = _cmap.Color("transparent") # NB: To enable dragging the shape within the border, # we require a positive alpha. alpha = max(color.alpha, 1e-6) self._roi.color = Color(color.hex, alpha=alpha) - @property - def border_color(self) -> Any: - return self._roi.border_color + def border_color(self) -> _cmap.Color: + return _cmap.Color(self._roi.border_color.rgba) - @border_color.setter - def border_color(self, color: Any | None = None) -> None: + def set_border_color(self, color: _cmap.Color | None) -> None: if color is None: - color = cmap.Color("yellow") - if not isinstance(color, cmap.Color): - color = cmap.Color(color) + color = _cmap.Color("yellow") self._roi.border_color = Color(color.hex, alpha=color.alpha) def remove(self) -> None: self._roi.parent = None - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: return self._roi.cursor_at(pos) -class VispyViewerCanvas(PCanvas): +class VispyArrayCanvas(ArrayCanvas): """Vispy-based viewer for data. All vispy-specific code is encapsulated in this class (and non-vispy canvases @@ -447,6 +427,13 @@ class VispyViewerCanvas(PCanvas): def __init__(self) -> None: self._canvas = scene.SceneCanvas(size=(600, 600)) + + # this filter needs to remain in scope for the lifetime of the canvas + # or mouse events will not be intercepted + # the returned function can be called to remove the filter, (and it also + # closes on the event filter and keeps it in scope). + self._disconnect_mouse_events = filter_mouse_events(self._canvas.native, self) + self._last_state: dict[Literal[2, 3], Any] = {} central_wdg: scene.Widget = self._canvas.central_widget @@ -454,6 +441,7 @@ def __init__(self) -> None: self._ndim: Literal[2, 3] | None = None self._elements: WeakKeyDictionary = WeakKeyDictionary() + self._txt_fmt = "auto" if supports_float_textures() else None @property def _camera(self) -> vispy.scene.cameras.BaseCamera: @@ -480,48 +468,65 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: cam.set_state(state) self._view.camera = cam - def qwidget(self) -> QWidget: - return cast("QWidget", self._canvas.native) + def frontend_widget(self) -> Any: + return self._canvas.native + + def set_visible(self, visible: bool) -> None: ... + + def close(self) -> None: + self._disconnect_mouse_events() + self._canvas.close() def refresh(self) -> None: self._canvas.update() - def add_image( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> VispyImageHandle: + def add_image(self, data: np.ndarray | None = None) -> VispyImageHandle: """Add a new Image node to the scene.""" - img = scene.visuals.Image(data, parent=self._view.scene) + data = _downcast(data) + try: + img = scene.visuals.Image( + data, parent=self._view.scene, texture_format=self._txt_fmt + ) + except ValueError as e: + warnings.warn(f"{e}. Falling back to CPUScaledTexture", stacklevel=2) + img = scene.visuals.Image(data, parent=self._view.scene) + img.set_gl_state("additive", depth_test=False) img.interactive = True handle = VispyImageHandle(img) self._elements[img] = handle if data is not None: self.set_range() - if cmap is not None: - handle.cmap = cmap return handle - def add_volume( - self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None - ) -> VispyImageHandle: - vol = scene.visuals.Volume( - data, parent=self._view.scene, interpolation="nearest" - ) + def add_volume(self, data: np.ndarray | None = None) -> VispyImageHandle: + data = _downcast(data) + try: + vol = scene.visuals.Volume( + data, + parent=self._view.scene, + interpolation="nearest", + texture_format=self._txt_fmt, + ) + except ValueError as e: + warnings.warn(f"{e}. Falling back to CPUScaledTexture", stacklevel=2) + vol = scene.visuals.Volume( + data, parent=self._view.scene, interpolation="nearest" + ) + vol.set_gl_state("additive", depth_test=False) vol.interactive = True handle = VispyImageHandle(vol) self._elements[vol] = handle if data is not None: self.set_range() - if cmap is not None: - handle.cmap = cmap return handle def add_roi( self, vertices: Sequence[tuple[float, float]] | None = None, - color: cmap.Color | None = None, - border_color: cmap.Color | None = None, + color: _cmap.Color | None = None, + border_color: _cmap.Color | None = None, ) -> VispyRoiHandle: """Add a new Rectangular ROI node to the scene.""" roi = RectangularROI(parent=self._view.scene) @@ -530,10 +535,10 @@ def add_roi( for h in roi._handles: self._elements[h] = VispyHandleHandle(h, handle) if vertices: - handle.vertices = vertices + handle.set_vertices(vertices) self.set_range() - handle.color = color - handle.border_color = border_color + handle.set_color(color) + handle.set_border_color(border_color) return handle def set_range( @@ -547,6 +552,9 @@ def set_range( When called with no arguments, the range is set to the full extent of the data. """ + # temporary + self._camera.set_range() + return _x = [0.0, 0.0] _y = [0.0, 0.0] _z = [0.0, 0.0] @@ -593,3 +601,15 @@ def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: if (handle := self._elements.get(vis)) is not None: elements.append(handle) return elements + + +def _downcast(data: np.ndarray | None) -> np.ndarray | None: + """Downcast >32bit data to 32bit.""" + # downcast to 32bit, preserving int/float + if data is not None: + if np.issubdtype(data.dtype, np.integer) and data.dtype.itemsize > 2: + warnings.warn("Downcasting integer data to uint16.", stacklevel=2) + data = data.astype(np.uint16) + elif np.issubdtype(data.dtype, np.floating) and data.dtype.itemsize > 4: + data = data.astype(np.float32) + return data diff --git a/src/ndv/_views/_vispy/_histogram.py b/src/ndv/_views/_vispy/_histogram.py new file mode 100644 index 00000000..38437eae --- /dev/null +++ b/src/ndv/_views/_vispy/_histogram.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +from enum import Enum, auto +from typing import TYPE_CHECKING, Any + +import numpy as np +from vispy import scene + +from ndv._types import CursorType +from ndv._views._app import filter_mouse_events +from ndv._views.bases import HistogramCanvas + +from ._plot_widget import PlotWidget + +if TYPE_CHECKING: + from collections.abc import Sequence + + import cmap + import numpy.typing as npt + + from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent + + +class Grabbable(Enum): + NONE = auto() + LEFT_CLIM = auto() + RIGHT_CLIM = auto() + GAMMA = auto() + + +class VispyHistogramCanvas(HistogramCanvas): + """A HistogramCanvas utilizing VisPy.""" + + def __init__(self, *, vertical: bool = False) -> None: + # ------------ data and state ------------ # + + self._values: Sequence[float] | np.ndarray | None = None + self._bin_edges: Sequence[float] | np.ndarray | None = None + self._clims: tuple[float, float] | None = None + self._gamma: float = 1 + + # the currently grabbed object + self._grabbed: Grabbable = Grabbable.NONE + # whether the y-axis is logarithmic + self._log_base: float | None = None + # whether the histogram is vertical + self._vertical: bool = vertical + # The values of the left and right edges on the canvas (respectively) + self._domain: tuple[float, float] | None = None + # The values of the bottom and top edges on the canvas (respectively) + self._range: tuple[float, float] | None = None + + # ------------ VisPy Canvas ------------ # + + self._canvas = scene.SceneCanvas() + self._disconnect_mouse_events = filter_mouse_events(self._canvas.native, self) + + ## -- Visuals -- ## + + # NB We directly use scene.Mesh, instead of scene.Histogram, + # so that we can control the calculation of the histogram ourselves + self._hist_mesh = scene.Mesh(color="#888888") + + # The Lut Line visualizes both the clims (vertical line segments connecting the + # first two and last two points, respectively) and the gamma curve + # (the polyline between all remaining points) + self._lut_line = scene.LinePlot( + data=(0), # Dummy value to prevent resizing errors + color="k", + connect="strip", + symbol=None, + line_kind="-", + width=1.5, + marker_size=10.0, + edge_color="k", + face_color="b", + edge_width=1.0, + ) + self._lut_line.visible = False + self._lut_line.order = -1 + + # The gamma handle appears halfway between the clims + self._gamma_handle_pos: np.ndarray = np.ndarray((1, 2)) + self._gamma_handle = scene.Markers( + pos=self._gamma_handle_pos, + size=6, + edge_width=0, + ) + self._gamma_handle.visible = False + self._gamma_handle.order = -2 + + # One transform to rule them all! + self._handle_transform = scene.transforms.STTransform() + self._lut_line.transform = self._handle_transform + self._gamma_handle.transform = self._handle_transform + + ## -- Plot -- ## + self.plot = PlotWidget() + self.plot.lock_axis("y") + self._canvas.central_widget.add_widget(self.plot) + self.node_tform = self.plot.node_transform(self.plot._view.scene) + + self.plot._view.add(self._hist_mesh) + self.plot._view.add(self._lut_line) + self.plot._view.add(self._gamma_handle) + + self.set_vertical(vertical) + + def refresh(self) -> None: + self._canvas.update() + + def set_visible(self, visible: bool) -> None: ... + + def close(self) -> None: + self._disconnect_mouse_events() + self._canvas.close() + + # ------------- LutView Protocol methods ------------- # + + def set_channel_name(self, name: str) -> None: + # Nothing to do + # TODO: maybe show text somewhere + pass + + def set_channel_visible(self, visible: bool) -> None: + self._lut_line.visible = visible + self._gamma_handle.visible = visible + + def set_colormap(self, lut: cmap.Colormap) -> None: + if self._hist_mesh is not None: + self._hist_mesh.color = lut.color_stops[-1].color.hex + + def set_gamma(self, gamma: float) -> None: + if gamma < 0: + raise ValueError("gamma must be non-negative!") + self._gamma = gamma + self._update_lut_lines() + + def set_clims(self, clims: tuple[float, float]) -> None: + if clims[1] < clims[0]: + clims = (clims[1], clims[0]) + self._clims = clims + self._update_lut_lines() + + def set_auto_scale(self, autoscale: bool) -> None: + # Nothing to do (yet) + pass + + # ------------- HistogramView Protocol methods ------------- # + + def set_data(self, values: np.ndarray, bin_edges: np.ndarray) -> None: + """Set the histogram values and bin edges. + + These inputs follow the same format as the return value of numpy.histogram. + """ + self._values, self._bin_edges = values, bin_edges + self._update_histogram() + + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = 0, + ) -> None: + if x: + if x[0] > x[1]: + x = (x[1], x[0]) + elif self._bin_edges is not None: + x = self._bin_edges[0], self._bin_edges[-1] + if y: + if y[0] > y[1]: + y = (y[1], y[0]) + elif self._values is not None: + y = (0, np.max(self._values)) + self._range = y + self._domain = x + self._resize() + + def set_vertical(self, vertical: bool) -> None: + self._vertical = vertical + self._update_histogram() + self.plot.lock_axis("x" if vertical else "y") + # When vertical, smaller values should appear at the top of the canvas + self.plot.camera.flip = [False, vertical, False] + self._update_lut_lines() + self._resize() + + def set_log_base(self, base: float | None) -> None: + if base != self._log_base: + self._log_base = base + self._update_histogram() + self._update_lut_lines() + self._resize() + + def frontend_widget(self) -> Any: + return self._canvas.native + + def canvas_to_world( + self, pos_xy: tuple[float, float] + ) -> tuple[float, float, float]: + """Map XY canvas position (pixels) to XYZ coordinate in world space.""" + raise NotImplementedError + + def elements_at(self, pos_xy: tuple[float, float]) -> list: + raise NotImplementedError + + # ------------- Private methods ------------- # + + def _update_histogram(self) -> None: + """ + Updates the displayed histogram with current View parameters. + + NB: Much of this code is graciously borrowed from: + + https://github.com/vispy/vispy/blob/af847424425d4ce51f144a4d1c75ab4033fe39be/vispy/visuals/histogram.py#L28 + """ + if self._values is None or self._bin_edges is None: + return # pragma: no cover + values = self._values + if self._log_base: + # Replace zero values with 1 + values = np.where(values == 0, 1, values) + values = np.log(values) / np.log(self._log_base) + + verts, faces = _hist_counts_to_mesh(values, self._bin_edges, self._vertical) + self._hist_mesh.set_data(vertices=verts, faces=faces) + + # FIXME: This should be called internally upon set_data, right? + # Looks like https://github.com/vispy/vispy/issues/1899 + self._hist_mesh._bounds_changed() + + def _update_lut_lines(self, npoints: int = 256) -> None: + if self._clims is None or self._gamma is None: + return # pragma: no cover + + # 2 additional points for each of the two vertical clims lines + X = np.empty(npoints + 4) + Y = np.empty(npoints + 4) + if self._vertical: + # clims lines + X[0:2], Y[0:2] = (1, 0.5), self._clims[0] + X[-2:], Y[-2:] = (0.5, 0), self._clims[1] + # gamma line + X[2:-2] = np.linspace(0, 1, npoints) ** self._gamma + Y[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) + midpoint = np.array([(2**-self._gamma, np.mean(self._clims))]) + else: + # clims lines + X[0:2], Y[0:2] = self._clims[0], (1, 0.5) + X[-2:], Y[-2:] = self._clims[1], (0.5, 0) + # gamma line + X[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) + Y[2:-2] = np.linspace(0, 1, npoints) ** self._gamma + midpoint = np.array([(np.mean(self._clims), 2**-self._gamma)]) + + # TODO: Move to self.edit_cmap + color = np.linspace(0.2, 0.8, npoints + 4).repeat(4).reshape(-1, 4) + c1, c2 = [0.4] * 4, [0.7] * 4 + color[0:3] = [c1, c2, c1] + color[-3:] = [c1, c2, c1] + + self._lut_line.set_data((X, Y), marker_size=0, color=color) + + self._gamma_handle_pos[:] = midpoint[0] + self._gamma_handle.set_data(pos=self._gamma_handle_pos) + + # FIXME: These should be called internally upon set_data, right? + # Looks like https://github.com/vispy/vispy/issues/1899 + self._lut_line._bounds_changed() + for v in self._lut_line._subvisuals: + v._bounds_changed() + self._gamma_handle._bounds_changed() + + def get_cursor(self, pos: tuple[float, float]) -> CursorType: + nearby = self._find_nearby_node(pos) + + if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: + return CursorType.V_ARROW if self._vertical else CursorType.H_ARROW + elif nearby is Grabbable.GAMMA: + return CursorType.H_ARROW if self._vertical else CursorType.V_ARROW + else: + x, y = self._to_plot_coords(pos) + x1, x2 = self.plot.xaxis.axis.domain + y1, y2 = self.plot.yaxis.axis.domain + if (x1 < x <= x2) and (y1 <= y <= y2): + return CursorType.ALL_ARROW + else: + return CursorType.DEFAULT + + def on_mouse_press(self, event: MousePressEvent) -> bool: + pos = event.x, event.y + # check whether the user grabbed a node + self._grabbed = self._find_nearby_node(pos) + if self._grabbed != Grabbable.NONE: + # disconnect the pan/zoom mouse events until handle is dropped + self.plot.camera.interactive = False + return False + + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + self._grabbed = Grabbable.NONE + self.plot.camera.interactive = True + return False + + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + """Called whenever mouse moves over canvas.""" + pos = event.x, event.y + if self._clims is None: + return False # pragma: no cover + + if self._grabbed in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: + if self._vertical: + c = self._to_plot_coords(pos)[1] + else: + c = self._to_plot_coords(pos)[0] + if self._grabbed is Grabbable.LEFT_CLIM: + newlims = (min(self._clims[1], c), self._clims[1]) + elif self._grabbed is Grabbable.RIGHT_CLIM: + newlims = (self._clims[0], max(self._clims[0], c)) + self.climsChanged.emit(newlims) + return False + + if self._grabbed is Grabbable.GAMMA: + y0, y1 = ( + self.plot.xaxis.axis.domain + if self._vertical + else self.plot.yaxis.axis.domain + ) + y = self._to_plot_coords(pos)[0 if self._vertical else 1] + if y < np.maximum(y0, 0) or y > y1: + return False + self.gammaChanged.emit(-np.log2(y / y1)) + return False + + self.get_cursor(pos).apply_to(self) + return False + + def _find_nearby_node( + self, pos: tuple[float, float], tolerance: int = 5 + ) -> Grabbable: + """Describes whether the event is near a clim.""" + click_x, click_y = pos + + # NB Computations are performed in canvas-space + # for easier tolerance computation. + plot_to_canvas = self.node_tform.imap + gamma_to_plot = self._handle_transform.map + + if self._clims is not None: + if self._vertical: + click = click_y + right = plot_to_canvas([0, self._clims[1]])[1] + left = plot_to_canvas([0, self._clims[0]])[1] + else: + click = click_x + right = plot_to_canvas([self._clims[1], 0])[0] + left = plot_to_canvas([self._clims[0], 0])[0] + + # Right bound always selected on overlap + if bool(abs(right - click) < tolerance): + return Grabbable.RIGHT_CLIM + if bool(abs(left - click) < tolerance): + return Grabbable.LEFT_CLIM + + if self._gamma_handle_pos is not None: + gx, gy = plot_to_canvas(gamma_to_plot(self._gamma_handle_pos[0]))[:2] + if bool(abs(gx - click_x) < tolerance and abs(gy - click_y) < tolerance): + return Grabbable.GAMMA + + return Grabbable.NONE + + def _to_plot_coords(self, pos: Sequence[float]) -> tuple[float, float]: + """Return the plot coordinates of the given position.""" + x, y = self.node_tform.map(pos)[:2] + return x, y + + def _resize(self) -> None: + self.plot.camera.set_range( + x=self._range if self._vertical else self._domain, + y=self._domain if self._vertical else self._range, + # FIXME: Bitten by https://github.com/vispy/vispy/issues/1483 + # It's pretty visible in logarithmic mode + margin=1e-30, + ) + if self._vertical: + scale = 0.98 * self.plot.xaxis.axis.domain[1] + self._handle_transform.scale = (scale, 1) + else: + scale = 0.98 * self.plot.yaxis.axis.domain[1] + self._handle_transform.scale = (1, scale) + + def setVisible(self, visible: bool) -> None: ... + + +def _hist_counts_to_mesh( + values: Sequence[float] | npt.NDArray, + bin_edges: Sequence[float] | npt.NDArray, + vertical: bool = False, +) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.uint32]]: + """Convert histogram counts to mesh vertices and faces for plotting.""" + n_edges = len(bin_edges) + X, Y = (1, 0) if vertical else (0, 1) + + # 4-5 + # | | + # 1-2/7-8 + # |/| | | + # 0-3-6-9 + # construct vertices + vertices = np.zeros((3 * n_edges - 2, 3), np.float32) + vertices[:, X] = np.repeat(bin_edges, 3)[1:-1] + vertices[1::3, Y] = values + vertices[2::3, Y] = values + vertices[vertices == float("-inf")] = 0 + + # construct triangles + faces = np.zeros((2 * n_edges - 2, 3), np.uint32) + offsets = 3 * np.arange(n_edges - 1, dtype=np.uint32)[:, np.newaxis] + faces[::2] = np.array([0, 2, 1]) + offsets + faces[1::2] = np.array([2, 0, 3]) + offsets + return vertices, faces diff --git a/src/ndv/_views/_vispy/_plot_widget.py b/src/ndv/_views/_vispy/_plot_widget.py new file mode 100644 index 00000000..4b6145be --- /dev/null +++ b/src/ndv/_views/_vispy/_plot_widget.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, cast + +from vispy import scene + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import TypeVar + + # just here cause vispy has poor type hints + T = TypeVar("T") + + class Grid(scene.Grid, Generic[T]): + def add_view( + self, + row: int | None = None, + col: int | None = None, + row_span: int = 1, + col_span: int = 1, + **kwargs: Any, + ) -> scene.ViewBox: + super().add_view(...) + + def add_widget( + self, + widget: None | scene.Widget = None, + row: int | None = None, + col: int | None = None, + row_span: int = 1, + col_span: int = 1, + **kwargs: Any, + ) -> scene.Widget: + super().add_widget(...) + + def __getitem__(self, idxs: int | tuple[int, int]) -> T: + return super().__getitem__(idxs) # type: ignore [no-any-return] + + class WidgetKwargs(TypedDict, total=False): + pos: tuple[float, float] + size: tuple[float, float] + border_color: str + border_width: float + bgcolor: str + padding: float + margin: float + + class TextVisualKwargs(TypedDict, total=False): + text: str + color: str + bold: bool + italic: bool + face: str + font_size: float + pos: tuple[float, float] | tuple[float, float, float] + rotation: float + method: Literal["cpu", "gpu"] + depth_test: bool + + class AxisWidgetKwargs(TypedDict, total=False): + orientation: Literal["left", "bottom"] + tick_direction: tuple[int, int] + axis_color: str + tick_color: str + text_color: str + minor_tick_length: float + major_tick_length: float + tick_width: float + tick_label_margin: float + tick_font_size: float + axis_width: float + axis_label: str + axis_label_margin: float + axis_font_size: float + font_size: float # overrides tick_font_size and axis_font_size + + +__all__ = ["PlotWidget"] + + +DEFAULT_AXIS_KWARGS: AxisWidgetKwargs = { + "text_color": "w", + "axis_color": "w", + "tick_color": "w", + "tick_width": 1, + "tick_font_size": 8, + "tick_label_margin": 12, + "axis_label_margin": 50, + "minor_tick_length": 2, + "major_tick_length": 5, + "axis_width": 1, + "axis_font_size": 10, +} + + +class Component(str, Enum): + PAD_LEFT = "pad_left" + PAD_RIGHT = "pad_right" + PAD_BOTTOM = "pad_bottom" + TITLE = "title" + CBAR_TOP = "cbar_top" + CBAR_LEFT = "cbar_left" + CBAR_RIGHT = "cbar_right" + CBAR_BOTTOM = "cbar_bottom" + YAXIS = "yaxis" + XAXIS = "xaxis" + XLABEL = "xlabel" + YLABEL = "ylabel" + + def __str__(self) -> str: + return self.value + + +class PlotWidget(scene.Widget): + """Widget to facilitate plotting. + + Parameters + ---------- + fg_color : str + The default color for the plot. + xlabel : str + The x-axis label. + ylabel : str + The y-axis label. + title : str + The title of the plot. + lock_axis : {'x', 'y', None} + Prevent panning and zooming along a particular axis. + **widget_kwargs : dict + Keyword arguments to pass to the parent class. + """ + + def __init__( + self, + fg_color: str = "k", + xlabel: str = "", + ylabel: str = "", + title: str = "", + lock_axis: Literal["x", "y", None] = None, + **widget_kwargs: Any, + ) -> None: + self._fg_color = fg_color + self._visuals: list[scene.VisualNode] = [] + super().__init__(**widget_kwargs) + self.unfreeze() + self.grid = cast("Grid", self.add_grid(spacing=0, margin=10)) + + title_kwargs: TextVisualKwargs = {"font_size": 14, "color": "w"} + label_kwargs: TextVisualKwargs = {"font_size": 10, "color": "w"} + self._title = scene.Label(str(title), **title_kwargs) + self._xlabel = scene.Label(str(xlabel), **label_kwargs) + self._ylabel = scene.Label(str(ylabel), rotation=-90, **label_kwargs) + + axis_kwargs: AxisWidgetKwargs = DEFAULT_AXIS_KWARGS + self.yaxis = scene.AxisWidget(orientation="left", **axis_kwargs) + self.xaxis = scene.AxisWidget(orientation="bottom", **axis_kwargs) + + # 2D Plot layout: + # + # c0 c1 c2 c3 c4 c5 c6 + # +----------+-------+-------+-------+---------+---------+-----------+ + # r0 | | | title | | | + # | +-----------------------+---------+---------+ | + # r1 | | | cbar | | | + # |----------+-------+-------+-------+---------+---------+ ----------| + # r2 | pad_left | cbar | ylabel| yaxis | view | cbar | pad_right | + # |----------+-------+-------+-------+---------+---------+ ----------| + # r3 | | | xaxis | | | + # | +-----------------------+---------+---------+ | + # r4 | | | xlabel | | | + # | +-----------------------+---------+---------+ | + # r5 | | | cbar | | | + # |---------+------------------------+---------+---------+-----------| + # r6 | | pad_bottom | | + # +---------+------------------------+---------+---------+-----------+ + + self._grid_wdgs: dict[Component, scene.Widget] = {} + for name, row, col, widget in [ + (Component.PAD_LEFT, 2, 0, None), + (Component.PAD_RIGHT, 2, 6, None), + (Component.PAD_BOTTOM, 6, 4, None), + (Component.TITLE, 0, 4, self._title), + (Component.CBAR_TOP, 1, 4, None), + (Component.CBAR_LEFT, 2, 1, None), + (Component.CBAR_RIGHT, 2, 5, None), + (Component.CBAR_BOTTOM, 5, 4, None), + (Component.YAXIS, 2, 3, self.yaxis), + (Component.XAXIS, 3, 4, self.xaxis), + (Component.XLABEL, 4, 4, self._xlabel), + (Component.YLABEL, 2, 2, self._ylabel), + ]: + self._grid_wdgs[name] = wdg = self.grid.add_widget(widget, row=row, col=col) + # If we don't set max size, they will expand to fill the entire grid + # occluding pretty much everything else. + if str(name).startswith(("cbar", "pad")): + if name in { + Component.PAD_LEFT, + Component.PAD_RIGHT, + Component.CBAR_LEFT, + Component.CBAR_RIGHT, + }: + wdg.width_max = 2 + else: + wdg.height_max = 2 + + # The main view into which plots are added + self._view = self.grid.add_view(row=2, col=4) + + # NOTE: this is a mess of hardcoded values... not sure whether they will work + # cross-platform. Note that `width_max` and `height_max` of 2 is actually + # *less* visible than 0 for some reason. They should also be extracted into + # some sort of `hide/show` logic for each component + # TODO: dynamic max based on max tick value? + self._grid_wdgs[Component.YAXIS].width_max = 40 # otherwise it takes too much + self._grid_wdgs[Component.PAD_LEFT].width_max = 20 # otherwise you get clipping + self._grid_wdgs[Component.XAXIS].height_max = 20 # otherwise it takes too much + self.ylabel = ylabel + self.xlabel = xlabel + self.title = title + + # VIEWBOX (this has to go last, see vispy #1748) + self.camera = self._view.camera = PanZoom1DCamera(lock_axis) + # this has to come after camera is set + self.xaxis.link_view(self._view) + self.yaxis.link_view(self._view) + self.freeze() + + @property + def title(self) -> str: + """The title label.""" + return self._title.text # type: ignore [no-any-return] + + @title.setter + def title(self, text: str) -> None: + """Set the title of the plot.""" + self._title.text = text + wdg = self._grid_wdgs[Component.TITLE] + wdg.height_min = wdg.height_max = 30 if text else 2 + + @property + def xlabel(self) -> str: + """The x-axis label.""" + return self._xlabel.text # type: ignore [no-any-return] + + @xlabel.setter + def xlabel(self, text: str) -> None: + """Set the x-axis label.""" + self._xlabel.text = text + wdg = self._grid_wdgs[Component.XLABEL] + wdg.height_min = wdg.height_max = 40 if text else 2 + + @property + def ylabel(self) -> str: + """The y-axis label.""" + return self._ylabel.text # type: ignore [no-any-return] + + @ylabel.setter + def ylabel(self, text: str) -> None: + """Set the x-axis label.""" + self._ylabel.text = text + wdg = self._grid_wdgs[Component.YLABEL] + wdg.width_min = wdg.width_max = 20 if text else 2 + + def lock_axis(self, axis: Literal["x", "y", None]) -> None: + """Prevent panning and zooming along a particular axis.""" + self.camera._axis = axis + # self.camera.set_range() + + +class PanZoom1DCamera(scene.cameras.PanZoomCamera): + """Camera that allows panning and zooming along one axis only. + + Parameters + ---------- + axis : {'x', 'y', None} + The axis along which to allow panning and zooming. + *args : tuple + Positional arguments to pass to the parent class. + **kwargs : dict + Keyword arguments to pass to the parent class. + """ + + def __init__( + self, axis: Literal["x", "y", None] = None, *args: Any, **kwargs: Any + ) -> None: + self._axis: Literal["x", "y", None] = axis + super().__init__(*args, **kwargs) + + @property + def axis_index(self) -> Literal[0, 1, None]: + """Return the index of the axis along which to pan and zoom.""" + if self._axis in ("x", 0): + return 0 + elif self._axis in ("y", 1): + return 1 + return None + + def zoom( + self, + factor: float | tuple[float, float], + center: tuple[float, ...] | None = None, + ) -> None: + """Zoom the camera by `factor` around `center`.""" + if self.axis_index is None: + super().zoom(factor, center=center) + return + + if isinstance(factor, (float, int)): + factor = (factor, factor) + _factor = list(factor) + _factor[self.axis_index] = 1 + super().zoom(_factor, center=center) + + def pan(self, pan: Sequence[float]) -> None: + """Pan the camera by `pan`.""" + if self.axis_index is None: + super().pan(pan) + return + _pan = list(pan) + _pan[self.axis_index] = 0 + super().pan(*_pan) + + def set_range( + self, + x: tuple | None = None, + y: tuple | None = None, + z: tuple | None = None, + margin: float = 0, # overriding to create a different default from super() + ) -> None: + """Reset the camera view to the specified range.""" + super().set_range(x, y, z, margin) diff --git a/src/ndv/_views/_vispy/_utils.py b/src/ndv/_views/_vispy/_utils.py new file mode 100644 index 00000000..f5ec027a --- /dev/null +++ b/src/ndv/_views/_vispy/_utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from contextlib import contextmanager +from functools import cache +from typing import TYPE_CHECKING + +from vispy.app import Canvas +from vispy.gloo import gl +from vispy.gloo.context import get_current_canvas + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@contextmanager +def _opengl_context() -> Iterator[None]: + """Assure we are running with a valid OpenGL context. + + Only create a Canvas if one doesn't exist. Creating and closing a + Canvas causes vispy to process Qt events which can cause problems. + Ideally call opengl_context() on start after creating your first + Canvas. However it will work either way. + """ + canvas = Canvas(show=False) if get_current_canvas() is None else None + try: + yield + finally: + if canvas is not None: + canvas.close() + + +@cache +def get_gl_extensions() -> set[str]: + """Get basic info about the Gl capabilities of this machine.""" + with _opengl_context(): + return set(filter(None, gl.glGetParameter(gl.GL_EXTENSIONS).split())) + + +FLOAT_EXT = {"GL_ARB_texture_float", "GL_ATI_texture_float", "GL_NV_float_buffer"} + + +@cache +def supports_float_textures() -> bool: + """Check if the current OpenGL context supports float textures.""" + return bool(FLOAT_EXT.intersection(get_gl_extensions())) diff --git a/src/ndv/_views/_wx/__init__.py b/src/ndv/_views/_wx/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/ndv/_views/_wx/__init__.py @@ -0,0 +1 @@ + diff --git a/src/ndv/_views/_wx/_array_view.py b/src/ndv/_views/_wx/_array_view.py new file mode 100644 index 00000000..b8ee969f --- /dev/null +++ b/src/ndv/_views/_wx/_array_view.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, cast + +import wx +import wx.lib.newevent +from psygnal import Signal + +from ndv._views._wx._labeled_slider import WxLabeledSlider +from ndv._views.bases._array_view import ArrayView +from ndv._views.bases._lut_view import LutView +from ndv.models._array_display_model import ChannelMode + +from .range_slider import RangeSlider + +if TYPE_CHECKING: + from collections.abc import Container, Hashable, Mapping, Sequence + + import cmap + + from ndv._types import AxisKey + + +# mostly copied from _qt.qt_view._QLUTWidget +class _WxLUTWidget(wx.Panel): + def __init__(self, parent: wx.Window) -> None: + super().__init__(parent) + + self.visible = wx.CheckBox(self, label="Visible") + self.visible.SetValue(True) + + # Placeholder for the custom colormap combo box + self.cmap = wx.ComboBox( + self, choices=["gray", "green", "magenta"], style=wx.CB_DROPDOWN + ) + + # Placeholder for the QLabeledRangeSlider equivalent + self.clims = RangeSlider(self, style=wx.SL_HORIZONTAL) + self.clims.SetMax(65000) + self.clims.SetValue(0, 65000) + + self.auto_clim = wx.ToggleButton(self, label="Auto") + + # Layout + sizer = wx.BoxSizer(wx.HORIZONTAL) + sizer.Add(self.visible, 0, wx.ALIGN_CENTER_VERTICAL, 5) + sizer.Add(self.cmap, 0, wx.ALIGN_CENTER_VERTICAL, 5) + sizer.Add(self.clims, 1, wx.ALIGN_CENTER_VERTICAL, 5) + sizer.Add(self.auto_clim, 0, wx.ALIGN_CENTER_VERTICAL, 5) + + self.SetSizer(sizer) + self.Layout() + + +class WxLutView(LutView): + def __init__(self, parent: wx.Window) -> None: + super().__init__() + self._wxwidget = wdg = _WxLUTWidget(parent) + # TODO: use emit_fast + wdg.visible.Bind(wx.EVT_CHECKBOX, self._on_visible_changed) + wdg.cmap.Bind(wx.EVT_COMBOBOX, self._on_cmap_changed) + wdg.clims.Bind(wx.EVT_SLIDER, self._on_clims_changed) + wdg.auto_clim.Bind(wx.EVT_TOGGLEBUTTON, self._on_autoscale_changed) + + # Event Handlers + def _on_visible_changed(self, event: wx.CommandEvent) -> None: + self.visibilityChanged.emit(self._wxwidget.visible.GetValue()) + + def _on_cmap_changed(self, event: wx.CommandEvent) -> None: + self.cmapChanged.emit(self._wxwidget.cmap.GetValue()) + + def _on_clims_changed(self, event: wx.CommandEvent) -> None: + self.climsChanged.emit(self._wxwidget.clims.GetValues()) + + def _on_autoscale_changed(self, event: wx.CommandEvent) -> None: + self.autoscaleChanged.emit(self._wxwidget.auto_clim.GetValue()) + + # Public Methods + def frontend_widget(self) -> wx.Window: + return self._wxwidget + + def set_channel_name(self, name: str) -> None: + self._wxwidget.visible.SetLabel(name) + + def set_auto_scale(self, auto: bool) -> None: + self._wxwidget.auto_clim.SetValue(auto) + + def set_colormap(self, cmap: cmap.Colormap) -> None: + name = cmap.name.split(":")[-1] # FIXME: this is a hack + self._wxwidget.cmap.SetValue(name) + + def set_clims(self, clims: tuple[float, float]) -> None: + self._wxwidget.clims.SetValue(*clims) + + def set_channel_visible(self, visible: bool) -> None: + self._wxwidget.visible.SetValue(visible) + + def set_visible(self, visible: bool) -> None: + if visible: + self._wxwidget.Show() + else: + self._wxwidget.Hide() + + def close(self) -> None: + self._wxwidget.Close() + + +# mostly copied from _qt.qt_view._QDimsSliders +class _WxDimsSliders(wx.Panel): + currentIndexChanged = Signal() + + def __init__(self, parent: wx.Window) -> None: + super().__init__(parent) + + self._sliders: dict[AxisKey, WxLabeledSlider] = {} + self.layout = wx.BoxSizer(wx.VERTICAL) + self.SetSizer(self.layout) + + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: + """Update sliders with the given coordinate ranges.""" + for axis, _coords in coords.items(): + slider = WxLabeledSlider(self) + slider.label.SetLabel(str(axis)) + slider.slider.Bind(wx.EVT_SLIDER, self._on_slider_changed) + + if isinstance(_coords, range): + slider.setRange(_coords.start, _coords.stop - 1) + slider.setSingleStep(_coords.step) + else: + slider.setRange(0, len(_coords) - 1) + + self.layout.Add(slider, 0, wx.EXPAND | wx.ALL, 5) + self._sliders[axis] = slider + self.currentIndexChanged.emit() + + def hide_dimensions( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + for ax, slider in self._sliders.items(): + if ax in axes_to_hide: + slider.Hide() + elif show_remainder: + slider.Show() + + self.Layout() + + def current_index(self) -> Mapping[AxisKey, int | slice]: + """Return the current value of the sliders.""" + return {axis: slider.value() for axis, slider in self._sliders.items()} + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + """Set the current value of the sliders.""" + changed = False + with self.currentIndexChanged.blocked(): + for axis, val in value.items(): + if isinstance(val, slice): + raise NotImplementedError("Slices are not supported yet") + if slider := self._sliders.get(axis): + if slider.value() != val: + changed = True + slider.setValue(val) + else: + warnings.warn(f"Axis {axis} not found in sliders", stacklevel=2) + + if changed: + self.currentIndexChanged.emit() + + def _on_slider_changed(self, event: wx.CommandEvent) -> None: + self.currentIndexChanged.emit() + + +class _WxArrayViewer(wx.Frame): + def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None): + super().__init__(parent) + + # FIXME: pygfx backend needs this to be canvas_widget._subwidget + if hasattr(canvas_widget, "_subwidget"): + canvas_widget = canvas_widget._subwidget + + if (parent := canvas_widget.GetParent()) and parent is not self: + canvas_widget.Reparent(self) # Reparent canvas_widget to this frame + if parent: + parent.Destroy() + canvas_widget.Show() + + self._canvas = canvas_widget + + # Dynamic sliders for dimensions + self.dims_sliders = _WxDimsSliders(self) + + # Labels for data and hover information + self._data_info_label = wx.StaticText(self, label="") + self._hover_info_label = wx.StaticText(self, label="") + + # Channel mode combo box + self.channel_mode_combo = wx.ComboBox( + self, choices=[x.value for x in ChannelMode], style=wx.CB_DROPDOWN + ) + + # Reset zoom button + self.reset_zoom_btn = wx.Button(self, label="Reset Zoom") + + # LUT layout (simple vertical grouping for LUT widgets) + self.luts = wx.BoxSizer(wx.VERTICAL) + + btns = wx.BoxSizer(wx.HORIZONTAL) + btns.Add(self.channel_mode_combo, 0, wx.RIGHT, 5) + btns.Add(self.reset_zoom_btn, 0, wx.RIGHT, 5) + + # Layout for the panel + main_sizer = wx.BoxSizer(wx.VERTICAL) + main_sizer.Add(self._data_info_label, 0, wx.EXPAND | wx.BOTTOM, 5) + main_sizer.Add(self._canvas, 1, wx.EXPAND | wx.ALL, 5) + main_sizer.Add(self._hover_info_label, 0, wx.EXPAND | wx.BOTTOM, 5) + main_sizer.Add(self.dims_sliders, 0, wx.EXPAND | wx.BOTTOM, 5) + main_sizer.Add(self.luts, 0, wx.EXPAND, 5) + main_sizer.Add(btns, 0, wx.EXPAND, 5) + + self.SetSizer(main_sizer) + self.SetInitialSize(wx.Size(600, 800)) + self.Layout() + + +class WxArrayView(ArrayView): + def __init__(self, canvas_widget: wx.Window, parent: wx.Window = None) -> None: + self._wxwidget = wdg = _WxArrayViewer(canvas_widget, parent) + + # TODO: use emit_fast + wdg.dims_sliders.currentIndexChanged.connect(self.currentIndexChanged.emit) + wdg.channel_mode_combo.Bind(wx.EVT_COMBOBOX, self._on_channel_mode_changed) + wdg.reset_zoom_btn.Bind(wx.EVT_BUTTON, self._on_reset_zoom_clicked) + + def _on_channel_mode_changed(self, event: wx.CommandEvent) -> None: + mode = self._wxwidget.channel_mode_combo.GetValue() + self.channelModeChanged.emit(mode) + + def _on_reset_zoom_clicked(self, event: wx.CommandEvent) -> None: + self.resetZoomClicked.emit() + + def frontend_widget(self) -> wx.Window: + return self._wxwidget + + def add_lut_view(self) -> WxLutView: + view = WxLutView(self.frontend_widget()) + self._wxwidget.luts.Add(view._wxwidget, 0, wx.EXPAND | wx.BOTTOM, 5) + self._wxwidget.Layout() + return view + + def remove_lut_view(self, lut: LutView) -> None: + wxwdg = cast("_WxLUTWidget", lut.frontend_widget()) + self._wxwidget.luts.Detach(wxwdg) + wxwdg.Destroy() + self._wxwidget.Layout() + + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: + self._wxwidget.dims_sliders.create_sliders(coords) + self._wxwidget.Layout() + + def hide_sliders( + self, axes_to_hide: Container[Hashable], show_remainder: bool = True + ) -> None: + self._wxwidget.dims_sliders.hide_dimensions(axes_to_hide, show_remainder) + self._wxwidget.Layout() + + def current_index(self) -> Mapping[AxisKey, int | slice]: + return self._wxwidget.dims_sliders.current_index() + + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: + self._wxwidget.dims_sliders.set_current_index(value) + + def set_data_info(self, text: str) -> None: + self._wxwidget._data_info_label.SetLabel(text) + + def set_hover_info(self, text: str) -> None: + self._wxwidget._hover_info_label.SetLabel(text) + + def set_channel_mode(self, mode: ChannelMode) -> None: + self._wxwidget.channel_mode_combo.SetValue(mode) + + def set_visible(self, visible: bool) -> None: + if visible: + self._wxwidget.Show() + else: + self._wxwidget.Hide() + + def close(self) -> None: + self._wxwidget.Close() diff --git a/src/ndv/_views/_wx/_labeled_slider.py b/src/ndv/_views/_wx/_labeled_slider.py new file mode 100644 index 00000000..ccad7b14 --- /dev/null +++ b/src/ndv/_views/_wx/_labeled_slider.py @@ -0,0 +1,29 @@ +import wx + + +class WxLabeledSlider(wx.Panel): + """A simple labeled slider widget for wxPython.""" + + def __init__(self, parent: wx.Window) -> None: + super().__init__(parent) + + self.label = wx.StaticText(self) + self.slider = wx.Slider(self, style=wx.HORIZONTAL) + + sizer = wx.BoxSizer(wx.HORIZONTAL) + sizer.Add(self.label, 0, wx.ALIGN_CENTER_VERTICAL | wx.RIGHT, 5) + sizer.Add(self.slider, 1, wx.EXPAND) + self.SetSizer(sizer) + + def setRange(self, min_val: int, max_val: int) -> None: + self.slider.SetMin(min_val) + self.slider.SetMax(max_val) + + def setValue(self, value: int) -> None: + self.slider.SetValue(value) + + def value(self) -> int: + return self.slider.GetValue() # type: ignore [no-any-return] + + def setSingleStep(self, step: int) -> None: + self.slider.SetLineSize(step) diff --git a/src/ndv/_views/_wx/range_slider.py b/src/ndv/_views/_wx/range_slider.py new file mode 100644 index 00000000..a6fad08c --- /dev/null +++ b/src/ndv/_views/_wx/range_slider.py @@ -0,0 +1,368 @@ +"""Adapted from https://gist.github.com/gabrieldp/e19611abead7f6617872d33866c568a3. + +credit: +Gabriel Pasa +gabrieldp +""" + +from __future__ import annotations + +from typing import Any + +import wx + + +def fraction_to_value(fraction: float, min_value: float, max_value: float) -> float: + return (max_value - min_value) * fraction + min_value + + +def value_to_fraction(value: float, min_value: float, max_value: float) -> float: + return float(value - min_value) / (max_value - min_value) + + +class SliderThumb: + def __init__(self, parent: RangeSlider, value: int): + self.parent = parent + self.dragged = False + self.mouse_over = False + self.thumb_poly = ((0, 0), (0, 13), (5, 18), (10, 13), (10, 0)) + self.thumb_shadow_poly = ((0, 14), (4, 18), (6, 18), (10, 14)) + min_coords = [float("Inf"), float("Inf")] + max_coords = [-float("Inf"), -float("Inf")] + for pt in list(self.thumb_poly) + list(self.thumb_shadow_poly): + for i_coord, coord in enumerate(pt): + if coord > max_coords[i_coord]: + max_coords[i_coord] = coord + if coord < min_coords[i_coord]: + min_coords[i_coord] = coord + self.size = (max_coords[0] - min_coords[0], max_coords[1] - min_coords[1]) + + self.value = value + self.normal_color = wx.Colour((0, 120, 215)) + self.normal_shadow_color = wx.Colour((120, 180, 228)) + self.dragged_color = wx.Colour((204, 204, 204)) + self.dragged_shadow_color = wx.Colour((222, 222, 222)) + self.mouse_over_color = wx.Colour((23, 23, 23)) + self.mouse_over_shadow_color = wx.Colour((132, 132, 132)) + + def GetPosition(self) -> tuple[int, int]: + min_x = self.GetMin() + max_x = self.GetMax() + parent_size = self.parent.GetSize() + min_value = self.parent.GetMin() + max_value = self.parent.GetMax() + fraction = value_to_fraction(self.value, min_value, max_value) + low = int(fraction_to_value(fraction, min_x, max_x)) + high = int(parent_size[1] / 2 + 1) + return low, high + + def SetPosition(self, pos: tuple[int, int]) -> None: + pos_x = pos[0] + # Limit movement by the position of the other thumb + who_other, other_thumb = self.GetOtherThumb() + other_pos = other_thumb.GetPosition() + if who_other == "low": + pos_x = int( + max(other_pos[0] + other_thumb.size[0] / 2 + self.size[0] / 2, pos_x) + ) + else: + pos_x = int( + min(other_pos[0] - other_thumb.size[0] / 2 - self.size[0] / 2, pos_x) + ) + # Limit movement by slider boundaries + min_x = self.GetMin() + max_x = self.GetMax() + pos_x = min(max(pos_x, min_x), max_x) + + fraction = value_to_fraction(pos_x, min_x, max_x) + self.value = int( + fraction_to_value(fraction, self.parent.GetMin(), self.parent.GetMax()) + ) + # Post event notifying that position changed + self.PostEvent() + + def GetValue(self) -> int: + return self.value + + def SetValue(self, value: int) -> None: + self.value = value + # Post event notifying that value changed + self.PostEvent() + + def PostEvent(self) -> None: + event = wx.PyCommandEvent(wx.EVT_SLIDER.typeId, self.parent.GetId()) + event.SetEventObject(self.parent) + wx.PostEvent(self.parent.GetEventHandler(), event) + + def GetMin(self) -> int: + return self.parent.border_width + int(self.size[0] / 2) + + def GetMax(self) -> int: + parent_w = int(self.parent.GetSize()[0]) + return parent_w - self.parent.border_width - int(self.size[0] / 2) + + def IsMouseOver(self, mouse_pos: wx.Point) -> bool: + in_hitbox = True + my_pos = self.GetPosition() + for i_coord, mouse_coord in enumerate(mouse_pos): + boundary_low = my_pos[i_coord] - self.size[i_coord] / 2 + boundary_high = my_pos[i_coord] + self.size[i_coord] / 2 + in_hitbox = in_hitbox and (boundary_low <= mouse_coord <= boundary_high) + return in_hitbox + + def GetOtherThumb(self) -> tuple[str, SliderThumb]: + if self.parent.thumbs["low"] != self: + return "low", self.parent.thumbs["low"] + else: + return "high", self.parent.thumbs["high"] + + def OnPaint(self, dc: wx.BufferedPaintDC) -> None: + if self.dragged or not self.parent.IsEnabled(): + thumb_color = self.dragged_color + thumb_shadow_color = self.dragged_shadow_color + elif self.mouse_over: + thumb_color = self.mouse_over_color + thumb_shadow_color = self.mouse_over_shadow_color + else: + thumb_color = self.normal_color + thumb_shadow_color = self.normal_shadow_color + my_pos = self.GetPosition() + + # Draw thumb shadow (or anti-aliasing effect) + dc.SetBrush(wx.Brush(thumb_shadow_color, style=wx.BRUSHSTYLE_SOLID)) + dc.SetPen(wx.Pen(thumb_shadow_color, width=1, style=wx.PENSTYLE_SOLID)) + dc.DrawPolygon( + points=self.thumb_shadow_poly, + xoffset=int(my_pos[0] - self.size[0] / 2), + yoffset=int(my_pos[1] - self.size[1] / 2), + ) + # Draw thumb itself + dc.SetBrush(wx.Brush(thumb_color, style=wx.BRUSHSTYLE_SOLID)) + dc.SetPen(wx.Pen(thumb_color, width=1, style=wx.PENSTYLE_SOLID)) + dc.DrawPolygon( + points=self.thumb_poly, + xoffset=int(my_pos[0] - self.size[0] / 2), + yoffset=int(my_pos[1] - self.size[1] / 2), + ) + + +class RangeSlider(wx.Panel): + def __init__( + self, + parent: wx.Window, + id: int = wx.ID_ANY, + lowValue: int | None = None, + highValue: int | None = None, + minValue: int = 0, + maxValue: int = 100, + pos: wx.Point = wx.DefaultPosition, + size: wx.Size = wx.DefaultSize, + style: int = wx.SL_HORIZONTAL, + validator: wx.Validator = wx.DefaultValidator, + name: str = "rangeSlider", + ) -> None: + if style != wx.SL_HORIZONTAL: + raise NotImplementedError("Styles not implemented") + if validator != wx.DefaultValidator: + raise NotImplementedError("Validator not implemented") + super().__init__(parent=parent, id=id, pos=pos, size=size, name=name) + self.SetMinSize(size=(max(50, size[0]), max(26, size[1]))) + if minValue > maxValue: + minValue, maxValue = maxValue, minValue + self.min_value = minValue + self.max_value = maxValue + if lowValue is None: + lowValue = self.min_value + if highValue is None: + highValue = self.max_value + if lowValue > highValue: + lowValue, highValue = highValue, lowValue + lowValue = max(lowValue, self.min_value) + highValue = min(highValue, self.max_value) + self.track_height = 8 + self.border_width = 8 + + self.thumbs = { + "low": SliderThumb(parent=self, value=lowValue), + "high": SliderThumb(parent=self, value=highValue), + } + self.thumb_width = self.thumbs["low"].size[0] + + # Aesthetic definitions + self.slider_background_color = wx.Colour((231, 234, 234)) + self.slider_outline_color = wx.Colour((214, 214, 214)) + self.selected_range_color = wx.Colour((0, 120, 215)) + self.selected_range_outline_color = wx.Colour((0, 120, 215)) + + # Bind events + self.Bind(wx.EVT_LEFT_DOWN, self.OnMouseDown) + self.Bind(wx.EVT_LEFT_UP, self.OnMouseUp) + self.Bind(wx.EVT_MOTION, self.OnMouseMotion) + self.Bind(wx.EVT_MOUSE_CAPTURE_LOST, self.OnMouseLost) + self.Bind(wx.EVT_ENTER_WINDOW, self.OnMouseEnter) + self.Bind(wx.EVT_LEAVE_WINDOW, self.OnMouseLeave) + self.Bind(wx.EVT_PAINT, self.OnPaint) + self.Bind(wx.EVT_ERASE_BACKGROUND, self.OnEraseBackground) + self.Bind(wx.EVT_SIZE, self.OnResize) + + def Enable(self, enable: bool = True) -> None: + super().Enable(enable) + self.Refresh() + + def Disable(self) -> None: + super().Disable() + self.Refresh() + + def SetValueFromMousePosition(self, click_pos: wx.Point) -> None: + for thumb in self.thumbs.values(): + if thumb.dragged: + thumb.SetPosition(click_pos) + + def OnMouseDown(self, evt: wx.Event) -> None: + if not self.IsEnabled(): + return + click_pos = evt.GetPosition() + for thumb in self.thumbs.values(): + if thumb.IsMouseOver(click_pos): + thumb.dragged = True + thumb.mouse_over = False + break + self.SetValueFromMousePosition(click_pos) + self.CaptureMouse() + self.Refresh() + + def OnMouseUp(self, evt: wx.Event) -> None: + if not self.IsEnabled(): + return + self.SetValueFromMousePosition(evt.GetPosition()) + for thumb in self.thumbs.values(): + thumb.dragged = False + if self.HasCapture(): + self.ReleaseMouse() + self.Refresh() + + def OnMouseLost(self, evt: wx.Event) -> None: + for thumb in self.thumbs.values(): + thumb.dragged = False + thumb.mouse_over = False + self.Refresh() + + def OnMouseMotion(self, evt: wx.Event) -> None: + if not self.IsEnabled(): + return + refresh_needed = False + mouse_pos = evt.GetPosition() + if evt.Dragging() and evt.LeftIsDown(): + self.SetValueFromMousePosition(mouse_pos) + refresh_needed = True + else: + for thumb in self.thumbs.values(): + old_mouse_over = thumb.mouse_over + thumb.mouse_over = thumb.IsMouseOver(mouse_pos) + if old_mouse_over != thumb.mouse_over: + refresh_needed = True + if refresh_needed: + self.Refresh() + + def OnMouseEnter(self, evt: wx.Event) -> None: + if not self.IsEnabled(): + return + mouse_pos = evt.GetPosition() + for thumb in self.thumbs.values(): + if thumb.IsMouseOver(mouse_pos): + thumb.mouse_over = True + self.Refresh() + break + + def OnMouseLeave(self, evt: wx.Event) -> None: + if not self.IsEnabled(): + return + for thumb in self.thumbs.values(): + thumb.mouse_over = False + self.Refresh() + + def OnResize(self, evt: wx.Event) -> None: + self.Refresh() + + def OnPaint(self, evt: wx.Event) -> None: + w, h = self.GetSize() + # BufferedPaintDC should reduce flickering + dc = wx.BufferedPaintDC(self) + background_brush = wx.Brush(self.GetBackgroundColour(), wx.SOLID) + dc.SetBackground(background_brush) + dc.Clear() + # Draw slider + dc.SetPen(wx.Pen(self.slider_outline_color, width=1, style=wx.PENSTYLE_SOLID)) + dc.SetBrush(wx.Brush(self.slider_background_color, style=wx.BRUSHSTYLE_SOLID)) + + dc.DrawRectangle( + int(self.border_width), + int(h / 2 - self.track_height / 2), + int(w - 2 * self.border_width), + self.track_height, + ) + # Draw selected range + if self.IsEnabled(): + dc.SetPen( + wx.Pen( + self.selected_range_outline_color, width=1, style=wx.PENSTYLE_SOLID + ) + ) + dc.SetBrush(wx.Brush(self.selected_range_color, style=wx.BRUSHSTYLE_SOLID)) + else: + dc.SetPen( + wx.Pen(self.slider_outline_color, width=1, style=wx.PENSTYLE_SOLID) + ) + dc.SetBrush(wx.Brush(self.slider_outline_color, style=wx.BRUSHSTYLE_SOLID)) + low_pos = self.thumbs["low"].GetPosition()[0] + high_pos = self.thumbs["high"].GetPosition()[0] + dc.DrawRectangle( + int(low_pos), + int(h / 2 - self.track_height / 4), + int(high_pos - low_pos), + int(self.track_height / 2), + ) + # Draw thumbs + for thumb in self.thumbs.values(): + thumb.OnPaint(dc) + evt.Skip() + + def OnEraseBackground(self, evt: Any) -> None: + # This should reduce flickering + pass + + def GetValues(self) -> tuple[int, int]: + return self.thumbs["low"].value, self.thumbs["high"].value + + def SetValue(self, lowValue: float, highValue: float) -> None: + if lowValue > highValue: + lowValue, highValue = highValue, lowValue + lowValue = max(lowValue, self.min_value) + highValue = min(highValue, self.max_value) + self.thumbs["low"].SetValue(int(lowValue)) + self.thumbs["high"].SetValue(int(highValue)) + self.Refresh() + + def GetMax(self) -> int: + return self.max_value + + def GetMin(self) -> int: + return self.min_value + + def SetMax(self, maxValue: int) -> None: + if maxValue < self.min_value: + maxValue = self.min_value + _, old_high = self.GetValues() + if old_high > maxValue: + self.thumbs["high"].SetValue(maxValue) + self.max_value = maxValue + self.Refresh() + + def SetMin(self, minValue: int) -> None: + if minValue > self.max_value: + minValue = self.max_value + old_low, _ = self.GetValues() + if old_low < minValue: + self.thumbs["low"].SetValue(minValue) + self.min_value = minValue + self.Refresh() diff --git a/src/ndv/_views/bases/__init__.py b/src/ndv/_views/bases/__init__.py new file mode 100644 index 00000000..a722dd71 --- /dev/null +++ b/src/ndv/_views/bases/__init__.py @@ -0,0 +1,18 @@ +from ._array_view import ArrayView +from ._lut_view import LutView +from ._view_base import Viewable +from .graphics._canvas import ArrayCanvas, HistogramCanvas +from .graphics._canvas_elements import CanvasElement, ImageHandle, RoiHandle +from .graphics._mouseable import Mouseable + +__all__ = [ + "ArrayCanvas", + "ArrayView", + "CanvasElement", + "HistogramCanvas", + "ImageHandle", + "LutView", + "Mouseable", + "RoiHandle", + "Viewable", +] diff --git a/src/ndv/_views/bases/_array_view.py b/src/ndv/_views/bases/_array_view.py new file mode 100644 index 00000000..d61ac3ed --- /dev/null +++ b/src/ndv/_views/bases/_array_view.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from psygnal import Signal + +from ndv._views.bases._view_base import Viewable +from ndv.models._array_display_model import ChannelMode + +if TYPE_CHECKING: + from collections.abc import Container, Hashable, Mapping, Sequence + + from ndv._types import AxisKey + from ndv._views.bases._lut_view import LutView + + +class ArrayView(Viewable): + """ABC for ND Array viewers widget. + + Currently, this is the "main" widget that contains the array display and + all the controls for interacting with the array, includings sliders, LUTs, + and histograms. + """ + + currentIndexChanged = Signal() + resetZoomClicked = Signal() + histogramRequested = Signal() + channelModeChanged = Signal(ChannelMode) + + @abstractmethod + def __init__(self, canvas_widget: Any, **kwargs: Any) -> None: ... + @abstractmethod + def create_sliders(self, coords: Mapping[int, Sequence]) -> None: ... + @abstractmethod + def current_index(self) -> Mapping[AxisKey, int | slice]: ... + @abstractmethod + def set_current_index(self, value: Mapping[AxisKey, int | slice]) -> None: ... + @abstractmethod + def set_channel_mode(self, mode: ChannelMode) -> None: ... + @abstractmethod + def set_data_info(self, data_info: str) -> None: ... + @abstractmethod + def set_hover_info(self, hover_info: str) -> None: ... + @abstractmethod + def hide_sliders( + self, axes_to_hide: Container[Hashable], *, show_remainder: bool = ... + ) -> None: ... + @abstractmethod + def add_lut_view(self) -> LutView: ... + @abstractmethod + def remove_lut_view(self, view: LutView) -> None: ... + + def add_histogram(self, widget: Any) -> None: + raise NotImplementedError + + def remove_histogram(self, widget: Any) -> None: + raise NotImplementedError diff --git a/src/ndv/_views/bases/_lut_view.py b/src/ndv/_views/bases/_lut_view.py new file mode 100644 index 00000000..d621c92b --- /dev/null +++ b/src/ndv/_views/bases/_lut_view.py @@ -0,0 +1,77 @@ +from abc import abstractmethod +from typing import final + +import cmap +from psygnal import Signal + +from ._view_base import Viewable + + +class LutView(Viewable): + """Manages LUT properties (contrast, colormap, etc...) in a view object.""" + + visibilityChanged = Signal(bool) + autoscaleChanged = Signal(bool) + cmapChanged = Signal(cmap.Colormap) + climsChanged = Signal(tuple) + gammaChanged = Signal(float) + + @abstractmethod + def set_channel_name(self, name: str) -> None: + """Set the name of the channel to `name`.""" + + @abstractmethod + def set_auto_scale(self, checked: bool) -> None: + """Set the autoscale button to checked if `checked` is True.""" + + @abstractmethod + def set_colormap(self, cmap: cmap.Colormap) -> None: + """Set the colormap to `cmap`. + + Usually corresponds to a dropdown menu. + """ + + @abstractmethod + def set_clims(self, clims: tuple[float, float]) -> None: + """Set the (low, high) contrast limits to `clims`. + + Usually this will be a range slider or two text boxes. + """ + + @abstractmethod + def set_channel_visible(self, visible: bool) -> None: + """Check or uncheck the visibility indicator of the LUT. + + Usually corresponsds to a checkbox. + """ + + def set_gamma(self, gamma: float) -> None: + """Set the gamma value of the LUT.""" + return None + + # These methods apply a value to the view without re-emitting the signal. + + @final + def set_auto_scale_without_signal(self, auto: bool) -> None: + with self.autoscaleChanged.blocked(): + self.set_auto_scale(auto) + + @final + def set_colormap_without_signal(self, cmap: cmap.Colormap) -> None: + with self.cmapChanged.blocked(): + self.set_colormap(cmap) + + @final + def set_clims_without_signal(self, clims: tuple[float, float]) -> None: + with self.climsChanged.blocked(): + self.set_clims(clims) + + @final + def set_gamma_without_signal(self, gamma: float) -> None: + with self.gammaChanged.blocked(): + self.set_gamma(gamma) + + @final + def set_channel_visible_without_signal(self, visible: bool) -> None: + with self.visibilityChanged.blocked(): + self.set_channel_visible(visible) diff --git a/src/ndv/_views/bases/_view_base.py b/src/ndv/_views/bases/_view_base.py new file mode 100644 index 00000000..a9e87496 --- /dev/null +++ b/src/ndv/_views/bases/_view_base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class Viewable(ABC): + """ABC representing anything that can be viewed on screen. + + For example, a widget, a window, a frame, canvas, etc. + """ + + @abstractmethod + def frontend_widget(self) -> Any: + """Return the native object backing the viewable objects.""" + + @abstractmethod + def set_visible(self, visible: bool) -> None: + """Sets the visibility of the view/widget itself.""" + + @abstractmethod + def close(self) -> None: + """Close the view/widget.""" diff --git a/src/ndv/_views/bases/graphics/__init__.py b/src/ndv/_views/bases/graphics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/_views/bases/graphics/_canvas.py b/src/ndv/_views/bases/graphics/_canvas.py new file mode 100644 index 00000000..642227ff --- /dev/null +++ b/src/ndv/_views/bases/graphics/_canvas.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Literal + +import numpy as np + +from ndv._views.bases._lut_view import LutView +from ndv._views.bases._view_base import Viewable + +from ._mouseable import Mouseable + +if TYPE_CHECKING: + from collections.abc import Sequence + + import cmap + import numpy as np + + from ndv._views.bases.graphics._canvas_elements import ( + CanvasElement, + ImageHandle, + RoiHandle, + ) + + +class GraphicsCanvas(Viewable, Mouseable): + @abstractmethod + def refresh(self) -> None: ... + @abstractmethod + def set_range( + self, + x: tuple[float, float] | None = None, + y: tuple[float, float] | None = None, + z: tuple[float, float] | None = None, + margin: float = ..., + ) -> None: ... + @abstractmethod + def canvas_to_world( + self, pos_xy: tuple[float, float] + ) -> tuple[float, float, float]: + """Map XY canvas position (pixels) to XYZ coordinate in world space.""" + + @abstractmethod + def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: ... + + +class ArrayCanvas(GraphicsCanvas): + @abstractmethod + def set_ndim(self, ndim: Literal[2, 3]) -> None: ... + @abstractmethod + @abstractmethod + def add_image(self, data: np.ndarray | None = ...) -> ImageHandle: ... + @abstractmethod + def add_volume(self, data: np.ndarray | None = ...) -> ImageHandle: ... + @abstractmethod + def add_roi( + self, + vertices: Sequence[tuple[float, float]] | None = None, + color: cmap.Color | None = None, + border_color: cmap.Color | None = None, + ) -> RoiHandle: ... + + +class HistogramCanvas(GraphicsCanvas, LutView): + """A histogram-based view for LookUp Table (LUT) adjustment.""" + + def set_vertical(self, vertical: bool) -> None: + """If True, orient axes vertically (x-axis on left).""" + + def set_log_base(self, base: float | None) -> None: + """Sets the axis scale of the range. + + Properties + ---------- + enabled : bool + If true, the range will be displayed with a logarithmic (base 10) + scale. If false, the range will be displayed with a linear scale. + """ + + def set_data(self, values: np.ndarray, bin_edges: np.ndarray) -> None: + """Sets the histogram data. + + Properties + ---------- + values : np.ndarray + The histogram values. + bin_edges : np.ndarray + The bin edges of the histogram. + """ diff --git a/src/ndv/_views/bases/graphics/_canvas_elements.py b/src/ndv/_views/bases/graphics/_canvas_elements.py new file mode 100644 index 00000000..4e7ced55 --- /dev/null +++ b/src/ndv/_views/bases/graphics/_canvas_elements.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any + +from ._mouseable import Mouseable + +if TYPE_CHECKING: + from collections.abc import Sequence + + import cmap as _cmap + import numpy as np + + from ndv._types import CursorType + + +class CanvasElement(Mouseable): + """Protocol defining an interactive element on the Canvas.""" + + @abstractmethod + def visible(self) -> bool: + """Defines whether the element is visible on the canvas.""" + + @abstractmethod + def set_visible(self, visible: bool) -> None: + """Sets element visibility.""" + + @abstractmethod + def can_select(self) -> bool: + """Defines whether the element can be selected.""" + + @abstractmethod + def selected(self) -> bool: + """Returns element selection status.""" + + @abstractmethod + def set_selected(self, selected: bool) -> None: + """Sets element selection status.""" + + def cursor_at(self, pos: Sequence[float]) -> CursorType | None: + """Returns the element's cursor preference at the provided position.""" + + def start_move(self, pos: Sequence[float]) -> None: + """ + Behavior executed at the beginning of a "move" operation. + + In layman's terms, this is the behavior executed during the the "click" + of a "click-and-drag". + """ + + def move(self, pos: Sequence[float]) -> None: + """ + Behavior executed throughout a "move" operation. + + In layman's terms, this is the behavior executed during the "drag" + of a "click-and-drag". + """ + + def remove(self) -> None: + """Removes the element from the canvas.""" + + +class ImageHandle(CanvasElement): + @abstractmethod + def data(self) -> np.ndarray: ... + @abstractmethod + def set_data(self, data: np.ndarray) -> None: ... + @abstractmethod + def clim(self) -> Any: ... + @abstractmethod + def set_clims(self, clims: tuple[float, float]) -> None: ... + @abstractmethod + def gamma(self) -> float: ... + @abstractmethod + def set_gamma(self, gamma: float) -> None: ... + @abstractmethod + def cmap(self) -> _cmap.Colormap: ... + @abstractmethod + def set_cmap(self, cmap: _cmap.Colormap) -> None: ... + + +class RoiHandle(CanvasElement): + @abstractmethod + def vertices(self) -> Sequence[Sequence[float]]: ... + @abstractmethod + def set_vertices(self, data: Sequence[Sequence[float]]) -> None: ... + @abstractmethod + def color(self) -> Any: ... + @abstractmethod + def set_color(self, color: _cmap.Color | None) -> None: ... + @abstractmethod + def border_color(self) -> Any: ... + @abstractmethod + def set_border_color(self, color: _cmap.Color | None) -> None: ... diff --git a/src/ndv/_views/bases/graphics/_mouseable.py b/src/ndv/_views/bases/graphics/_mouseable.py new file mode 100644 index 00000000..d8f540e7 --- /dev/null +++ b/src/ndv/_views/bases/graphics/_mouseable.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from psygnal import Signal + +from ndv._types import MouseMoveEvent, MousePressEvent, MouseReleaseEvent + + +class Mouseable: + """Mixin class for objects that can be interacted with using the mouse. + + The signals here are to be emitted by the view object that inherits this class; + usually by intercepting native mouse events with `filter_mouse_events`. + + The methods allow the object to handle its own mouse events before emitting the + signals. If the method returns `True`, the event is considered handled and should + not be passed to the next receiver in the chain. + """ + + mouseMoved = Signal(MouseMoveEvent) + mousePressed = Signal(MousePressEvent) + mouseReleased = Signal(MouseReleaseEvent) + + def on_mouse_move(self, event: MouseMoveEvent) -> bool: + return False + + def on_mouse_press(self, event: MousePressEvent) -> bool: + return False + + def on_mouse_release(self, event: MouseReleaseEvent) -> bool: + return False diff --git a/src/ndv/controllers/__init__.py b/src/ndv/controllers/__init__.py new file mode 100644 index 00000000..0b280cd5 --- /dev/null +++ b/src/ndv/controllers/__init__.py @@ -0,0 +1,3 @@ +from ._array_viewer import ArrayViewer + +__all__ = ["ArrayViewer"] diff --git a/src/ndv/controllers/_array_viewer.py b/src/ndv/controllers/_array_viewer.py new file mode 100644 index 00000000..bd92430c --- /dev/null +++ b/src/ndv/controllers/_array_viewer.py @@ -0,0 +1,377 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any + +import numpy as np + +from ndv._views import _app +from ndv.controllers._channel_controller import ChannelController +from ndv.models._array_display_model import ArrayDisplayModel, ChannelMode +from ndv.models._data_display_model import _ArrayDataDisplayModel +from ndv.models._lut_model import LUTModel +from ndv.models.data_wrappers import DataWrapper + +if TYPE_CHECKING: + from typing import Any, Unpack + + from typing_extensions import TypeAlias + + from ndv._types import MouseMoveEvent + from ndv._views.bases import ArrayView, HistogramCanvas + from ndv.models._array_display_model import ArrayDisplayModelKwargs + + LutKey: TypeAlias = int | None + + +# primary "Controller" (and public API) for viewing an array + + +class ArrayViewer: + """Viewer dedicated to displaying a single n-dimensional array. + + This wraps a model, view, and controller into a single object, and defines the + public API. + + Parameters + ---------- + data : DataWrapper | Any + Data to be displayed. + display_model : ArrayDisplayModel, optional + Just the display model to use. If provided, `data_or_model` must be an array + or `DataWrapper`... and kwargs will be ignored. + **kwargs: ArrayDisplayModelKwargs + Keyword arguments to pass to the `ArrayDisplayModel` constructor. If + `display_model` is provided, these will be ignored. + """ + + def __init__( + self, + data: Any | DataWrapper = None, + /, + display_model: ArrayDisplayModel | None = None, + **kwargs: Unpack[ArrayDisplayModelKwargs], + ) -> None: + if display_model is not None and kwargs: + warnings.warn( + "When display_model is provided, kwargs are be ignored.", + stacklevel=2, + ) + + # mapping of channel keys to their respective controllers + # where None is the default channel + self._lut_controllers: dict[LutKey, ChannelController] = {} + + # get and create the front-end and canvas classes + frontend_cls = _app.get_array_view_class() + canvas_cls = _app.get_array_canvas_class() + self._canvas = canvas_cls() + self._canvas.set_ndim(2) + + self._histogram: HistogramCanvas | None = None + self._view = frontend_cls(self._canvas.frontend_widget()) + + display_model = display_model or ArrayDisplayModel(**kwargs) + self._data_model = _ArrayDataDisplayModel( + data_wrapper=data, display=display_model + ) + self._set_model_connected(self._data_model.display) + + self._view.currentIndexChanged.connect(self._on_view_current_index_changed) + self._view.resetZoomClicked.connect(self._on_view_reset_zoom_clicked) + self._view.histogramRequested.connect(self._add_histogram) + self._view.channelModeChanged.connect(self._on_view_channel_mode_changed) + self._canvas.mouseMoved.connect(self._on_canvas_mouse_moved) + + if self._data_model.data_wrapper is not None: + self._fully_synchronize_view() + + # -------------- public attributes and methods ------------------------- + + @property + def view(self) -> ArrayView: + """Return the front-end view object. + + To access the actual native widget, use `self.view.frontend_widget()`. + If you directly access the frontend widget, you're on your own :) no guarantees + can be made about synchronization with the model. However, it is exposed for + experimental and custom use cases. + """ + return self._view + + @property + def display_model(self) -> ArrayDisplayModel: + """Return the current ArrayDisplayModel.""" + return self._data_model.display + + @display_model.setter + def display_model(self, model: ArrayDisplayModel) -> None: + """Set the ArrayDisplayModel.""" + if not isinstance(model, ArrayDisplayModel): + raise TypeError("model must be an ArrayDisplayModel") + self._set_model_connected(self._data_model.display, False) + self._data_model.display = model + self._set_model_connected(self._data_model.display) + self._fully_synchronize_view() + + @property + def data_wrapper(self) -> Any: + """Return data being displayed.""" + return self._data_model.data_wrapper + + @property + def data(self) -> Any: + """Return data being displayed.""" + if self._data_model.data_wrapper is None: + return None # pragma: no cover + # returning the actual data, not the wrapper + return self._data_model.data_wrapper.data + + @data.setter + def data(self, data: Any) -> None: + """Set the data to be displayed.""" + if data is None: + self._data_model.data_wrapper = None + else: + self._data_model.data_wrapper = DataWrapper.create(data) + self._fully_synchronize_view() + + def show(self) -> None: + """Show the viewer.""" + self._view.set_visible(True) + + def hide(self) -> None: + """Show the viewer.""" + self._view.set_visible(False) + + def close(self) -> None: + """Close the viewer.""" + self._view.set_visible(False) + + def clone(self) -> ArrayViewer: + """Return a new ArrayViewer instance with the same data and display model. + + Currently, this is a shallow copy. Modifying one viewer will affect the state + of the other. + """ + # TODO: provide deep_copy option + return ArrayViewer( + self._data_model.data_wrapper, display_model=self.display_model + ) + + # --------------------- PRIVATE ------------------------------------------ + + def _add_histogram(self) -> None: + histogram_cls = _app.get_histogram_canvas_class() # will raise if not supported + self._histogram = histogram_cls() + self._view.add_histogram(self._histogram.frontend_widget()) + for view in self._lut_controllers.values(): + view.add_lut_view(self._histogram) + # FIXME: hack + if handles := view.handles: + data = handles[0].data() + counts, edges = _calc_hist_bins(data) + self._histogram.set_data(counts, edges) + + if self.data is not None: + self._update_hist_domain_for_dtype() + + def _update_hist_domain_for_dtype( + self, dtype: np.typing.DTypeLike | None = None + ) -> None: + if self._histogram is None: + return + if dtype is None: + if (wrapper := self._data_model.data_wrapper) is None: + return + dtype = wrapper.dtype + else: + dtype = np.dtype(dtype) + if dtype.kind in "iu": + iinfo = np.iinfo(dtype) + self._histogram.set_range(x=(iinfo.min, iinfo.max)) + + def _set_model_connected( + self, model: ArrayDisplayModel, connect: bool = True + ) -> None: + """Connect or disconnect the model to/from the viewer. + + We do this in a single method so that we are sure to connect and disconnect + the same events in the same order. (but it's kinda ugly) + """ + _connect = "connect" if connect else "disconnect" + + for obj, callback in [ + (model.events.visible_axes, self._on_model_visible_axes_changed), + # the current_index attribute itself is immutable + (model.current_index.value_changed, self._on_model_current_index_changed), + (model.events.channel_mode, self._on_model_channel_mode_changed), + # TODO: lut values themselves are mutable evented objects... + # so we need to connect to their events as well + # (model.luts.value_changed, ...), + ]: + getattr(obj, _connect)(callback) + + # ------------------ Model callbacks ------------------ + + def _fully_synchronize_view(self) -> None: + """Fully re-synchronize the view with the model.""" + display_model = self._data_model.display + with self.view.currentIndexChanged.blocked(): + self._view.create_sliders(self._data_model.normed_data_coords) + self._view.set_channel_mode(display_model.channel_mode) + if self.data is not None: + self._update_visible_sliders() + if cur_index := display_model.current_index: + self._view.set_current_index(cur_index) + # reconcile view sliders with model + self._on_view_current_index_changed() + if wrapper := self._data_model.data_wrapper: + self._view.set_data_info(wrapper.summary_info()) + + self._clear_canvas() + self._update_canvas() + for lut_ctr in self._lut_controllers.values(): + lut_ctr._update_view_from_model() + self._update_hist_domain_for_dtype() + + def _on_model_visible_axes_changed(self) -> None: + self._update_visible_sliders() + self._update_canvas() + + def _on_model_current_index_changed(self) -> None: + value = self._data_model.display.current_index + self._view.set_current_index(value) + self._update_canvas() + + def _on_model_channel_mode_changed(self, mode: ChannelMode) -> None: + self._view.set_channel_mode(mode) + self._update_visible_sliders() + show_channel_luts = mode in {ChannelMode.COLOR, ChannelMode.COMPOSITE} + for lut_ctrl in self._lut_controllers.values(): + for view in lut_ctrl.lut_views: + if lut_ctrl.key is None: + view.set_visible(not show_channel_luts) + else: + view.set_visible(show_channel_luts) + # redraw + self._clear_canvas() + self._update_canvas() + + def _clear_canvas(self) -> None: + for lut_ctrl in self._lut_controllers.values(): + # self._view.remove_lut_view(lut_ctrl.lut_view) + while lut_ctrl.handles: + lut_ctrl.handles.pop().remove() + # do we need to cleanup the lut views themselves? + + # ------------------ View callbacks ------------------ + + def _on_view_current_index_changed(self) -> None: + """Update the model when slider value changes.""" + self._data_model.display.current_index.update(self._view.current_index()) + + def _on_view_reset_zoom_clicked(self) -> None: + """Reset the zoom level of the canvas.""" + self._canvas.set_range() + + def _on_canvas_mouse_moved(self, event: MouseMoveEvent) -> None: + """Respond to a mouse move event in the view.""" + x, y, _z = self._canvas.canvas_to_world((event.x, event.y)) + + # collect and format intensity values at the current mouse position + channel_values = self._get_values_at_world_point(int(x), int(y)) + vals = [] + for ch, value in channel_values.items(): + # restrict to 2 decimal places, but remove trailing zeros + fval = f"{value:.2f}".rstrip("0").rstrip(".") + fch = f"{ch}: " if ch is not None else "" + vals.append(f"{fch}{fval}") + text = f"[{y:.0f}, {x:.0f}] " + ",".join(vals) + self._view.set_hover_info(text) + + def _on_view_channel_mode_changed(self, mode: ChannelMode) -> None: + self._data_model.display.channel_mode = mode + + # ------------------ Helper methods ------------------ + + def _update_visible_sliders(self) -> None: + """Update which sliders are visible based on the current data and model.""" + hidden_sliders: tuple[int, ...] = self._data_model.normed_visible_axes + if self._data_model.display.channel_mode.is_multichannel(): + if ch := self._data_model.normed_channel_axis: + hidden_sliders += (ch,) + + self._view.hide_sliders(hidden_sliders, show_remainder=True) + + def _update_canvas(self) -> None: + """Force the canvas to fetch and update the displayed data. + + This is called (frequently) when anything changes that requires a redraw. + It fetches the current data slice from the model and updates the image handle. + """ + if not self._data_model.data_wrapper: + return # pragma: no cover + + display_model = self._data_model.display + # TODO: make asynchronous + for future in self._data_model.request_sliced_data(): + response = future.result() + key = response.channel_key + data = response.data + + if (lut_ctrl := self._lut_controllers.get(key)) is None: + if key is None: + model = display_model.default_lut + elif key in display_model.luts: + model = display_model.luts[key] + else: + # we received a new channel key that has not been set in the model + # so we create a new LUT model for it + model = display_model.luts[key] = LUTModel() + + lut_views = [self._view.add_lut_view()] + if self._histogram is not None: + lut_views.append(self._histogram) + self._lut_controllers[key] = lut_ctrl = ChannelController( + key=key, + model=model, + views=lut_views, + ) + + if not lut_ctrl.handles: + # we don't yet have any handles for this channel + lut_ctrl.add_handle(self._canvas.add_image(data)) + else: + lut_ctrl.update_texture_data(data) + if self._histogram is not None: + # TODO: once data comes in in chunks, we'll need a proper stateful + # stats object that calculates the histogram incrementally + counts, bin_edges = _calc_hist_bins(data) + # TODO: currently this is updating the histogram on *any* + # channel index... so it doesn't work with composite mode + self._histogram.set_data(counts, bin_edges) + self._histogram.set_range() + + self._canvas.refresh() + + def _get_values_at_world_point(self, x: int, y: int) -> dict[LutKey, float]: + # TODO: handle 3D data + if ( + x < 0 or y < 0 + ) or self._data_model.display.n_visible_axes != 2: # pragma: no cover + return {} + + values: dict[LutKey, float] = {} + for key, ctrl in self._lut_controllers.items(): + if (value := ctrl.get_value_at_index((y, x))) is not None: + values[key] = value + + return values + + +def _calc_hist_bins(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + maxval = np.iinfo(data.dtype).max + counts = np.bincount(data.flatten(), minlength=maxval + 1) + bin_edges = np.arange(maxval + 2) - 0.5 + return counts, bin_edges diff --git a/src/ndv/controllers/_channel_controller.py b/src/ndv/controllers/_channel_controller.py new file mode 100644 index 00000000..75ed3e2b --- /dev/null +++ b/src/ndv/controllers/_channel_controller.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + import cmap + import numpy as np + + from ndv._views.bases import LutView + from ndv._views.bases.graphics._canvas_elements import ImageHandle + from ndv.models._lut_model import LUTModel + + LutKey = int | None + + +class ChannelController: + """Controller for a single channel in the viewer. + + This manages the connection between the LUT model (settings like colormap, + contrast limits and visibility) and the LUT view (the front-end widget that + allows the user to interact with these settings), as well as the image handle + that displays the data, all for a single "channel" extracted from the data. + """ + + def __init__(self, key: LutKey, model: LUTModel, views: Sequence[LutView]) -> None: + self.key = key + self.lut_views: list[LutView] = [] + self.lut_model = model + self.handles: list[ImageHandle] = [] + + for v in views: + self.add_lut_view(v) + + # connect model changes to view callbacks that update the view + self.lut_model.events.cmap.connect(self._on_model_cmap_changed) + self.lut_model.events.clims.connect(self._on_model_clims_changed) + self.lut_model.events.autoscale.connect(self._on_model_autoscale_changed) + self.lut_model.events.visible.connect(self._on_model_visible_changed) + self.lut_model.events.gamma.connect(self._on_model_gamma_changed) + + def add_lut_view(self, view: LutView) -> None: + """Add a LUT view to the controller.""" + self.lut_views.append(view) + # connect view changes to controller callbacks that update the model + view.visibilityChanged.connect(self._on_view_lut_visible_changed) + view.autoscaleChanged.connect(self._on_view_lut_autoscale_changed) + view.cmapChanged.connect(self._on_view_lut_cmap_changed) + view.climsChanged.connect(self._on_view_lut_clims_changed) + view.gammaChanged.connect(self._on_view_lut_gamma_changed) + self._update_view_from_model(view) + + def _on_model_clims_changed(self, clims: tuple[float, float]) -> None: + """The contrast limits in the model have changed.""" + for v in self.lut_views: + v.set_clims_without_signal(clims) + for handle in self.handles: + handle.set_clims(clims) + + def _on_model_gamma_changed(self, gamma: float) -> None: + """The gamma value in the model has changed.""" + for view in self.lut_views: + view.set_gamma_without_signal(gamma) + for handle in self.handles: + handle.set_gamma(gamma) + + def _on_model_autoscale_changed(self, autoscale: bool) -> None: + """The autoscale setting in the model has changed.""" + for view in self.lut_views: + view.set_auto_scale_without_signal(autoscale) + if autoscale: + for handle in self.handles: + d = handle.data() + handle.set_clims((d.min(), d.max())) + + def _on_model_cmap_changed(self, cmap: cmap.Colormap) -> None: + """The colormap in the model has changed.""" + for view in self.lut_views: + view.set_colormap_without_signal(cmap) + for handle in self.handles: + handle.set_cmap(cmap) + + def _on_model_visible_changed(self, visible: bool) -> None: + """The visibility in the model has changed.""" + for view in self.lut_views: + view.set_channel_visible_without_signal(visible) + for handle in self.handles: + handle.set_visible(visible) + + def _update_view_from_model(self, *views: LutView) -> None: + """Make sure the view matches the model.""" + _views: Iterable[LutView] = views or self.lut_views + for view in _views: + view.set_colormap_without_signal(self.lut_model.cmap) + if self.lut_model.clims: + view.set_clims_without_signal(self.lut_model.clims) + # TODO: handle more complex autoscale types + view.set_auto_scale_without_signal(bool(self.lut_model.autoscale)) + view.set_channel_visible_without_signal(True) + name = str(self.key) if self.key is not None else "" + view.set_channel_name(name) + + def _on_view_lut_visible_changed(self, visible: bool, key: LutKey = None) -> None: + """The visibility checkbox in the LUT widget has changed.""" + for handle in self.handles: + handle.set_visible(visible) + + def _on_view_lut_autoscale_changed( + self, autoscale: bool, key: LutKey = None + ) -> None: + """The autoscale checkbox in the LUT widget has changed.""" + self.lut_model.autoscale = autoscale + for view in self.lut_views: + view.set_auto_scale_without_signal(autoscale) + + if autoscale: + # TODO: or should we have a global min/max across all handles for this key? + for handle in self.handles: + data = handle.data() + # update the model with the new clim values + self.lut_model.clims = (data.min(), data.max()) + + def _on_view_lut_cmap_changed( + self, cmap: cmap.Colormap, key: LutKey = None + ) -> None: + """The colormap in the LUT widget has changed.""" + for handle in self.handles: + handle.set_cmap(cmap) # actually apply it to the Image texture + self.lut_model.cmap = cmap # update the model as well + + def _on_view_lut_clims_changed(self, clims: tuple[float, float]) -> None: + """The contrast limits slider in the LUT widget has changed.""" + self.lut_model.clims = clims + # when the clims are manually adjusted in the view, we turn off autoscale + self.lut_model.autoscale = False + + def _on_view_lut_gamma_changed(self, gamma: float) -> None: + """The gamma slider in the LUT widget has changed.""" + self.lut_model.gamma = gamma + + def update_texture_data(self, data: np.ndarray) -> None: + """Update the data in the image handle.""" + # WIP: + # until we have a more sophisticated way to handle updating data + # for multiple handles, we'll just update the first one + if not (handles := self.handles): + return + handles[0].set_data(data) + # if this image handle is visible and autoscale is on, then we need + # to update the clim values + if self.lut_model.autoscale: + self.lut_model.clims = (data.min(), data.max()) + # lut_view.setClims((data.min(), data.max())) + # technically... the LutView may also emit a signal that the + # controller listens to, and then updates the image handle + # but this next line is more direct + # self._handles[None].clim = (data.min(), data.max()) + + def add_handle(self, handle: ImageHandle) -> None: + """Add an image texture handle to the controller.""" + self.handles.append(handle) + handle.set_cmap(self.lut_model.cmap) + if self.lut_model.autoscale: + data = handle.data() + self.lut_model.clims = (data.min(), data.max()) + if self.lut_model.clims: + handle.set_clims(self.lut_model.clims) + + def get_value_at_index(self, idx: tuple[int, ...]) -> float | None: + """Get the value of the data at the given index.""" + if not (handles := self.handles): + return None + # only getting one handle per channel for now + handle = handles[0] + with suppress(IndexError): # skip out of bounds + # here, we're retrieving the value from the in-memory data + # stored by the backend visual, rather than querying the data itself + # this is a quick workaround to get the value without having to + # worry about other dimensions in the data source (since the + # texture has already been reduced to 2D). But a more complete + # implementation would gather the full current nD index and query + # the data source directly. + return handle.data()[idx] # type: ignore [no-any-return] + return None diff --git a/src/ndv/data.py b/src/ndv/data.py index 545774d9..40fb8671 100644 --- a/src/ndv/data.py +++ b/src/ndv/data.py @@ -2,9 +2,11 @@ from __future__ import annotations +from typing import Any + import numpy as np -__all__ = ["cells3d", "nd_sine_wave"] +__all__ = ["astronaut", "cat", "cells3d", "cosem_dataset", "nd_sine_wave"] def nd_sine_wave( @@ -12,7 +14,7 @@ def nd_sine_wave( amplitude: float = 240, base_frequency: float = 5, ) -> np.ndarray: - """5D dataset.""" + """5D dataset: (10, 3, 5, 512, 512), float64.""" # Unpack the dimensions if not len(shape) == 5: raise ValueError("Shape must have 5 dimensions") @@ -52,7 +54,7 @@ def nd_sine_wave( def cells3d() -> np.ndarray: - """Load cells3d data from scikit-image.""" + """Load cells3d from scikit-image (60, 2, 256, 256) uint16.""" try: from imageio.v2 import volread except ImportError as e: @@ -61,4 +63,55 @@ def cells3d() -> np.ndarray: ) from e url = "https://gitlab.com/scikit-image/data/-/raw/2cdc5ce89b334d28f06a58c9f0ca21aa6992a5ba/cells3d.tif" - return volread(url) # type: ignore [no-any-return] + data = np.asarray(volread(url)) + + # this data has been stretched to 16 bit, and lacks certain intensity values + # add a small random integer to each pixel ... so the histogram is not silly + data = (data + np.random.randint(-24, 24, data.shape)).astype(np.uint16) + return data + + +def cat() -> np.ndarray: + """Load RGB cat data (300, 451, 3), uint8.""" + return _imread("imageio:chelsea.png") + + +def astronaut() -> np.ndarray: + """Load RGB data (512, 512, 3), uint8.""" + return _imread("imageio:astronaut.png") + + +def _imread(uri: str) -> np.ndarray: + try: + import imageio.v3 as iio + except ImportError: + raise ImportError("Please install imageio fetch data") from None + return iio.imread(uri) # type: ignore [no-any-return] + + +def cosem_dataset( + uri: str = "", + dataset: str = "jrc_hela-3", + label: str = "er-mem_pred", + level: int = 4, +) -> Any: + try: + import tensorstore as ts + except ImportError: + raise ImportError("Please install tensorstore to fetch cosem data") from None + + if not uri: + uri = f"{dataset}/{dataset}.n5/labels/{label}/s{level}/" + + ts_array = ts.open( + { + "driver": "n5", + "kvstore": { + "driver": "s3", + "bucket": "janelia-cosem-datasets", + "path": uri, + }, + }, + ).result() + ts_array = ts_array[ts.d[:].label["z", "y", "x"]] + return ts_array[ts.d[("y", "x", "z")].transpose[:]] diff --git a/src/ndv/histogram/model.py b/src/ndv/histogram/model.py deleted file mode 100644 index 939fa477..00000000 --- a/src/ndv/histogram/model.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Model protocols for data display.""" - -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import ClassVar, cast - -import numpy as np -from psygnal import SignalGroupDescriptor - - -@dataclass -class StatsModel: - """A model of the statistics of a dataset. - - TODO: Consider refactoring into a protocol allowing subclassing for - e.g. faster histogram computation, different data types? - """ - - events: ClassVar[SignalGroupDescriptor] = SignalGroupDescriptor() - - standard_deviation: float | None = None - average: float | None = None - # TODO: Is the generality nice, or should we just say np.ndarray? - histogram: tuple[Sequence[int], Sequence[float]] | None = None - bins: int = 256 - - _data: np.ndarray | None = None - - @property - def data(self) -> np.ndarray: - """Returns the data backing this StatsModel.""" - if self._data is not None: - return self._data - raise RuntimeError("Data has not yet been set!") - - @data.setter - def data(self, data: np.ndarray) -> None: - """Sets the data backing this StatsModel.""" - if data is None: - return - self._data = data - self.histogram = cast( - tuple[Sequence[int], Sequence[float]], - np.histogram(self._data, bins=self.bins), - ) - self.average = np.average(self._data) - self.standard_deviation = np.std(self._data) diff --git a/src/ndv/histogram/view.py b/src/ndv/histogram/view.py deleted file mode 100644 index 8164ffdf..00000000 --- a/src/ndv/histogram/view.py +++ /dev/null @@ -1,178 +0,0 @@ -"""View interfaces for data display.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -import cmap -from psygnal import Signal - -if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Any - - -class StatsView(Protocol): - """A view of the statistics of a dataset.""" - - def set_histogram( - self, values: Sequence[float], bin_edges: Sequence[float] - ) -> None: - """Defines the distribution of the dataset. - - Properties - ---------- - values : Sequence[int] - A length (n) sequence of values representing clustered counts of data - points. values[i] defines the number of data points falling between - bin_edges[i] and bin_edges[i+1]. - bin_edges : Sequence[float] - A length (n+1) sequence of values defining the intervals partitioning - all data points. Must be non-decreasing. - """ - ... - - def set_std_dev(self, std_dev: float) -> None: - """Defines the standard deviation of the dataset. - - Properties - ---------- - std_dev : float - The standard deviation. - """ - ... - - def set_average(self, avg: float) -> None: - """Defines the average value of the dataset. - - Properties - ---------- - std_dev : float - The average value of the dataset. - """ - ... - - def view(self) -> Any: - """The native object that can be displayed.""" - ... - - -class LutView(Protocol): - """An (interactive) view of a LookUp Table (LUT).""" - - cmapChanged: Signal = Signal(cmap.Colormap) - gammaChanged: Signal = Signal(float) - climsChanged: Signal = Signal(tuple[float, float]) - autoscaleChanged: Signal = Signal(object) - - def set_visibility(self, visible: bool) -> None: - """Defines whether this view is visible. - - Properties - ---------- - visible : bool - True iff the view should be visible. - """ - ... - - def set_cmap(self, lut: cmap.Colormap) -> None: - """Defines the colormap backing the view. - - Properties - ---------- - lut : cmap.Colormap - The object mapping scalar values to RGB(A) colors. - """ - ... - - def set_gamma(self, gamma: float) -> None: - """Defines the exponent used for gamma correction. - - Properties - ---------- - gamma : float - The exponent used for gamma correction - """ - ... - - def set_clims(self, clims: tuple[float, float]) -> None: - """Defines the input clims. - - The contrast limits (clims) are the input values mapped to the minimum and - maximum (respectively) of the LUT. - - Properties - ---------- - clims : tuple[float, float] - The clims - """ - ... - - def set_autoscale(self, autoscale: bool | tuple[float, float]) -> None: - """Defines whether autoscale has been enabled. - - Autoscale defines whether the contrast limits (clims) are adjusted when the - data changes. - - Properties - ---------- - autoscale : bool | tuple[float, float] - If a boolean, true iff clims automatically changed on dataset alteration. - If a tuple, indicated that clims automatically changed. Values denote - the fraction of the dataset located below and above the lower and - upper clims, respectively. - """ - ... - - def view(self) -> Any: - """The native object that can be displayed.""" - ... - - -class HistogramView(StatsView, LutView): - """A histogram-based view for LookUp Table (LUT) adjustment.""" - - def set_domain(self, bounds: tuple[float, float] | None) -> None: - """Sets the domain of the view. - - Properties - ---------- - bounds : tuple[float, float] | None - If a tuple, sets the displayed extremes of the x axis to the passed - values. If None, sets them to the extent of the data instead. - """ - ... - - def set_range(self, bounds: tuple[float, float] | None) -> None: - """Sets the range of the view. - - Properties - ---------- - bounds : tuple[float, float] | None - If a tuple, sets the displayed extremes of the y axis to the passed - values. If None, sets them to the extent of the data instead. - """ - ... - - def set_vertical(self, vertical: bool) -> None: - """Sets the axis of the domain. - - Properties - ---------- - vertical : bool - If true, views the domain along the y axis and the range along the x - axis. If false, views the domain along the x axis and the range along - the y axis. - """ - ... - - def set_range_log(self, enabled: bool) -> None: - """Sets the axis scale of the range. - - Properties - ---------- - enabled : bool - If true, the range will be displayed with a logarithmic (base 10) - scale. If false, the range will be displayed with a linear scale. - """ - ... diff --git a/src/ndv/histogram/views/_vispy.py b/src/ndv/histogram/views/_vispy.py deleted file mode 100644 index 217ce48c..00000000 --- a/src/ndv/histogram/views/_vispy.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright (c) Vispy Development Team. All Rights Reserved. -# Distributed under the (new) BSD License. See LICENSE.txt for more info. -from __future__ import annotations - -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast - -import numpy as np -from qtpy.QtCore import Qt -from vispy import scene - -from ndv.histogram.view import HistogramView - -if TYPE_CHECKING: - from typing import Unpack - - import numpy.typing as npt - - -class Grabbable(Enum): - NONE = auto() - LEFT_CLIM = auto() - RIGHT_CLIM = auto() - GAMMA = auto() - - -if TYPE_CHECKING: - # just here cause vispy has poor type hints - from collections.abc import Sequence - - import cmap - from vispy.app.canvas import MouseEvent - - class Grid(scene.Grid): - def add_view( - self, - row: int | None = None, - col: int | None = None, - row_span: int = 1, - col_span: int = 1, - **kwargs: Any, - ) -> scene.ViewBox: - super().add_view(...) - - def add_widget( - self, - widget: None | scene.Widget = None, - row: int | None = None, - col: int | None = None, - row_span: int = 1, - col_span: int = 1, - **kwargs: Any, - ) -> scene.Widget: - super().add_widget(...) - - class WidgetKwargs(TypedDict, total=False): - pos: tuple[float, float] - size: tuple[float, float] - border_color: str - border_width: float - bgcolor: str - padding: float - margin: float - - class TextVisualKwargs(TypedDict, total=False): - text: str - color: str - bold: bool - italic: bool - face: str - font_size: float - pos: tuple[float, float] | tuple[float, float, float] - rotation: float - method: Literal["cpu", "gpu"] - depth_test: bool - - class AxisWidgetKwargs(TypedDict, total=False): - orientation: Literal["left", "bottom"] - tick_direction: tuple[int, int] - axis_color: str - tick_color: str - text_color: str - minor_tick_length: float - major_tick_length: float - tick_width: float - tick_label_margin: float - tick_font_size: float - axis_width: float - axis_label: str - axis_label_margin: float - axis_font_size: float - font_size: float # overrides tick_font_size and axis_font_size - - -__all__ = ["PlotWidget"] - - -DEFAULT_AXIS_KWARGS: AxisWidgetKwargs = { - "text_color": "w", - "axis_color": "w", - "tick_color": "w", - "tick_width": 1, - "tick_font_size": 8, - "tick_label_margin": 12, - "axis_label_margin": 50, - "minor_tick_length": 2, - "major_tick_length": 5, - "axis_width": 1, - "axis_font_size": 10, -} - - -class Component(str, Enum): - PAD_LEFT = "pad_left" - PAD_RIGHT = "pad_right" - PAD_BOTTOM = "pad_bottom" - TITLE = "title" - CBAR_TOP = "cbar_top" - CBAR_LEFT = "cbar_left" - CBAR_RIGHT = "cbar_right" - CBAR_BOTTOM = "cbar_bottom" - YAXIS = "yaxis" - XAXIS = "xaxis" - XLABEL = "xlabel" - YLABEL = "ylabel" - - def __str__(self) -> str: - return self.value - - -class PlotWidget(scene.Widget): - """Widget to facilitate plotting. - - Parameters - ---------- - fg_color : str - The default color for the plot. - xlabel : str - The x-axis label. - ylabel : str - The y-axis label. - title : str - The title of the plot. - lock_axis : {'x', 'y', None} - Prevent panning and zooming along a particular axis. - **widget_kwargs : dict - Keyword arguments to pass to the parent class. - """ - - def __init__( - self, - fg_color: str = "k", - xlabel: str = "", - ylabel: str = "", - title: str = "", - lock_axis: Literal["x", "y", None] = None, - **widget_kwargs: Unpack[WidgetKwargs], - ) -> None: - self._fg_color = fg_color - self._visuals: list[scene.VisualNode] = [] - super().__init__(**widget_kwargs) - self.unfreeze() - self.grid = cast("Grid", self.add_grid(spacing=0, margin=10)) - - title_kwargs: TextVisualKwargs = {"font_size": 14, "color": "w"} - label_kwargs: TextVisualKwargs = {"font_size": 10, "color": "w"} - self._title = scene.Label(str(title), **title_kwargs) - self._xlabel = scene.Label(str(xlabel), **label_kwargs) - self._ylabel = scene.Label(str(ylabel), rotation=-90, **label_kwargs) - - axis_kwargs: AxisWidgetKwargs = DEFAULT_AXIS_KWARGS - self.yaxis = scene.AxisWidget(orientation="left", **axis_kwargs) - self.xaxis = scene.AxisWidget(orientation="bottom", **axis_kwargs) - - # 2D Plot layout: - # - # c0 c1 c2 c3 c4 c5 c6 - # +----------+-------+-------+-------+---------+---------+-----------+ - # r0 | | | title | | | - # | +-----------------------+---------+---------+ | - # r1 | | | cbar | | | - # |----------+-------+-------+-------+---------+---------+ ----------| - # r2 | pad_left | cbar | ylabel| yaxis | view | cbar | pad_right | - # |----------+-------+-------+-------+---------+---------+ ----------| - # r3 | | | xaxis | | | - # | +-----------------------+---------+---------+ | - # r4 | | | xlabel | | | - # | +-----------------------+---------+---------+ | - # r5 | | | cbar | | | - # |---------+------------------------+---------+---------+-----------| - # r6 | | pad_bottom | | - # +---------+------------------------+---------+---------+-----------+ - - self._grid_wdgs: dict[Component, scene.Widget] = {} - for name, row, col, widget in [ - (Component.PAD_LEFT, 2, 0, None), - (Component.PAD_RIGHT, 2, 6, None), - (Component.PAD_BOTTOM, 6, 4, None), - (Component.TITLE, 0, 4, self._title), - (Component.CBAR_TOP, 1, 4, None), - (Component.CBAR_LEFT, 2, 1, None), - (Component.CBAR_RIGHT, 2, 5, None), - (Component.CBAR_BOTTOM, 5, 4, None), - (Component.YAXIS, 2, 3, self.yaxis), - (Component.XAXIS, 3, 4, self.xaxis), - (Component.XLABEL, 4, 4, self._xlabel), - (Component.YLABEL, 2, 2, self._ylabel), - ]: - self._grid_wdgs[name] = wdg = self.grid.add_widget(widget, row=row, col=col) - # If we don't set max size, they will expand to fill the entire grid - # occluding pretty much everything else. - if str(name).startswith(("cbar", "pad")): - if name in { - Component.PAD_LEFT, - Component.PAD_RIGHT, - Component.CBAR_LEFT, - Component.CBAR_RIGHT, - }: - wdg.width_max = 2 - else: - wdg.height_max = 2 - - # The main view into which plots are added - self._view = self.grid.add_view(row=2, col=4) - - # NOTE: this is a mess of hardcoded values... not sure whether they will work - # cross-platform. Note that `width_max` and `height_max` of 2 is actually - # *less* visible than 0 for some reason. They should also be extracted into - # some sort of `hide/show` logic for each component - self._grid_wdgs[Component.YAXIS].width_max = 30 # otherwise it takes too much - self._grid_wdgs[Component.PAD_LEFT].width_max = 20 # otherwise you get clipping - self._grid_wdgs[Component.XAXIS].height_max = 20 # otherwise it takes too much - self.ylabel = ylabel - self.xlabel = xlabel - self.title = title - - # VIEWBOX (this has to go last, see vispy #1748) - self.camera = self._view.camera = PanZoom1DCamera(lock_axis) - # this has to come after camera is set - self.xaxis.link_view(self._view) - self.yaxis.link_view(self._view) - self.freeze() - - @property - def title(self) -> str: - """The title label.""" - return self._title.text # type: ignore [no-any-return] - - @title.setter - def title(self, text: str) -> None: - """Set the title of the plot.""" - self._title.text = text - wdg = self._grid_wdgs[Component.TITLE] - wdg.height_min = wdg.height_max = 30 if text else 2 - - @property - def xlabel(self) -> str: - """The x-axis label.""" - return self._xlabel.text # type: ignore [no-any-return] - - @xlabel.setter - def xlabel(self, text: str) -> None: - """Set the x-axis label.""" - self._xlabel.text = text - wdg = self._grid_wdgs[Component.XLABEL] - wdg.height_min = wdg.height_max = 40 if text else 2 - - @property - def ylabel(self) -> str: - """The y-axis label.""" - return self._ylabel.text # type: ignore [no-any-return] - - @ylabel.setter - def ylabel(self, text: str) -> None: - """Set the x-axis label.""" - self._ylabel.text = text - wdg = self._grid_wdgs[Component.YLABEL] - wdg.width_min = wdg.width_max = 20 if text else 2 - - def lock_axis(self, axis: Literal["x", "y", None]) -> None: - """Prevent panning and zooming along a particular axis.""" - self.camera._axis = axis - # self.camera.set_range() - - -class PanZoom1DCamera(scene.cameras.PanZoomCamera): - """Camera that allows panning and zooming along one axis only. - - Parameters - ---------- - axis : {'x', 'y', None} - The axis along which to allow panning and zooming. - *args : tuple - Positional arguments to pass to the parent class. - **kwargs : dict - Keyword arguments to pass to the parent class. - """ - - def __init__( - self, axis: Literal["x", "y", None] = None, *args: Any, **kwargs: Any - ) -> None: - self._axis: Literal["x", "y", None] = axis - super().__init__(*args, **kwargs) - - @property - def axis_index(self) -> Literal[0, 1, None]: - """Return the index of the axis along which to pan and zoom.""" - if self._axis in ("x", 0): - return 0 - elif self._axis in ("y", 1): - return 1 - return None - - def zoom( - self, - factor: float | tuple[float, float], - center: tuple[float, ...] | None = None, - ) -> None: - """Zoom the camera by `factor` around `center`.""" - if self.axis_index is None: - super().zoom(factor, center=center) - return - - if isinstance(factor, (float, int)): - factor = (factor, factor) - _factor = list(factor) - _factor[self.axis_index] = 1 - super().zoom(_factor, center=center) - - def pan(self, pan: Sequence[float]) -> None: - """Pan the camera by `pan`.""" - if self.axis_index is None: - super().pan(pan) - return - _pan = list(pan) - _pan[self.axis_index] = 0 - super().pan(*_pan) - - def set_range( - self, - x: tuple | None = None, - y: tuple | None = None, - z: tuple | None = None, - margin: float = 0, # overriding to create a different default from super() - ) -> None: - """Reset the camera view to the specified range.""" - super().set_range(x, y, z, margin) - - -# Note: the only need for this superclass is the Signals that are defined on -# the protocols. Otherwise, they are just Protocols without any @abstractmethods. -class VispyHistogramView(HistogramView): - """A HistogramView on a VisPy SceneCanvas.""" - - def __init__(self) -> None: - # ------------ data and state ------------ # - - self._values: Sequence[float] | np.ndarray | None = None - self._bin_edges: Sequence[float] | np.ndarray | None = None - self._clims: tuple[float, float] | None = None - self._gamma: float = 1 - - # the currently grabbed object - self._grabbed: Grabbable = Grabbable.NONE - # whether the y-axis is logarithmic - self._log_y: bool = False - # whether the histogram is vertical - self._vertical: bool = False - # The values of the left and right edges on the canvas (respectively) - self._domain: tuple[float, float] | None = None - # The values of the bottom and top edges on the canvas (respectively) - self._range: tuple[float, float] | None = None - - # ------------ VisPy Canvas ------------ # - - self._canvas = scene.SceneCanvas() - self._canvas.unfreeze() - self._canvas.on_mouse_press = self.on_mouse_press - self._canvas.on_mouse_move = self.on_mouse_move - self._canvas.on_mouse_release = self.on_mouse_release - self._canvas.freeze() - - ## -- Visuals -- ## - - # NB We directly use scene.Mesh, instead of scene.Histogram, - # so that we can control the calculation of the histogram ourselves - self._hist_mesh = scene.Mesh(color="red") - - # The Lut Line visualizes both the clims (vertical line segments connecting the - # first two and last two points, respectively) and the gamma curve - # (the polyline between all remaining points) - self._lut_line = scene.LinePlot( - data=(0), # Dummy value to prevent resizing errors - color="k", - connect="strip", - symbol=None, - line_kind="-", - width=1.5, - marker_size=10.0, - edge_color="k", - face_color="b", - edge_width=1.0, - ) - self._lut_line.visible = False - self._lut_line.order = -1 - - # The gamma handle appears halfway between the clims - self._gamma_handle_pos: np.ndarray = np.ndarray((1, 2)) - self._gamma_handle = scene.Markers( - pos=self._gamma_handle_pos, - size=6, - edge_width=0, - ) - self._gamma_handle.visible = False - self._gamma_handle.order = -2 - - # One transform to rule them all! - self._handle_transform = scene.transforms.STTransform() - self._lut_line.transform = self._handle_transform - self._gamma_handle.transform = self._handle_transform - - ## -- Plot -- ## - self.plot = PlotWidget() - self.plot.lock_axis("y") - self._canvas.central_widget.add_widget(self.plot) - self.node_tform = self.plot.node_transform(self.plot._view.scene) - - self.plot._view.add(self._hist_mesh) - self.plot._view.add(self._lut_line) - self.plot._view.add(self._gamma_handle) - - # ------------- StatsView Protocol methods ------------- # - - def set_histogram( - self, values: Sequence[float], bin_edges: Sequence[float] - ) -> None: - """Set the histogram values and bin edges. - - These inputs follow the same format as the return value of numpy.histogram. - """ - self._values = values - self._bin_edges = bin_edges - self._update_histogram() - if self._clims is None: - self.set_clims((self._bin_edges[0], self._bin_edges[-1])) - self._resize() - - def set_std_dev(self, std_dev: float) -> None: - # Nothing to do. - # TODO: maybe show text somewhere - pass - - def set_average(self, average: float) -> None: - # Nothing to do - # TODO: maybe show text somewhere - pass - - def view(self) -> Any: - return self._canvas.native - - # ------------- LutView Protocol methods ------------- # - - def set_visibility(self, visible: bool) -> None: - if self._hist_mesh is None: - return # pragma: no cover - self._hist_mesh.visible = visible - self._lut_line.visible = visible - self._gamma_handle.visible = visible - - def set_cmap(self, lut: cmap.Colormap) -> None: - if self._hist_mesh is not None: - self._hist_mesh.color = lut.color_stops[-1].color.hex - - def set_gamma(self, gamma: float) -> None: - if gamma < 0: - raise ValueError("gamma must be non-negative!") - self._gamma = gamma - self._update_lut_lines() - - def set_clims(self, clims: tuple[float, float]) -> None: - if clims[1] < clims[0]: - clims = (clims[1], clims[0]) - self._clims = clims - self._update_lut_lines() - - def set_autoscale(self, autoscale: bool | tuple[float, float]) -> None: - # Nothing to do (yet) - pass - - # ------------- HistogramView Protocol methods ------------- # - - def set_domain(self, bounds: tuple[float, float] | None) -> None: - if bounds is not None: - if bounds[0] is None or bounds[1] is None: - # TODO: Sensible defaults? - raise ValueError("Domain min/max cannot be None!") - if bounds[0] > bounds[1]: - bounds = (bounds[1], bounds[0]) - self._domain = bounds - self._resize() - - def set_range(self, bounds: tuple[float, float] | None) -> None: - if bounds is not None: - if bounds[0] is None or bounds[1] is None: - # TODO: Sensible defaults? - raise ValueError("Range min/max cannot be None!") - if bounds[0] > bounds[1]: - bounds = (bounds[1], bounds[0]) - self._range = bounds - self._resize() - - def set_vertical(self, vertical: bool) -> None: - self._vertical = vertical - self._update_histogram() - self.plot.lock_axis("x" if vertical else "y") - # When vertical, smaller values should appear at the top of the canvas - self.plot.camera.flip = [False, vertical, False] - self._update_lut_lines() - self._resize() - - def set_range_log(self, enabled: bool) -> None: - if enabled != self._log_y: - self._log_y = enabled - self._update_histogram() - self._update_lut_lines() - self._resize() - - # ------------- Private methods ------------- # - - def _update_histogram(self) -> None: - """ - Updates the displayed histogram with current View parameters. - - NB: Much of this code is graciously borrowed from: - - https://github.com/vispy/vispy/blob/af847424425d4ce51f144a4d1c75ab4033fe39be/vispy/visuals/histogram.py#L28 - """ - if self._values is None or self._bin_edges is None: - return # pragma: no cover - values = self._values - if self._log_y: - # Replace zero values with 1 (which will be log10(1) = 0) - values = np.where(values == 0, 1, values) - values = np.log10(values) - - verts, faces = _hist_counts_to_mesh(values, self._bin_edges, self._vertical) - self._hist_mesh.set_data(vertices=verts, faces=faces) - - # FIXME: This should be called internally upon set_data, right? - # Looks like https://github.com/vispy/vispy/issues/1899 - self._hist_mesh._bounds_changed() - - def _update_lut_lines(self, npoints: int = 256) -> None: - if self._clims is None or self._gamma is None: - return # pragma: no cover - - # 2 additional points for each of the two vertical clims lines - X = np.empty(npoints + 4) - Y = np.empty(npoints + 4) - if self._vertical: - # clims lines - X[0:2], Y[0:2] = (1, 0.5), self._clims[0] - X[-2:], Y[-2:] = (0.5, 0), self._clims[1] - # gamma line - X[2:-2] = np.linspace(0, 1, npoints) ** self._gamma - Y[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) - midpoint = np.array([(2**-self._gamma, np.mean(self._clims))]) - else: - # clims lines - X[0:2], Y[0:2] = self._clims[0], (1, 0.5) - X[-2:], Y[-2:] = self._clims[1], (0.5, 0) - # gamma line - X[2:-2] = np.linspace(self._clims[0], self._clims[1], npoints) - Y[2:-2] = np.linspace(0, 1, npoints) ** self._gamma - midpoint = np.array([(np.mean(self._clims), 2**-self._gamma)]) - - # TODO: Move to self.edit_cmap - color = np.linspace(0.2, 0.8, npoints + 4).repeat(4).reshape(-1, 4) - c1, c2 = [0.4] * 4, [0.7] * 4 - color[0:3] = [c1, c2, c1] - color[-3:] = [c1, c2, c1] - - self._lut_line.set_data((X, Y), marker_size=0, color=color) - self._lut_line.visible = True - - self._gamma_handle_pos[:] = midpoint[0] - self._gamma_handle.set_data(pos=self._gamma_handle_pos) - self._gamma_handle.visible = True - - # FIXME: These should be called internally upon set_data, right? - # Looks like https://github.com/vispy/vispy/issues/1899 - self._lut_line._bounds_changed() - for v in self._lut_line._subvisuals: - v._bounds_changed() - self._gamma_handle._bounds_changed() - - def on_mouse_press(self, event: MouseEvent) -> None: - if event.pos is None: - return # pragma: no cover - # check whether the user grabbed a node - self._grabbed = self._find_nearby_node(event) - if self._grabbed != Grabbable.NONE: - # disconnect the pan/zoom mouse events until handle is dropped - self.plot.camera.interactive = False - - def on_mouse_release(self, event: MouseEvent) -> None: - self._grabbed = Grabbable.NONE - self.plot.camera.interactive = True - - def on_mouse_move(self, event: MouseEvent) -> None: - """Called whenever mouse moves over canvas.""" - if event.pos is None: - return # pragma: no cover - if self._clims is None: - return # pragma: no cover - - if self._grabbed in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: - newlims = list(self._clims) - if self._vertical: - c = self._to_plot_coords(event.pos)[1] - else: - c = self._to_plot_coords(event.pos)[0] - if self._grabbed is Grabbable.LEFT_CLIM: - newlims[0] = min(newlims[1], c) - elif self._grabbed is Grabbable.RIGHT_CLIM: - newlims[1] = max(newlims[0], c) - self.climsChanged.emit(newlims) - return - elif self._grabbed is Grabbable.GAMMA: - y0, y1 = ( - self.plot.xaxis.axis.domain - if self._vertical - else self.plot.yaxis.axis.domain - ) - y = self._to_plot_coords(event.pos)[0 if self._vertical else 1] - if y < np.maximum(y0, 0) or y > y1: - return - self.gammaChanged.emit(-np.log2(y / y1)) - return - - # TODO: try to remove the Qt aspect here so that we can use - # this for Jupyter as well - self._canvas.native.unsetCursor() - - nearby = self._find_nearby_node(event) - - if nearby in [Grabbable.LEFT_CLIM, Grabbable.RIGHT_CLIM]: - if self._vertical: - cursor = Qt.CursorShape.SplitVCursor - else: - cursor = Qt.CursorShape.SplitHCursor - self._canvas.native.setCursor(cursor) - elif nearby is Grabbable.GAMMA: - if self._vertical: - cursor = Qt.CursorShape.SplitHCursor - else: - cursor = Qt.CursorShape.SplitVCursor - self._canvas.native.setCursor(cursor) - else: - x, y = self._to_plot_coords(event.pos) - x1, x2 = self.plot.xaxis.axis.domain - y1, y2 = self.plot.yaxis.axis.domain - if (x1 < x <= x2) and (y1 <= y <= y2): - self._canvas.native.setCursor(Qt.CursorShape.SizeAllCursor) - - def _find_nearby_node(self, event: MouseEvent, tolerance: int = 5) -> Grabbable: - """Describes whether the event is near a clim.""" - click_x, click_y = event.pos - - # NB Computations are performed in canvas-space - # for easier tolerance computation. - plot_to_canvas = self.node_tform.imap - gamma_to_plot = self._handle_transform.map - - if self._clims is not None: - if self._vertical: - click = click_y - right = plot_to_canvas([0, self._clims[1]])[1] - left = plot_to_canvas([0, self._clims[0]])[1] - else: - click = click_x - right = plot_to_canvas([self._clims[1], 0])[0] - left = plot_to_canvas([self._clims[0], 0])[0] - - # Right bound always selected on overlap - if bool(abs(right - click) < tolerance): - return Grabbable.RIGHT_CLIM - if bool(abs(left - click) < tolerance): - return Grabbable.LEFT_CLIM - - if self._gamma_handle_pos is not None: - gx, gy = plot_to_canvas(gamma_to_plot(self._gamma_handle_pos[0]))[:2] - if bool(abs(gx - click_x) < tolerance and abs(gy - click_y) < tolerance): - return Grabbable.GAMMA - - return Grabbable.NONE - - def _to_plot_coords(self, pos: Sequence[float]) -> tuple[float, float]: - """Return the plot coordinates of the given position.""" - x, y = self.node_tform.map(pos)[:2] - return x, y - - def _resize(self) -> None: - self.plot.camera.set_range( - x=self._range if self._vertical else self._domain, - y=self._domain if self._vertical else self._range, - # FIXME: Bitten by https://github.com/vispy/vispy/issues/1483 - # It's pretty visible in logarithmic mode - margin=1e-30, - ) - if self._vertical: - scale = 0.98 * self.plot.xaxis.axis.domain[1] - self._handle_transform.scale = (scale, 1) - else: - scale = 0.98 * self.plot.yaxis.axis.domain[1] - self._handle_transform.scale = (1, scale) - - -def _hist_counts_to_mesh( - values: Sequence[float] | npt.NDArray, - bin_edges: Sequence[float] | npt.NDArray, - vertical: bool = False, -) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.uint32]]: - """Convert histogram counts to mesh vertices and faces for plotting.""" - n_edges = len(bin_edges) - X, Y = (1, 0) if vertical else (0, 1) - - # 4-5 - # | | - # 1-2/7-8 - # |/| | | - # 0-3-6-9 - # construct vertices - vertices = np.zeros((3 * n_edges - 2, 3), np.float32) - vertices[:, X] = np.repeat(bin_edges, 3)[1:-1] - vertices[1::3, Y] = values - vertices[2::3, Y] = values - vertices[vertices == float("-inf")] = 0 - - # construct triangles - faces = np.zeros((2 * n_edges - 2, 3), np.uint32) - offsets = 3 * np.arange(n_edges - 1, dtype=np.uint32)[:, np.newaxis] - faces[::2] = np.array([0, 2, 1]) + offsets - faces[1::2] = np.array([2, 0, 3]) + offsets - - return vertices, faces diff --git a/src/ndv/models/__init__.py b/src/ndv/models/__init__.py new file mode 100644 index 00000000..9ecdde27 --- /dev/null +++ b/src/ndv/models/__init__.py @@ -0,0 +1,7 @@ +"""Models for NDV.""" + +from ._array_display_model import ArrayDisplayModel +from ._lut_model import LUTModel +from .data_wrappers._data_wrapper import DataWrapper + +__all__ = ["ArrayDisplayModel", "DataWrapper", "LUTModel"] diff --git a/src/ndv/models/_array_display_model.py b/src/ndv/models/_array_display_model.py new file mode 100644 index 00000000..035afa92 --- /dev/null +++ b/src/ndv/models/_array_display_model.py @@ -0,0 +1,186 @@ +"""General model for ndv.""" + +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union, cast + +from pydantic import Field, computed_field, model_validator +from typing_extensions import Self, TypeAlias + +from ndv._types import AxisKey, Slice + +from ._base_model import NDVModel +from ._lut_model import LUTModel +from ._mapping import ValidatedEventedDict +from ._reducer import ReducerType + +if TYPE_CHECKING: + from collections.abc import Mapping + + import cmap + + from ._lut_model import AutoscaleType + + class LutModelKwargs(TypedDict, total=False): + visible: bool + cmap: cmap.Colormap | cmap._colormap.ColorStopsLike + clims: tuple[float, float] | None + gamma: float + autoscale: AutoscaleType + + class ArrayDisplayModelKwargs(TypedDict, total=False): + visible_axes: tuple[AxisKey, AxisKey, AxisKey] | tuple[AxisKey, AxisKey] + current_index: Mapping[AxisKey, Union[int, slice]] + channel_mode: "ChannelMode" | Literal["grayscale", "composite", "color", "rgba"] + channel_axis: Optional[AxisKey] + reducers: Mapping[AxisKey | None, ReducerType] + luts: Mapping[int | None, LUTModel | LutModelKwargs] + default_lut: LUTModel | LutModelKwargs + + +# map of axis to index/slice ... i.e. the current subset of data being displayed +IndexMap: TypeAlias = ValidatedEventedDict[AxisKey, Union[int, Slice]] +# map of index along channel axis to LUTModel object +LutMap: TypeAlias = ValidatedEventedDict[Union[int, None], LUTModel] +# map of axis to reducer +Reducers: TypeAlias = ValidatedEventedDict[Union[AxisKey, None], ReducerType] +# used for visible_axes +TwoOrThreeAxisTuple: TypeAlias = Union[ + tuple[AxisKey, AxisKey, AxisKey], tuple[AxisKey, AxisKey] +] + + +def _default_luts() -> LutMap: + colors = ["green", "magenta", "cyan", "red", "blue", "yellow"] + return ValidatedEventedDict( + (i, LUTModel(cmap=color)) for i, color in enumerate(colors) + ) + + +class ChannelMode(str, Enum): + """Channel display mode. + + Attributes + ---------- + GRAYSCALE : str + The array is displayed as a single channel, with a single lookup table applied. + In this mode, there effective *is* no channel axis: all non-visible dimensions + have sliders, and there is a single LUT control (the `default_lut`). + COMPOSITE : str + Display all (or a subset of) channels as a composite image, with a different + lookup table applied to each channel. In this mode, the slider for the channel + axis is hidden by default, and LUT controls for each channel are shown. + COLOR : str + Display a single channel at a time as a color image, with a channel-specific + lookup table applied. In this mode, the slider for the channel axis is shown, + and the user can select which channel to display. LUT controls are shown for + all channels. + RGBA : str + The array is displayed as an RGB image, with a single lookup table applied. + In this mode, the slider for the channel axis is hidden, and a single LUT + control is shown. Only valid when channel axis has length <= 4. + RGB : str + Alias for RGBA. + """ + + GRAYSCALE = "grayscale" + COMPOSITE = "composite" + COLOR = "color" + RGBA = "rgba" + + def __str__(self) -> str: + return self.value + + def is_multichannel(self) -> bool: + """Return whether this mode displays multiple channels. + + If `is_multichannel` is True, then the `channel_axis` slider should be hidden. + """ + return self in (ChannelMode.COMPOSITE, ChannelMode.RGBA) + + +ChannelMode._member_map_["RGB"] = ChannelMode.RGBA # ChannelMode["RGB"] +ChannelMode._value2member_map_["rgb"] = ChannelMode.RGBA # ChannelMode("rgb") + + +class ArrayDisplayModel(NDVModel): + """Model of how to display an array. + + In the following types, `AxisKey` can be either an integer index or a string label. + + Parameters + ---------- + visible_axes : tuple[AxisKey, ...] + Ordered list of axes to visualize, from slowest to fastest. + e.g. ('z', -2, -1) + current_index : Mapping[AxisKey, int | Slice] + The currently displayed position/slice along each dimension. + e.g. {0: 0, 'time': slice(10, 20)} + Not all axes need be present, and axes not present are assumed to + be slice(None), meaning it is up to the controller of this model to restrict + indices to an efficient range for retrieval. + If the number of non-singleton axes is greater than `n_visible_axes`, + then reducers are used to reduce the data along the remaining axes. + NOTE: In terms of requesting data, there is a slight "delocalization" of state + here in that we probably also want to avoid requesting data for channel + positions that are not visible. + reducers : Mapping[AxisKey | None, ReducerType] + Callable to reduce data along axes remaining after slicing. + Ideally, the ufunc should accept an `axis` argument. + (TODO: what happens if not?) + channel_mode : ChannelMode + How to display channel information: + - `GRAYSCALE`: ignore channel axis, use `default_lut` + - `COMPOSITE`: display all channels as a composite image, using `luts` + - `COLOR`: display a single channel at a time, using `luts` + - `RGBA`: display as an RGB image, using `default_lut` (except for cmap) + + If `channel_mode` is set to anything other than `GRAYSCALE`, then `channel_axis` + must be set to a valid axis; if no `channel_axis` is set, at the time of + display, the `DataWrapper` MAY guess the `channel_axis`, and set it on the + model. + channel_axis : AxisKey | None + The dimension index or name of the channel dimension. + The implication of setting channel_axis is that *all* elements along the channel + dimension are shown, with different LUTs applied to each channel. + If None, then a single lookup table is used for all channels (`luts[None]`). + NOTE: it is an error for channel_axis to be in `visible_axes` (or ignore it?) + luts : Mapping[int | None, LUTModel] + Instructions for how to display each channel of the array. + Keys represent position along the dimension specified by `channel_axis`. + Values are `LUT` objects that specify how to display the channel. + The special key `None` is used to represent a fallback LUT for all channels, + and is used when `channel_axis` is None. It should always be present + """ + + visible_axes: TwoOrThreeAxisTuple = (-2, -1) + current_index: IndexMap = Field(default_factory=IndexMap, frozen=True) + + channel_mode: ChannelMode = ChannelMode.GRAYSCALE + channel_axis: Optional[AxisKey] = None + + # map of axis to reducer (function that can reduce dimensionality along that axis) + reducers: Reducers = Field(default_factory=Reducers, frozen=True) + default_reducer: ReducerType = "numpy.max" # type: ignore [assignment] # FIXME + + # map of index along channel axis to LUTModel object + luts: LutMap = Field(default_factory=_default_luts) + default_lut: LUTModel = Field(default_factory=LUTModel, frozen=True) + + @computed_field # type: ignore [prop-decorator] + @property + def n_visible_axes(self) -> Literal[2, 3]: + """Number of dims is derived from the length of `visible_axes`.""" + return cast(Literal[2, 3], len(self.visible_axes)) + + @model_validator(mode="after") + def _validate_model(self) -> "Self": + # prevent channel_axis from being in visible_axes + if self.channel_axis in self.visible_axes: + warnings.warn( + "Channel_axis cannot be in visible_axes. Setting channel_axis to None.", + UserWarning, + stacklevel=2, + ) + self.channel_axis = None + return self diff --git a/src/ndv/models/_base_model.py b/src/ndv/models/_base_model.py new file mode 100644 index 00000000..310ef498 --- /dev/null +++ b/src/ndv/models/_base_model.py @@ -0,0 +1,15 @@ +from typing import ClassVar + +from psygnal import SignalGroupDescriptor +from pydantic import BaseModel, ConfigDict + + +class NDVModel(BaseModel): + """Base evented model for NDV models.""" + + model_config = ConfigDict( + validate_assignment=True, + validate_default=True, + extra="forbid", + ) + events: ClassVar[SignalGroupDescriptor] = SignalGroupDescriptor() diff --git a/src/ndv/models/_data_display_model.py b/src/ndv/models/_data_display_model.py new file mode 100644 index 00000000..58b72909 --- /dev/null +++ b/src/ndv/models/_data_display_model.py @@ -0,0 +1,205 @@ +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from dataclasses import dataclass, field +from typing import Any, Optional, Union, cast + +import numpy as np +from pydantic import Field + +from ndv.models._array_display_model import ArrayDisplayModel, ChannelMode +from ndv.models._base_model import NDVModel + +from .data_wrappers import DataWrapper + +__all__ = ["DataRequest", "DataResponse", "_ArrayDataDisplayModel"] + + +@dataclass +class DataRequest: + """Request object for data slicing.""" + + wrapper: DataWrapper + index: Mapping[int, Union[int, slice]] + visible_axes: tuple[int, ...] + channel_axis: Optional[int] + + +@dataclass +class DataResponse: + """Response object for data requests.""" + + data: np.ndarray = field(repr=False) + channel_key: Optional[int] + request: Optional[DataRequest] = None + + +class _ArrayDataDisplayModel(NDVModel): + """Utility class combining ArrayDisplayModel model with a DataWrapper. + + The `ArrayDisplayModel` can be thought of as an "instruction" for how to display + some data, while the `DataWrapper` is the actual data. This class combines the two + and provides a way to access the data in a normalized way (i.e. be converting + AxisKeys in the display model to positive integers, based on the available + dimensions of the DataWrapper). This makes it easier to index into the data, even + with named axes, which this class also helps manage with the `request_sliced_data` + method. + + Having this class composed of the two other models (rather than inheriting from + `ArrayDisplayModel`) allows for multiple models to share the same underlying + display model (e.g. for linked views). + + Attributes + ---------- + display : ArrayDisplayModel + The display model. Provides instructions for how to display the data. + data_wrapper : DataWrapper | None + The data wrapper. Provides the actual data to be displayed + """ + + display: ArrayDisplayModel = Field(default_factory=ArrayDisplayModel) + data_wrapper: Optional[DataWrapper] = None + + def model_post_init(self, __context: Any) -> None: + # connect the channel mode change signal to the channel axis guessing method + self.display.events.channel_mode.connect(self._on_channel_mode_change) + + def _on_channel_mode_change(self) -> None: + # if the mode is not grayscale, and the channel axis is not set, + # we let the data wrapper guess the channel axis + if ( + self.display.channel_mode != ChannelMode.GRAYSCALE + and self.display.channel_axis is None + and self.data_wrapper is not None + ): + # only use the guess if it's not already in the visible axes + guess = self.data_wrapper.guess_channel_axis() + if guess not in self.normed_visible_axes: + self.display.channel_axis = guess + + # Properties for normalized data access ----------------------------------------- + # these all use positive integers as axis keys + + def _ensure_wrapper(self) -> DataWrapper: + if self.data_wrapper is None: + raise ValueError("Cannot normalize axes. No data is set.") + return self.data_wrapper + + @property + def normed_data_coords(self) -> Mapping[int, Sequence]: + """Return the coordinates of the data as positive integers.""" + if (wrapper := self.data_wrapper) is None: + return {} + return {wrapper.normalized_axis_key(d): wrapper.coords[d] for d in wrapper.dims} + + @property + def normed_visible_axes(self) -> "tuple[int, int, int] | tuple[int, int]": + """Return the visible axes as positive integers.""" + wrapper = self._ensure_wrapper() + return tuple( # type: ignore [return-value] + wrapper.normalized_axis_key(ax) for ax in self.display.visible_axes + ) + + @property + def normed_current_index(self) -> Mapping[int, Union[int, slice]]: + """Return the current index with positive integer axis keys.""" + wrapper = self._ensure_wrapper() + return { + wrapper.normalized_axis_key(ax): v + for ax, v in self.display.current_index.items() + } + + @property + def normed_channel_axis(self) -> "int | None": + """Return the channel axis as positive integers.""" + if self.display.channel_axis is None: + return None + wrapper = self._ensure_wrapper() + return wrapper.normalized_axis_key(self.display.channel_axis) + + # Indexing and Data Slicing ----------------------------------------------------- + + def current_slice_requests(self) -> list[DataRequest]: + """Return the current index request for the data. + + This reconciles the `current_index` and `visible_axes` attributes of the display + with the available dimensions of the data to return a valid index request. + In the returned mapping, the keys are the normalized (non-negative integer) + axis indices and the values are either integers or slices (where axes present + in `visible_axes` are guaranteed to be slices rather than integers). + """ + if self.data_wrapper is None: + return [] + + requested_slice = dict(self.normed_current_index) + for ax in self.normed_visible_axes: + if not isinstance(requested_slice.get(ax), slice): + requested_slice[ax] = slice(None) + + # if we need to request multiple channels (composite mode or RGB), + # ensure that the channel axis is also sliced + if c_ax := self.normed_channel_axis: + if self.display.channel_mode.is_multichannel(): + if not isinstance(requested_slice.get(c_ax), slice): + requested_slice[c_ax] = slice(None) + else: + # somewhat of a hack. + # we heed DataRequest.channel_axis to be None if we want the view + # to use the default_lut + c_ax = None + + # ensure that all axes are slices, so that we don't lose any dimensions. + # data will be squeezed to remove singleton dimensions later after + # transposing according to the order of visible axes + # (this operation happens below in `current_data_slice`) + for ax, val in requested_slice.items(): + if isinstance(val, int): + requested_slice[ax] = slice(val, val + 1) + + return [ + DataRequest( + wrapper=self.data_wrapper, + index=requested_slice, + visible_axes=self.normed_visible_axes, + channel_axis=c_ax, + ) + ] + + # TODO: make async + def request_sliced_data(self) -> list[Future[DataResponse]]: + """Return the slice of data requested by the current index (synchronous).""" + if self.data_wrapper is None: + raise ValueError("Data not set") + + if not (requests := self.current_slice_requests()): + return [] + + futures: list[Future[DataResponse]] = [] + for req in requests: + data = req.wrapper.isel(req.index) + + # for transposing according to the order of visible axes + vis_ax = req.visible_axes + t_dims = vis_ax + tuple(i for i in range(data.ndim) if i not in vis_ax) + + if (ch_ax := req.channel_axis) is not None: + ch_indices: Iterable[Optional[int]] = range(data.shape[ch_ax]) + else: + ch_indices = (None,) + + for i in ch_indices: + if i is None: + ch_data = data + else: + ch_keepdims = (slice(None),) * cast(int, ch_ax) + (i,) + (None,) + ch_data = data[ch_keepdims] + future = Future[DataResponse]() + future.set_result( + DataResponse( + data=ch_data.transpose(*t_dims).squeeze(), + channel_key=i, + request=req, + ) + ) + futures.append(future) + + return futures diff --git a/src/ndv/models/_lut_model.py b/src/ndv/models/_lut_model.py new file mode 100644 index 00000000..c361ce16 --- /dev/null +++ b/src/ndv/models/_lut_model.py @@ -0,0 +1,55 @@ +from typing import Any, Callable, Optional, Union + +import numpy.typing as npt +from cmap import Colormap +from pydantic import Field, model_validator +from typing_extensions import TypeAlias + +from ._base_model import NDVModel + +AutoscaleType: TypeAlias = Union[ + Callable[[npt.ArrayLike], tuple[float, float]], tuple[float, float], bool +] + + +class LUTModel(NDVModel): + """Representation of how to display a channel of an array. + + Parameters + ---------- + visible : bool + Whether to display this channel. + NOTE: This has implications for data retrieval, as we may not want to request + channels that are not visible. See current_index above. + cmap : Colormap + Colormap to use for this channel. + clims : tuple[float, float] | None + Contrast limits for this channel. + TODO: What does `None` imply? Autoscale? + gamma : float + Gamma correction for this channel. By default, 1.0. + autoscale : bool | tuple[float, float] + Whether/how to autoscale the colormap. + If `False`, then autoscaling is disabled. + If `True` or `(0, 1)` then autoscale using the min/max of the data. + If a tuple, then the first element is the lower quantile and the second element + is the upper quantile. + If a callable, then it should be a function that takes an array and returns a + tuple of (min, max) values to use for scaling. + + NaN values should be ignored (n.b. nanmax is slower and should only be used if + necessary). + """ + + visible: bool = True + cmap: Colormap = Field(default_factory=lambda: Colormap("gray")) + clims: Optional[tuple[float, float]] = None + gamma: float = 1.0 + autoscale: AutoscaleType = Field(default=True, union_mode="left_to_right") + + @model_validator(mode="before") + def _validate_model(cls, v: Any) -> Any: + # cast bare string/colormap inputs to cmap declaration + if isinstance(v, (str, Colormap)): + return {"cmap": v} + return v diff --git a/src/ndv/models/_mapping.py b/src/ndv/models/_mapping.py new file mode 100644 index 00000000..340f24f4 --- /dev/null +++ b/src/ndv/models/_mapping.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator, MutableMapping +from contextlib import suppress +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Protocol, + TypeVar, + cast, + get_args, + overload, +) + +from psygnal import Signal +from pydantic import TypeAdapter +from pydantic_core import SchemaValidator, core_schema + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") +_VT_co = TypeVar("_VT_co", covariant=True) +_NULL = object() + + +class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): + def keys(self) -> Iterable[_KT]: ... + def __getitem__(self, key: _KT, /) -> _VT_co: ... + + +class ValidatedEventedDict(MutableMapping[_KT, _VT]): + item_added = Signal(str, object) # key, new_value + item_removed = Signal(str, object) # key, old_value + item_changed = Signal(str, object, object) # key, new_value, old_value + value_changed = Signal() + + # long ugly overloads to support all possible ways to initialize a ValidatedDict + @overload + def __init__(self) -> None: ... + @overload + def __init__( # type: ignore[misc] + self: dict[str, _VT], + key_validator: Callable[[Any], _KT] | None = None, + value_validator: Callable[[Any], _VT] | None = None, + **kwargs: _VT, + ) -> None: ... + @overload + def __init__( + self, + map: SupportsKeysAndGetItem[_KT, _VT], + /, + key_validator: Callable[[Any], _KT] | None = None, + value_validator: Callable[[Any], _VT] | None = None, + ) -> None: ... + @overload + def __init__( # type: ignore[misc] + self: dict[str, _VT], + map: SupportsKeysAndGetItem[str, _VT], + /, + key_validator: Callable[[Any], _KT] | None = ..., + value_validator: Callable[[Any], _VT] | None = ..., + validate_lookup: bool = ..., + **kwargs: _VT, + ) -> None: ... + @overload + def __init__( + self, + iterable: Iterable[tuple[_KT, _VT]], + /, + key_validator: Callable[[Any], _KT] | None = ..., + value_validator: Callable[[Any], _VT] | None = ..., + validate_lookup: bool = ..., + ) -> None: ... + @overload + def __init__( # type: ignore[misc] + self: dict[str, _VT], + iterable: Iterable[tuple[str, _VT]], + /, + key_validator: Callable[[Any], _KT] | None = ..., + value_validator: Callable[[Any], _VT] | None = ..., + validate_lookup: bool = ..., + **kwargs: _VT, + ) -> None: ... + def __init__( # type: ignore[misc] # does not accept all possible overloads + self, + *args: Any, + key_validator: Callable[[Any], _KT] | None = None, + value_validator: Callable[[Any], _VT] | None = None, + validate_lookup: bool = False, + **kwargs: Any, + ) -> None: + self._key_validator = key_validator + self._value_validator = value_validator + self._validate_lookup = validate_lookup + _d = {} + for k, v in dict(*args, **kwargs).items(): + if self._key_validator is not None: + k = self._key_validator(k) + if self._value_validator is not None: + v = self._value_validator(v) + _d[k] = v + self._dict: dict[_KT, _VT] = _d + + def __missing__(self, key: _KT) -> _VT: + raise KeyError(key) + + # ---------------- abstract interface ---------------- + + def __getitem__(self, key: _KT) -> _VT: + if self._validate_lookup: + key = self._validate_key(key) + try: + return self._dict[key] + except KeyError: + return self.__missing__(key) + + # def __setitem__(self, key: _KT, value: _VT) -> None: + # we allow Any here because validation may change the type of the value. + def __setitem__(self, key: Any, value: Any) -> None: + key = self._validate_key(key) + value = self._validate_value(value) + before = self._dict.get(key, _NULL) + self._dict[key] = value + # if the value is the same as before, try to exit early without emitting signals + # but catch exceptions that may be raised during __eq__ (like numpy) + if before is not _NULL: + with suppress(Exception): + if before == value: + return + self.item_changed.emit(key, value, before) + else: + self.item_added.emit(key, value) + self.value_changed.emit() + + def __delitem__(self, key: _KT) -> None: + if self._validate_lookup: + key = self._validate_key(key) + # TODO: maybe add removing signal (before actual removal) if needed + item = self._dict.pop(key) + self.item_removed.emit(key, item) + self.value_changed.emit() + + def __len__(self) -> int: + return len(self._dict) + + def __iter__(self) -> Iterator[_KT]: + return iter(self._dict) + + # batch operations, with a single value_changed ------------------------------- + + @overload + def assign(self, m: SupportsKeysAndGetItem[_KT, _VT], /, **kwargs: _VT) -> None: ... + @overload + def assign(self, m: Iterable[tuple[_KT, _VT]], /, **kwargs: _VT) -> None: ... + @overload + def assign(self, **kwargs: _VT) -> None: ... + def assign(self, *args: Any, **kwargs: _VT) -> None: # type: ignore[misc] + """Override state with the key/value pairs from the mapping or iterable. + + Similar to update, but clears the dictionary first (without signals), replacing + the contents with the key/value pairs from the mapping or iterable, and then + emitting a single value_changed signal at the end. + """ + with self.value_changed.blocked(): + self.clear() + self.update(*args, **kwargs) + self.value_changed.emit() + + @overload + def update(self, m: SupportsKeysAndGetItem[_KT, _VT], /, **kwargs: _VT) -> None: ... + @overload + def update(self, m: Iterable[tuple[_KT, _VT]], /, **kwargs: _VT) -> None: ... + @overload + def update(self, **kwargs: _VT) -> None: ... + def update(self, *args: Any, **kwargs: _VT) -> None: # type: ignore[misc] + """Update the dictionary with the key/value pairs from the mapping or iterable. + + only emit a single value_changed signal at the end. + """ + with self.value_changed.blocked(): + super().update(*args, **kwargs) + # TODO: only emit if anything was caught + self.value_changed.emit() + + def clear(self) -> None: + """Clear the dictionary. + + only emit a single value_changed signal at the end. + """ + with self.value_changed.blocked(): + super().clear() + # TODO: only emit if anything was caught + self.value_changed.emit() + + # ----------------------------------------------------- + + @cached_property + def _validate_key(self) -> Callable[[Any], _KT]: + """Return a function that validates keys.""" + if self._key_validator is not None: + return self._key_validator + # No key validator was provided during init. Try to get the key type from the + # class type hint and return a validator function for it. + # __orig_class__ is not available during __init__ + # https://discuss.python.org/t/runtime-access-to-type-parameters/37517 + cls = getattr(self, "__orig_class__", None) or type(self) + if args := get_args(cls): + return TypeAdapter(args[0]).validator.validate_python + # fall back to identity function + return lambda x: x + + @cached_property + def _validate_value(self) -> Callable[[Any], _VT]: + """Return a function that validates values.""" + if self._value_validator is not None: + return self._value_validator + # No value validator was provided during init. Try to get the value type from + # the class type hint and return a validator function for it. + # __orig_class__ is not available during __init__ + # https://discuss.python.org/t/runtime-access-to-type-parameters/37517 + cls = getattr(self, "__orig_class__", None) or type(self) + if len(args := get_args(cls)) > 1: + return TypeAdapter(args[1]).validator.validate_python + # fall back to identity function + return lambda x: x + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._dict!r})" + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Return the Pydantic core schema for this object. + + In this method, we parse the key and value types of the `source_type`, which + be something like ValidatedDict[KT, VT]. And then create a validator function + that creates a new instance of the ValidatedDict during assignment, passing in + the key and value validator functions (from pydantic). + + Parameters + ---------- + source_type : Any + The source type. This will usually be `cls`. + handler : GetCoreSchemaHandler + Handler to call into the next CoreSchema schema generation function. + """ + # get key/value types from the source_type type hint. + key_type = val_type = Any + if args := get_args(source_type): + key_type = args[0] + if len(args) > 1: + val_type = args[1] + + # get key/value core schemas for the key/value types. + keys_schema = _get_schema(key_type, handler) + values_schema = _get_schema(val_type, handler) + validate_key = SchemaValidator(keys_schema).validate_python + validate_value = SchemaValidator(values_schema).validate_python + + # define function that creates new instance during assignment + # passing in the validator functions. + def _new(*args: Any, **kwargs: Any) -> ValidatedEventedDict[_KT, _VT]: + return cls( # type: ignore + *args, + key_validator=validate_key, + value_validator=validate_value, + **kwargs, + ) + + # this schema for this validated dict + dict_schema = core_schema.dict_schema( + keys_schema=keys_schema, + values_schema=values_schema, + ) + # wrap the schema with a validator function that creates a new instance, + # passing in the key/value validators. + return core_schema.no_info_after_validator_function( + function=_new, + schema=dict_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: x._dict, return_schema=dict_schema + ), + ) + + +def _get_schema(hint: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + # check if the hint already has a core schema attached to it. + if hasattr(hint, "__pydantic_core_schema__"): + return cast("core_schema.CoreSchema", hint.__pydantic_core_schema__) + # otherwise, call the handler to get the core schema. + return handler.generate_schema(hint) diff --git a/src/ndv/models/_reducer.py b/src/ndv/models/_reducer.py new file mode 100644 index 00000000..c374a7e8 --- /dev/null +++ b/src/ndv/models/_reducer.py @@ -0,0 +1,69 @@ +from collections.abc import Sequence +from typing import Any, Callable, Protocol, SupportsIndex, Union, cast + +import numpy as np +import numpy.typing as npt +from pydantic_core import core_schema +from typing_extensions import TypeAlias + +_ShapeLike: TypeAlias = Union[SupportsIndex, Sequence[SupportsIndex]] + + +class Reducer(Protocol): + """Function to reduce an array along an axis. + + A reducer is any function that takes an array-like, and an optional axis argument, + and returns a reduced array. Examples include `np.max`, `np.mean`, etc. + """ + + def __call__(self, a: npt.ArrayLike, axis: _ShapeLike = ...) -> npt.ArrayLike: + """Reduce an array along an axis.""" + + +def _str_to_callable(obj: Any) -> Callable: + """Deserialize a callable from a string.""" + if isinstance(obj, str): + # e.g. "numpy.max" -> numpy.max + try: + mod_name, qual_name = obj.rsplit(".", 1) + mod = __import__(mod_name, fromlist=[qual_name]) + obj = getattr(mod, qual_name) + except Exception: + try: + # fallback to numpy + # e.g. "max" -> numpy.max + obj = getattr(np, obj) + except Exception: + raise + + if not callable(obj): + raise TypeError(f"Expected a callable or string, got {type(obj)}") + return cast("Callable", obj) + + +def _callable_to_str(obj: Union[str, Callable]) -> str: + """Serialize a callable to a string.""" + if isinstance(obj, str): + return obj + # e.g. numpy.max -> "numpy.max" + return f"{obj.__module__}.{obj.__qualname__}" + + +class ReducerType(Reducer): + """Reducer type for pydantic. + + This just provides a pydantic core schema for a generic callable that accepts an + array and an axis argument and returns an array (of reduced dimensionality). + This serializes/deserializes the callable as a string (module.qualname). + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source: Any, handler: Any) -> Any: + """Get the Pydantic schema for this object.""" + ser_schema = core_schema.plain_serializer_function_ser_schema(_callable_to_str) + return core_schema.no_info_before_validator_function( + _str_to_callable, + # using callable_schema() would be more correct, but prevents dumping schema + core_schema.any_schema(), + serialization=ser_schema, + ) diff --git a/src/ndv/models/data_wrappers/__init__.py b/src/ndv/models/data_wrappers/__init__.py new file mode 100644 index 00000000..30df9865 --- /dev/null +++ b/src/ndv/models/data_wrappers/__init__.py @@ -0,0 +1,3 @@ +from ._data_wrapper import DataWrapper + +__all__ = ["DataWrapper"] diff --git a/src/ndv/models/data_wrappers/_data_wrapper.py b/src/ndv/models/data_wrappers/_data_wrapper.py new file mode 100644 index 00000000..c9819109 --- /dev/null +++ b/src/ndv/models/data_wrappers/_data_wrapper.py @@ -0,0 +1,372 @@ +"""In this module, we provide built-in support for many array types.""" + +from __future__ import annotations + +import json +import logging +import sys +import warnings +from abc import ABC, abstractmethod +from collections.abc import Hashable, Mapping, Sequence +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Protocol, + TypeVar, +) + +import numpy as np +import numpy.typing as npt + +if TYPE_CHECKING: + from collections.abc import Container, Iterator + from typing import Any, TypeAlias, TypeGuard + + import dask.array.core as da + import numpy.typing as npt + import pydantic_core + import pyopencl.array as cl_array + import sparse + import tensorstore as ts + from pydantic import GetCoreSchemaHandler + + Index: TypeAlias = int | slice + + +class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + +ArrayT = TypeVar("ArrayT") +NPArrayLike = TypeVar("NPArrayLike", bound=SupportsIndexing) +_T = TypeVar("_T", bound=type) + + +def _recurse_subclasses(cls: _T) -> Iterator[_T]: + for subclass in cls.__subclasses__(): + yield subclass + yield from _recurse_subclasses(subclass) + + +class DataWrapper(Generic[ArrayT], ABC): + """Interface for wrapping different array-like data types. + + `DataWrapper.create` is a factory method that returns a DataWrapper instance + for the given data type. If your datastore type is not supported, you may implement + a new DataWrapper subclass to handle your data type. To do this, import and + subclass DataWrapper, and (minimally) implement the supports and isel methods. + Ensure that your class is imported before the DataWrapper.create method is called, + and it will be automatically detected and used to wrap your data. + """ + + # Order in which subclasses are checked for support. + # Lower numbers are checked first, and the first supporting subclass is used. + # Default is 50, and fallback to numpy-like duckarray is 100. + # Subclasses can override this to change the priority in which they are checked + PRIORITY: ClassVar[int] = 50 + # These names will be checked when looking for a channel axis + COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c") + # Maximum dimension size consider when guessing the channel axis + MAX_CHANNELS = 16 + + def __init__(self, data: ArrayT) -> None: + self._data = data + + # ----------------------------- Mandatory methods ----------------------------- + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object. + + Any exceptions raised by this method will be suppressed, so it is safe to + directly import necessary dependencies without a try/except block. + """ + + @property + @abstractmethod + def dims(self) -> tuple[Hashable, ...]: + """Return the dimension labels for the data.""" + + @property + @abstractmethod + def coords(self) -> Mapping[Hashable, Sequence]: + """Return the coordinates for the data.""" + + @abstractmethod + def isel(self, index: Mapping[int, int | slice]) -> np.ndarray: + """Return a slice of the data as a numpy array.""" + + def save_as_zarr(self, path: str) -> None: + raise NotImplementedError("Saving as zarr is not supported for this data type") + + @property + def dtype(self) -> np.dtype: + """Return the dtype for the data.""" + try: + return np.dtype(self._data.dtype) # type: ignore + except AttributeError as e: + raise NotImplementedError( + "`dtype` property not properly implemented for DataWrapper of type: " + f"{type(self)}" + ) from e + + # ----------------------------- + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data + + # check subclasses for support + # This allows users to define their own DataWrapper subclasses which will + # be automatically detected (assuming they have been imported by this point) + for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY): + try: + if subclass.supports(data): + logging.debug(f"Using {subclass.__name__} to wrap {type(data)}") + return subclass(data) + except Exception as e: + warnings.warn( + f"Error checking DataWrapper subclass {subclass.__name__}: {e}", + RuntimeWarning, + stacklevel=2, + ) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + @property + def data(self) -> ArrayT: + """Return the data being wrapped.""" + return self._data + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type, handler: GetCoreSchemaHandler + ) -> pydantic_core.CoreSchema: + from pydantic_core import core_schema + + return core_schema.no_info_before_validator_function( + function=cls.create, + schema=core_schema.any_schema(), + ) + + def sizes(self) -> Mapping[Hashable, int]: + """Return the sizes of the dimensions.""" + return {dim: len(self.coords[dim]) for dim in self.dims} + + def guess_channel_axis(self) -> Hashable | None: + """Return the (best guess) axis name for the channel dimension.""" + # for arrays with labeled dimensions, + # see if any of the dimensions are named "channel" + sizes = self.sizes() + for dimkey, val in sizes.items(): + if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: + if val <= self.MAX_CHANNELS: + return self.normalized_axis_key(dimkey) + + # otherwise use the smallest dimension as the channel axis + return min(sizes, key=sizes.get) # type: ignore [arg-type] + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + + # TODO: this needs to be cleared when data.dims changes + @cached_property + def axis_map(self) -> Mapping[Hashable, int]: + """Mapping of ALL valid axis keys to normalized, positive integer keys.""" + axis_index: dict[Hashable, int] = {} + ndims = len(self.dims) + for i, dim in enumerate(self.dims): + axis_index[dim] = i # map dimension label to positive index + axis_index[i] = i # map positive integer index to itself + axis_index[-(ndims - i)] = i # map negative integer index to positive index + return axis_index + + def normalized_axis_key(self, axis: Hashable) -> int: + """Return positive index for `axis` (which can be +/- int or str label).""" + try: + return self.axis_map[axis] + except KeyError as e: + ndims = len(self.dims) + if isinstance(axis, int): + raise IndexError( + f"Axis index {axis} out of bounds for data with {ndims} dimensions" + ) from e + raise IndexError(f"Axis label {axis} not found in data dimensions") from e + + def clear_cache(self) -> None: + """Clear any cached properties.""" + if hasattr(self, "axis_map"): + del self.axis_map + + +########################## + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + """Wrapper for tensorstore.TensorStore objects.""" + + def __init__(self, data: Any) -> None: + super().__init__(data) + + import tensorstore as ts + + self._ts = ts + + spec = self.data.spec().to_json() + dims: Sequence[Hashable] | None = None + self._ts = ts + if (tform := spec.get("transform")) and ("input_labels" in tform): + dims = [str(x) for x in tform["input_labels"]] + elif ( + str(spec.get("driver")).startswith("zarr") + and (zattrs := self.data.kvstore.read(".zattrs").result().value) + and isinstance((zattr_dict := json.loads(zattrs)), dict) + and "_ARRAY_DIMENSIONS" in zattr_dict + ): + dims = zattr_dict["_ARRAY_DIMENSIONS"] + + if isinstance(dims, Sequence) and len(dims) == len(self._data.domain): + self._dims: tuple[Hashable, ...] = tuple(str(x) for x in dims) + self._data = self.data[ts.d[:].label[self._dims]] + else: + self._dims = tuple(range(len(self._data.domain))) + self._coords: Mapping[Hashable, Sequence] = { + i: range(s) for i, s in zip(self._dims, self._data.domain.shape) + } + + @property + def dims(self) -> tuple[Hashable, ...]: + """Return the dimension labels for the data.""" + return self._dims + + @property + def coords(self) -> Mapping[Hashable, Sequence]: + """Return the coordinates for the data.""" + return self._coords + + def sizes(self) -> Mapping[Hashable, int]: + return dict(zip(self._dims, self._data.domain.shape)) + + def isel(self, indexers: Mapping[int, int | slice]) -> np.ndarray: + if not indexers: + slc: slice | tuple = slice(None) + else: + slc = tuple( + indexers.get(i, slice(None)) for i in range(len(self._data.shape)) + ) + result = self._data[slc].read().result() + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper[NPArrayLike]): + """Wrapper for numpy duck array-like objects.""" + + PRIORITY = 100 + + @property + def dims(self) -> tuple[Hashable, ...]: + """Return the dimension labels for the data.""" + return tuple(range(len(self.data.shape))) + + @property + def coords(self) -> Mapping[Hashable, Sequence]: + """Return the coordinates for the data.""" + return {i: range(s) for i, s in enumerate(self.data.shape)} + + def isel(self, indexers: Mapping[int, int | slice]) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) + return self._asarray(self.data[idx]) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or hasattr(obj, "__array__") + ) + and hasattr(obj, "__getitem__") + and hasattr(obj, "shape") + ): + return True + return False + + def _asarray(self, data: npt.ArrayLike) -> np.ndarray: + """Convert data to a numpy array.""" + return np.asarray(data) + + +class SparseArrayWrapper(ArrayLikeWrapper["sparse.Array"]): + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[sparse.COO]: + if (sparse := sys.modules.get("sparse")) and isinstance(obj, sparse.COO): + return True + return False + + def _asarray(self, data: sparse.COO) -> np.ndarray: + return np.asarray(data.todense()) + + +class CLArrayWrapper(ArrayLikeWrapper["cl_array.Array"]): + """Wrapper for pyopencl array objects.""" + + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[cl_array.Array]: + if (cl_array := sys.modules.get("pyopencl.array")) and isinstance( + obj, cl_array.Array + ): + return True + return False + + def _asarray(self, data: cl_array.Array) -> np.ndarray: + return np.asarray(data.get()) + + +class DaskWrapper(ArrayLikeWrapper["da.Array"]): + """Wrapper for dask array objects.""" + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + def _asarray(self, data: da.Array) -> np.ndarray: + return np.asarray(data.compute()) + + def save_as_zarr(self, path: str) -> None: + self._data.to_zarr(url=path) diff --git a/src/ndv/util.py b/src/ndv/util.py index ac72d194..f09617b4 100644 --- a/src/ndv/util.py +++ b/src/ndv/util.py @@ -2,63 +2,50 @@ from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, overload -from qtpy.QtWidgets import QApplication - -from .viewer._viewer import NDViewer +from ndv._views._app import run_app +from ndv.controllers import ArrayViewer if TYPE_CHECKING: - from qtpy.QtCore import QCoreApplication + from typing import Any, Unpack - from .viewer._data_wrapper import DataWrapper + from .models._array_display_model import ArrayDisplayModel, ArrayDisplayModelKwargs + from .models.data_wrappers import DataWrapper +@overload +def imshow( + data: Any | DataWrapper, /, display_model: ArrayDisplayModel = ... +) -> ArrayViewer: ... +@overload +def imshow( + data: Any | DataWrapper, /, **kwargs: Unpack[ArrayDisplayModelKwargs] +) -> ArrayViewer: ... def imshow( data: Any | DataWrapper, - cmap: Any | None = None, - *, - channel_mode: Literal["mono", "composite", "auto"] = "auto", -) -> NDViewer: + /, + display_model: ArrayDisplayModel | None = None, + **kwargs: Unpack[ArrayDisplayModelKwargs], +) -> ArrayViewer: """Display an array or DataWrapper in a new NDViewer window. Parameters ---------- data : Any | DataWrapper - The data to be displayed. If not a DataWrapper, it will be wrapped in one. - cmap : Any | None, optional - The colormap(s) to use for displaying the data. - channel_mode : Literal['mono', 'composite'], optional - The initial mode for displaying the channels. By default "mono" will be - used unless a cmap is provided, in which case "composite" will be used. + The data to be displayed. Any ArrayLike object or an `ndv.DataWrapper`. + display_model: ArrayDisplayModel, optional + The display model to use. If not provided, a new one will be created. + kwargs : Unpack[ArrayDisplayModelKwargs] + Additional keyword arguments to pass to the NDViewer Returns ------- - NDViewer + ViewerController The viewer window. """ - app, should_exec = _get_app() - if cmap is not None: - channel_mode = "composite" - if not isinstance(cmap, (list, tuple)): - cmap = [cmap] - elif channel_mode == "auto": - channel_mode = "mono" - viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode) + viewer = ArrayViewer(data, display_model, **kwargs) viewer.show() - viewer.raise_() - if should_exec: - app.exec() - return viewer - -def _get_app() -> tuple[QCoreApplication, bool]: - is_ipython = False - if (app := QApplication.instance()) is None: - app = QApplication([]) - app.setApplicationName("ndv") - elif (ipy := sys.modules.get("IPython")) and (shell := ipy.get_ipython()): - is_ipython = str(shell.active_eventloop).startswith("qt") - - return app, not is_ipython + run_app() + return viewer diff --git a/src/ndv/v1/__init__.py b/src/ndv/v1/__init__.py new file mode 100644 index 00000000..c3fecb64 --- /dev/null +++ b/src/ndv/v1/__init__.py @@ -0,0 +1,7 @@ +"""Here temporarily to allow for a smooth transition to the new viewer.""" + +from ._old_data_wrapper import DataWrapper +from ._old_viewer import NDViewer +from .util import imshow + +__all__ = ["DataWrapper", "NDViewer", "imshow"] diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/v1/_old_data_wrapper.py similarity index 99% rename from src/ndv/viewer/_data_wrapper.py rename to src/ndv/v1/_old_data_wrapper.py index 502389c2..34294701 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/v1/_old_data_wrapper.py @@ -31,7 +31,9 @@ import zarr from torch._tensor import Tensor - from ._dims_slider import Index, Indices, Sizes + Index = int | slice + Indices = Mapping[Any, Index] + Sizes = Mapping[Any, int] _T_contra = TypeVar("_T_contra", contravariant=True) @@ -130,7 +132,7 @@ def isel_async( """Asynchronous version of isel.""" return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) - def guess_channel_axis(self) -> Hashable | None: + def guess_channel_axis(self) -> Any | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/v1/_old_viewer.py similarity index 94% rename from src/ndv/viewer/_viewer.py rename to src/ndv/v1/_old_viewer.py index 8b976559..f0f9a4de 100755 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/v1/_old_viewer.py @@ -12,31 +12,40 @@ from superqt import QCollapsible, QElidingLabel, QIconifyIcon, ensure_main_thread from superqt.utils import qthrottled, signals_blocked -from ndv.viewer._components import ( +from ndv._views import get_array_canvas_class + +from ._old_data_wrapper import DataWrapper +from ._qt._components import ( ChannelMode, ChannelModeButton, DimToggleButton, QSpinner, ROIButton, ) - -from ._backends import get_canvas_class -from ._data_wrapper import DataWrapper -from ._dims_slider import DimsSliders -from ._lut_control import LutControl +from ._qt._dims_slider import DimsSliders +from ._qt._lut_control import LutControl if TYPE_CHECKING: - from collections.abc import Hashable, Iterable, Sequence + from collections.abc import Iterable, Mapping, Sequence from concurrent.futures import Future from typing import Any, Callable, TypeAlias from qtpy.QtCore import QObject from qtpy.QtGui import QCloseEvent, QKeyEvent - from ._backends._protocols import CanvasElement, PCanvas, PImageHandle, PRoiHandle - from ._dims_slider import DimKey, Indices, Sizes + from ndv._views.bases.graphics._canvas import ArrayCanvas + from ndv._views.bases.graphics._canvas_elements import ( + CanvasElement, + ImageHandle, + RoiHandle, + ) - ImgKey: TypeAlias = Hashable + DimKey = int + Index = int | slice + Indices = Mapping[Any, Index] + Sizes = Mapping[Any, int] + + ImgKey: TypeAlias = Any # any mapping of dimensions to sizes SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] @@ -125,7 +134,7 @@ def __init__( self._data_wrapper: DataWrapper | None = None # mapping of key to a list of objects that control image nodes in the canvas - self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) + self._img_handles: defaultdict[ImgKey, list[ImageHandle]] = defaultdict(list) # mapping of same keys to the LutControl objects control image display props self._lut_ctrls: dict[ImgKey, LutControl] = {} self._lut_ctrl_state: dict[ImgKey, dict] = {} @@ -151,7 +160,7 @@ def __init__( # Canvas selection self._selection: CanvasElement | None = None # ROI - self._roi: PRoiHandle | None = None + self._roi: RoiHandle | None = None # WIDGETS ---------------------------------------------------- @@ -178,9 +187,9 @@ def __init__( # place to display arbitrary text self._hover_info_label = QLabel("", self) # the canvas that displays the images - self._canvas: PCanvas = get_canvas_class()() + self._canvas: ArrayCanvas = get_array_canvas_class()() self._canvas.set_ndim(self._ndims) - self._qcanvas = self._canvas.qwidget() + self._qcanvas = self._canvas.frontend_widget() # Install an event filter so we can intercept mouse/key events self._qcanvas.installEventFilter(self) @@ -337,7 +346,8 @@ def set_roi( # Remove the old ROI if self._roi: self._roi.remove() - + color = cmap.Color(color) if color is not None else None + border_color = cmap.Color(border_color) if border_color is not None else None self._roi = self._canvas.add_roi( vertices=vertices, color=color, border_color=border_color ) @@ -443,7 +453,7 @@ def _toggle_3d(self) -> None: self._add_roi_btn.setEnabled(self._ndims == 2) # FIXME: When toggling 2D again, ROIs cannot be selected if self._roi: - self._roi.visible = self._ndims == 2 + self._roi.set_visible(self._ndims == 2) def _update_slider_ranges(self) -> None: """Set the maximum values of the sliders. @@ -553,7 +563,7 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: datum = self._reduce_data_for_display(data) if handles := self._img_handles[imkey]: for handle in handles: - handle.data = datum + handle.set_data(datum) if ctrl := self._lut_ctrls.get(imkey, None): ctrl.update_autoscale() else: @@ -563,9 +573,13 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: else GRAYS ) if datum.ndim == 2: - handles.append(self._canvas.add_image(datum, cmap=cm)) + handle = self._canvas.add_image(datum) + handle.set_cmap(cm) + handles.append(handle) elif datum.ndim == 3: - handles.append(self._canvas.add_volume(datum, cmap=cm)) + handle = self._canvas.add_volume(datum) + handle.set_cmap(cm) + handles.append(handle) if imkey not in self._lut_ctrls: ch_index = index.get(self._channel_axis, 0) self._lut_ctrls[imkey] = c = LutControl( @@ -672,7 +686,7 @@ def _begin_roi(self, event: QMouseEvent) -> bool: ev_pos = event.position() pos = self._canvas.canvas_to_world((ev_pos.x(), ev_pos.y())) self._roi.move(pos) - self._roi.visible = True + self._roi.set_visible(True) return False def _press_element(self, event: QMouseEvent) -> bool: @@ -682,19 +696,19 @@ def _press_element(self, event: QMouseEvent) -> bool: elements = self._canvas.elements_at(ev_pos) # Deselect prior selection before editing new selection if self._selection: - self._selection.selected = False + self._selection.set_selected(False) for e in elements: - if e.can_select: + if e.can_select(): e.start_move(pos) # Select new selection self._selection = e - self._selection.selected = True + self._selection.set_selected(True) return False return False def _move_selection(self, event: QMouseEvent) -> bool: if event.buttons() == Qt.MouseButton.LeftButton: - if self._selection and self._selection.selected: + if self._selection and self._selection.selected(): ev_pos = event.pos() pos = self._canvas.canvas_to_world((ev_pos.x(), ev_pos.y())) self._selection.move(pos) @@ -714,7 +728,7 @@ def _update_cursor(self, event: QMouseEvent) -> bool: pos = (event.pos().x(), event.pos().y()) for e in self._canvas.elements_at(pos): if (pref := e.cursor_at(pos)) is not None: - self._qcanvas.setCursor(pref) + self._qcanvas.setCursor(pref.to_qt()) return False # Otherwise, normal cursor self._qcanvas.setCursor(Qt.CursorShape.ArrowCursor) @@ -744,7 +758,7 @@ def _update_hover_info(self, event: QMouseEvent) -> bool: # texture has already been reduced to 2D). But a more complete # implementation would gather the full current nD index and query # the data source directly. - value = handle.data[y, x] + value = handle.data()[y, x] if isinstance(value, (np.floating, float)): value = f"{value:.2f}" channels.append(f" {n}: {value}") diff --git a/src/ndv/v1/_qt/__init__.py b/src/ndv/v1/_qt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ndv/viewer/_components.py b/src/ndv/v1/_qt/_components.py similarity index 100% rename from src/ndv/viewer/_components.py rename to src/ndv/v1/_qt/_components.py diff --git a/src/ndv/viewer/_dims_slider.py b/src/ndv/v1/_qt/_dims_slider.py similarity index 91% rename from src/ndv/viewer/_dims_slider.py rename to src/ndv/v1/_qt/_dims_slider.py index 374cb808..f639170c 100644 --- a/src/ndv/viewer/_dims_slider.py +++ b/src/ndv/v1/_qt/_dims_slider.py @@ -24,21 +24,10 @@ from superqt.utils import signals_blocked if TYPE_CHECKING: - from collections.abc import Hashable, Mapping - from typing import TypeAlias + from collections.abc import Mapping from qtpy.QtGui import QResizeEvent - # any hashable represent a single dimension in an ND array - DimKey: TypeAlias = Hashable - # any object that can be used to index a single dimension in an ND array - Index: TypeAlias = int | slice - # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) - # this object is used frequently to query or set the currently displayed slice - Indices: TypeAlias = Mapping[DimKey, Index] - # mapping of dimension keys to the maximum value for that dimension - Sizes: TypeAlias = Mapping[DimKey, int] - SS = """ QSlider::groove:horizontal { @@ -155,7 +144,7 @@ class DimsSlider(QWidget): valueChanged = Signal(object, object) # where object is int | slice - def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None: + def __init__(self, dimension_key: int, parent: QWidget | None = None) -> None: super().__init__(parent) self.setStyleSheet(SS) self._slice_mode = False @@ -242,7 +231,7 @@ def setRange(self, min_val: int, max_val: int) -> None: self._int_slider.setRange(min_val, max_val) self._slice_slider.setRange(min_val, max_val) - def value(self) -> Index: + def value(self) -> int | slice: if not self._slice_mode: return self._int_slider.value() # type: ignore start, *_, stop = cast("tuple[int, ...]", self._slice_slider.value()) @@ -250,7 +239,7 @@ def value(self) -> Index: return start return slice(start, stop) - def setValue(self, val: Index) -> None: + def setValue(self, val: int | slice) -> None: # variant of setValue that always updates the maximum self._set_slice_mode(isinstance(val, slice)) if self._lock_btn.isChecked(): @@ -265,7 +254,7 @@ def setValue(self, val: Index) -> None: self._int_slider.setValue(val) # self._slice_slider.setValue((val, val + 1)) - def forceValue(self, val: Index) -> None: + def forceValue(self, val: int | slice) -> None: """Set value and increase range if necessary.""" if isinstance(val, slice): if isinstance(val.start, int): @@ -368,10 +357,10 @@ class DimsSliders(QWidget): def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) - self._locks_visible: bool | Mapping[DimKey, bool] = False - self._sliders: dict[DimKey, DimsSlider] = {} - self._current_index: dict[DimKey, Index] = {} - self._invisible_dims: set[DimKey] = set() + self._locks_visible: bool | Mapping[int, bool] = False + self._sliders: dict[int, DimsSlider] = {} + self._current_index: dict[int, int | slice] = {} + self._invisible_dims: set[int] = set() self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) @@ -379,19 +368,19 @@ def __init__(self, parent: QWidget | None = None) -> None: layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) - def __contains__(self, key: DimKey) -> bool: + def __contains__(self, key: int) -> bool: """Return True if the dimension key is present in the DimsSliders.""" return key in self._sliders - def slider(self, key: DimKey) -> DimsSlider: + def slider(self, key: int) -> DimsSlider: """Return the DimsSlider widget for the given dimension key.""" return self._sliders[key] - def value(self) -> Indices: + def value(self) -> Mapping[int, int | slice]: """Return mapping of {dim_key -> current index} for each dimension.""" return self._current_index.copy() - def setValue(self, values: Indices) -> None: + def setValue(self, values: Mapping[int, int | slice]) -> None: """Set the current index for each dimension. Parameters @@ -410,11 +399,11 @@ def setValue(self, values: Indices) -> None: if val := self.value(): self.valueChanged.emit(val) - def minima(self) -> Sizes: + def minima(self) -> Mapping[int, int]: """Return mapping of {dim_key -> minimum value} for each dimension.""" return {k: v._int_slider.minimum() for k, v in self._sliders.items()} - def setMinima(self, values: Sizes) -> None: + def setMinima(self, values: Mapping[int, int]) -> None: """Set the minimum value for each dimension. Parameters @@ -427,11 +416,11 @@ def setMinima(self, values: Sizes) -> None: self.add_dimension(name) self._sliders[name].setMinimum(min_val) - def maxima(self) -> Sizes: + def maxima(self) -> Mapping[int, int]: """Return mapping of {dim_key -> maximum value} for each dimension.""" return {k: v._int_slider.maximum() for k, v in self._sliders.items()} - def setMaxima(self, values: Sizes) -> None: + def setMaxima(self, values: Mapping[int, int]) -> None: """Set the maximum value for each dimension. Parameters @@ -444,14 +433,14 @@ def setMaxima(self, values: Sizes) -> None: self.add_dimension(name) self._sliders[name].setMaximum(max_val) - def set_locks_visible(self, visible: bool | Mapping[DimKey, bool]) -> None: + def set_locks_visible(self, visible: bool | Mapping[int, bool]) -> None: """Set the visibility of the lock buttons for all dimensions.""" self._locks_visible = visible for dim, slider in self._sliders.items(): viz = visible if isinstance(visible, bool) else visible.get(dim, False) slider._lock_btn.setVisible(viz) - def add_dimension(self, key: DimKey, val: Index | None = None) -> None: + def add_dimension(self, key: int, val: int | slice | None = None) -> None: """Add a new dimension to the DimsSliders widget. Parameters @@ -481,7 +470,7 @@ def add_dimension(self, key: DimKey, val: Index | None = None) -> None: slider.valueChanged.connect(self._on_dim_slider_value_changed) cast("QVBoxLayout", self.layout()).addWidget(slider) - def set_dimension_visible(self, key: DimKey, visible: bool) -> None: + def set_dimension_visible(self, key: int, visible: bool) -> None: """Set the visibility of a dimension in the DimsSliders widget. Once a dimension is hidden, it will not be shown again until it is explicitly @@ -497,7 +486,7 @@ def set_dimension_visible(self, key: DimKey, visible: bool) -> None: if key in self._sliders: self._sliders[key].setVisible(visible) - def remove_dimension(self, key: DimKey) -> None: + def remove_dimension(self, key: int) -> None: """Remove a dimension from the DimsSliders widget.""" try: slider = self._sliders.pop(key) @@ -514,11 +503,11 @@ def clear(self) -> None: self._current_index = {} self._invisible_dims = set() - def _on_dim_slider_value_changed(self, key: DimKey, value: Index) -> None: + def _on_dim_slider_value_changed(self, key: int, value: int | slice) -> None: self._current_index[key] = value self.valueChanged.emit(self.value()) - def add_or_update_dimension(self, key: DimKey, value: Index) -> None: + def add_or_update_dimension(self, key: int, value: int | slice) -> None: """Add a dimension if it doesn't exist, otherwise update the value.""" if key in self._sliders: self._sliders[key].forceValue(value) diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/v1/_qt/_lut_control.py similarity index 85% rename from src/ndv/viewer/_lut_control.py rename to src/ndv/v1/_qt/_lut_control.py index 75e823aa..37d439e5 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/v1/_qt/_lut_control.py @@ -16,7 +16,7 @@ import cmap - from ._backends._protocols import PImageHandle + from ndv._views.bases.graphics._canvas_elements import ImageHandle class CmapCombo(QColormapComboBox): @@ -36,13 +36,13 @@ class LutControl(QWidget): def __init__( self, name: str = "", - handles: Iterable[PImageHandle] = (), + handles: Iterable[ImageHandle] = (), parent: QWidget | None = None, cmaplist: Iterable[Any] = (), auto_clim: bool = True, ) -> None: super().__init__(parent) - self._handles = handles + self._handles = list(handles) self._name = name self._visible = QCheckBox(name) @@ -51,8 +51,8 @@ def __init__( self._cmap = CmapCombo() self._cmap.currentColormapChanged.connect(self._on_cmap_changed) - for handle in handles: - self._cmap.addColormap(handle.cmap) + for handle in self._handles: + self._cmap.addColormap(handle.cmap()) for color in cmaplist: self._cmap.addColormap(color) @@ -100,17 +100,17 @@ def autoscaleChecked(self) -> bool: def _on_clims_changed(self, clims: tuple[float, float]) -> None: self._auto_clim.setChecked(False) for handle in self._handles: - handle.clim = clims + handle.set_clims(clims) def _on_visible_changed(self, visible: bool) -> None: for handle in self._handles: - handle.visible = visible + handle.set_visible(visible) if visible: self.update_autoscale() def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: for handle in self._handles: - handle.cmap = cmap + handle.set_cmap(cmap) def update_autoscale(self) -> None: if ( @@ -123,15 +123,21 @@ def update_autoscale(self) -> None: # find the min and max values for the current channel clims = [np.inf, -np.inf] for handle in self._handles: - clims[0] = min(clims[0], np.nanmin(handle.data)) - clims[1] = max(clims[1], np.nanmax(handle.data)) + data = handle.data() + clims[0] = min(clims[0], np.nanmin(data)) + clims[1] = max(clims[1], np.nanmax(data)) mi, ma = tuple(int(x) for x in clims) for handle in self._handles: - handle.clim = (mi, ma) + handle.set_clims((mi, ma)) # set the slider values to the new clims with signals_blocked(self._clims): self._clims.setMinimum(min(mi, self._clims.minimum())) self._clims.setMaximum(max(ma, self._clims.maximum())) self._clims.setValue((mi, ma)) + + def add_handle(self, handle: ImageHandle) -> None: + self._handles.append(handle) + self._cmap.addColormap(handle.cmap()) + self.update_autoscale() diff --git a/src/ndv/v1/util.py b/src/ndv/v1/util.py new file mode 100644 index 00000000..562d3790 --- /dev/null +++ b/src/ndv/v1/util.py @@ -0,0 +1,65 @@ +"""Utility and convenience functions.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from qtpy.QtCore import QCoreApplication + + from . import NDViewer + from ._old_data_wrapper import DataWrapper + + +def imshow( + data: Any | DataWrapper, + cmap: Any | None = None, + *, + channel_mode: Literal["mono", "composite", "auto"] = "auto", +) -> NDViewer: + """Display an array or DataWrapper in a new NDViewer window. + + Parameters + ---------- + data : Any | DataWrapper + The data to be displayed. If not a DataWrapper, it will be wrapped in one. + cmap : Any | None, optional + The colormap(s) to use for displaying the data. + channel_mode : Literal['mono', 'composite'], optional + The initial mode for displaying the channels. By default "mono" will be + used unless a cmap is provided, in which case "composite" will be used. + + Returns + ------- + NDViewer + The viewer window. + """ + from . import NDViewer + + app, should_exec = _get_app() + if cmap is not None: + channel_mode = "composite" + if not isinstance(cmap, (list, tuple)): + cmap = [cmap] + elif channel_mode == "auto": + channel_mode = "mono" + viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode) + viewer.show() + viewer.raise_() + if should_exec: + app.exec() + return viewer + + +def _get_app() -> tuple[QCoreApplication, bool]: + from qtpy.QtWidgets import QApplication + + is_ipython = False + if (app := QApplication.instance()) is None: + app = QApplication([]) + app.setApplicationName("ndv") + elif (ipy := sys.modules.get("IPython")) and (shell := ipy.get_ipython()): + is_ipython = str(shell.active_eventloop).startswith("qt") + + return app, not is_ipython diff --git a/src/ndv/viewer/__init__.py b/src/ndv/viewer/__init__.py deleted file mode 100644 index 09c94709..00000000 --- a/src/ndv/viewer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""viewer source.""" diff --git a/src/ndv/viewer/_backends/__init__.py b/src/ndv/viewer/_backends/__init__.py deleted file mode 100755 index c60046e0..00000000 --- a/src/ndv/viewer/_backends/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -import importlib -import importlib.util -import os -import sys -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from ndv.viewer._backends._protocols import PCanvas - - -def get_canvas_class(backend: str | None = None) -> type[PCanvas]: - backend = backend or os.getenv("NDV_CANVAS_BACKEND", None) - if backend == "vispy" or (backend is None and "vispy" in sys.modules): - from ._vispy import VispyViewerCanvas - - return VispyViewerCanvas - - if backend == "pygfx" or (backend is None and "pygfx" in sys.modules): - from ._pygfx import PyGFXViewerCanvas - - return PyGFXViewerCanvas - - if backend is None: - if importlib.util.find_spec("vispy") is not None: - from ._vispy import VispyViewerCanvas - - return VispyViewerCanvas - - if importlib.util.find_spec("pygfx") is not None: - from ._pygfx import PyGFXViewerCanvas - - return PyGFXViewerCanvas - - raise RuntimeError("No canvas backend found") diff --git a/src/ndv/viewer/_backends/_protocols.py b/src/ndv/viewer/_backends/_protocols.py deleted file mode 100755 index dd51f855..00000000 --- a/src/ndv/viewer/_backends/_protocols.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Literal, Protocol - -if TYPE_CHECKING: - from collections.abc import Sequence - - import cmap - import numpy as np - from qtpy.QtCore import Qt - from qtpy.QtWidgets import QWidget - - -class CanvasElement(Protocol): - """Protocol defining an interactive element on the Canvas.""" - - @property - def visible(self) -> bool: - """Defines whether the element is visible on the canvas.""" - - @visible.setter - def visible(self, visible: bool) -> None: - """Sets element visibility.""" - - @property - def can_select(self) -> bool: - """Defines whether the element can be selected.""" - - @property - def selected(self) -> bool: - """Returns element selection status.""" - - @selected.setter - def selected(self, selected: bool) -> None: - """Sets element selection status.""" - - def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: - """Returns the element's cursor preference at the provided position.""" - - def start_move(self, pos: Sequence[float]) -> None: - """ - Behavior executed at the beginning of a "move" operation. - - In layman's terms, this is the behavior executed during the the "click" - of a "click-and-drag". - """ - - def move(self, pos: Sequence[float]) -> None: - """ - Behavior executed throughout a "move" operation. - - In layman's terms, this is the behavior executed during the "drag" - of a "click-and-drag". - """ - - def remove(self) -> None: - """Removes the element from the canvas.""" - - -class PImageHandle(CanvasElement, Protocol): - @property - def data(self) -> np.ndarray: ... - @data.setter - def data(self, data: np.ndarray) -> None: ... - @property - def clim(self) -> Any: ... - @clim.setter - def clim(self, clims: tuple[float, float]) -> None: ... - @property - def cmap(self) -> cmap.Colormap: ... - @cmap.setter - def cmap(self, cmap: cmap.Colormap) -> None: ... - - -class PRoiHandle(CanvasElement, Protocol): - @property - def vertices(self) -> Sequence[Sequence[float]]: ... - @vertices.setter - def vertices(self, data: Sequence[Sequence[float]]) -> None: ... - @property - def color(self) -> Any: ... - @color.setter - def color(self, color: cmap.Color) -> None: ... - @property - def border_color(self) -> Any: ... - @border_color.setter - def border_color(self, color: cmap.Color) -> None: ... - - -class PCanvas(Protocol): - def __init__(self) -> None: ... - def set_ndim(self, ndim: Literal[2, 3]) -> None: ... - def set_range( - self, - x: tuple[float, float] | None = None, - y: tuple[float, float] | None = None, - z: tuple[float, float] | None = None, - margin: float = ..., - ) -> None: ... - def refresh(self) -> None: ... - def qwidget(self) -> QWidget: ... - def add_image( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... - ) -> PImageHandle: ... - def add_volume( - self, data: np.ndarray | None = ..., cmap: cmap.Colormap | None = ... - ) -> PImageHandle: ... - def canvas_to_world( - self, pos_xy: tuple[float, float] - ) -> tuple[float, float, float]: - """Map XY canvas position (pixels) to XYZ coordinate in world space.""" - - def elements_at(self, pos_xy: tuple[float, float]) -> list[CanvasElement]: ... - def add_roi( - self, - vertices: Sequence[tuple[float, float]] | None = None, - color: cmap.Color | None = None, - border_color: cmap.Color | None = None, - ) -> PRoiHandle: ... diff --git a/tests/conftest.py b/tests/conftest.py index 373acbeb..7bd805e6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,79 @@ +from __future__ import annotations + import gc +import importlib +import importlib.util +import os from collections.abc import Iterator -from typing import TYPE_CHECKING +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any +from unittest.mock import patch import pytest +from ndv._views import gui_frontend +from ndv._views._app import GuiFrontend + if TYPE_CHECKING: + from asyncio import AbstractEventLoop + from collections.abc import Iterator + + import wx from pytest import FixtureRequest from qtpy.QtWidgets import QApplication -@pytest.fixture(autouse=True) -def find_leaks(request: "FixtureRequest", qapp: "QApplication") -> Iterator[None]: +@pytest.fixture +def asyncio_app() -> Iterator[AbstractEventLoop]: + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + + +@pytest.fixture +def wxapp() -> Iterator[wx.App]: + import wx + + app = wx.App() + yield app + # app.ExitMainLoop() + + +@pytest.fixture +def any_app(request: pytest.FixtureRequest) -> Iterator[Any]: + # this fixture will use the appropriate application depending on the env var + # NDV_GUI_FRONTEND='qt' pytest + # NDV_GUI_FRONTEND='jupyter' pytest + try: + frontend = gui_frontend() + except RuntimeError: + # if we don't find any frontend, and jupyter is available, use that + # since it requires very little setup + if importlib.util.find_spec("jupyter"): + os.environ["NDV_GUI_FRONTEND"] = "jupyter" + gui_frontend.cache_clear() + + frontend = gui_frontend() + + if frontend == GuiFrontend.QT: + app = request.getfixturevalue("qapp") + qtbot = request.getfixturevalue("qtbot") + with patch.object(app, "exec", lambda *_: None): + with _catch_qt_leaks(request, app): + yield app, qtbot + elif frontend == GuiFrontend.JUPYTER: + yield request.getfixturevalue("asyncio_app") + elif frontend == GuiFrontend.WX: + yield request.getfixturevalue("wxapp") + else: + raise RuntimeError("No GUI frontend found") + + +@contextmanager +def _catch_qt_leaks(request: FixtureRequest, qapp: QApplication) -> Iterator[None]: """Run after each test to ensure no widgets have been left around. When this test fails, it means that a widget being tested has an issue closing @@ -29,7 +92,15 @@ def find_leaks(request: "FixtureRequest", qapp: "QApplication") -> Iterator[None # if the test failed, don't worry about checking widgets if request.session.testsfailed - failures_before: return - remaining = qapp.topLevelWidgets() + try: + from vispy.app.backends._qt import CanvasBackendDesktop + + allow: tuple[type, ...] = (CanvasBackendDesktop,) + except (ImportError, RuntimeError): + allow = () + + # This is a known widget that is not cleaned up properly + remaining = [w for w in qapp.topLevelWidgets() if not isinstance(w, allow)] if len(remaining) > nbefore: test_node = request.node diff --git a/tests/test_controller.py b/tests/test_controller.py new file mode 100644 index 00000000..32dad47d --- /dev/null +++ b/tests/test_controller.py @@ -0,0 +1,182 @@ +"""Test controller without canavs or gui frontend""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, cast, no_type_check +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest + +from ndv._types import MouseMoveEvent +from ndv._views import _app, gui_frontend +from ndv._views.bases._array_view import ArrayView +from ndv._views.bases._lut_view import LutView +from ndv._views.bases.graphics._canvas import ArrayCanvas, HistogramCanvas +from ndv._views.bases.graphics._canvas_elements import ImageHandle +from ndv.controllers import ArrayViewer +from ndv.models._array_display_model import ArrayDisplayModel, ChannelMode +from ndv.models._lut_model import LUTModel + +if TYPE_CHECKING: + from ndv.controllers._channel_controller import ChannelController + + +def _get_mock_canvas() -> ArrayCanvas: + mock = MagicMock(spec=ArrayCanvas) + handle = MagicMock(spec=ImageHandle) + handle.data.return_value = np.zeros((10, 10)).astype(np.uint8) + mock.add_image.return_value = handle + return mock + + +def _get_mock_hist_canvas() -> HistogramCanvas: + return MagicMock(spec=HistogramCanvas) + + +def _get_mock_view(*_: Any) -> ArrayView: + mock = MagicMock(spec=ArrayView) + lut_mock = MagicMock(spec=LutView) + mock.add_lut_view.return_value = lut_mock + return mock + + +def _patch_views(f: Callable) -> Callable: + f = patch.object(_app, "get_array_canvas_class", lambda: _get_mock_canvas)(f) + f = patch.object(_app, "get_array_view_class", lambda: _get_mock_view)(f) + f = patch.object(_app, "get_histogram_canvas_class", lambda: _get_mock_hist_canvas)(f) # fmt: skip # noqa + return f + + +@no_type_check +@_patch_views +def test_controller() -> None: + SHAPE = (10, 4, 10, 10) + ctrl = ArrayViewer() + model = ctrl.display_model + mock_view = ctrl.view + mock_view.create_sliders.assert_not_called() + + data = np.empty(SHAPE) + ctrl.data = data + wrapper = ctrl._data_model.data_wrapper + + # showing the controller shows the view + ctrl.show() + mock_view.set_visible.assert_called_once_with(True) + + # sliders are first created with the shape of the data + ranges = {i: range(s) for i, s in enumerate(SHAPE)} + mock_view.create_sliders.assert_called_once_with(ranges) + # visible-axis sliders are hidden + # (2,3) because model.visible_axes is set to (-2, -1) and ndim is 4 + mock_view.hide_sliders.assert_called_once_with((2, 3), show_remainder=True) + # channel mode is set to default (which is currently grayscale) + mock_view.set_channel_mode.assert_called_once_with(model.channel_mode) + # data info is set + mock_view.set_data_info.assert_called_once_with(wrapper.summary_info()) + model.current_index.assign({0: 1}) + + # changing visible axes updates which sliders are visible + model.visible_axes = (0, 3) + mock_view.hide_sliders.assert_called_with((0, 3), show_remainder=True) + + # changing the channel mode updates the sliders and updates the view combobox + mock_view.hide_sliders.reset_mock() + model.channel_mode = "composite" + mock_view.set_channel_mode.assert_called_with(ChannelMode.COMPOSITE) + mock_view.hide_sliders.assert_called_once_with( + (0, 3, model.channel_axis), show_remainder=True + ) + model.channel_mode = ChannelMode.GRAYSCALE + mock_view.hide_sliders.assert_called_with((0, 3), show_remainder=True) + + # when the view changes the current index, the model is updated + idx = {0: 1, 1: 2, 3: 8} + mock_view.current_index.return_value = idx + ctrl._on_view_current_index_changed() + assert model.current_index == idx + + # when the view changes the channel mode, the model is updated + assert model.channel_mode == ChannelMode.GRAYSCALE + ctrl._on_view_channel_mode_changed(ChannelMode.COMPOSITE) + assert model.channel_mode == ChannelMode.COMPOSITE + + # setting a new ArrayDisplay model updates the appropriate view widgets + ch_ctrl = cast("ChannelController", ctrl._lut_controllers[None]) + ch_ctrl.lut_views[0].set_colormap_without_signal.reset_mock() + ctrl.display_model = ArrayDisplayModel(default_lut=LUTModel(cmap="green")) + # fails + # ch_ctrl.lut_views[0].set_colormap_without_signal.assert_called_once() + + +@no_type_check +@_patch_views +def test_canvas() -> None: + SHAPE = (10, 4, 10, 10) + data = np.empty(SHAPE) + ctrl = ArrayViewer() + mock_canvas = ctrl._canvas + + mock_view = ctrl.view + ctrl.data = data + + # clicking the reset zoom button calls set_range on the canvas + ctrl._on_view_reset_zoom_clicked() + mock_canvas.set_range.assert_called_once_with() + + # hovering on the canvas updates the hover info in the view + mock_canvas.canvas_to_world.return_value = (1, 2, 3) + ctrl._on_canvas_mouse_moved(MouseMoveEvent(1, 2)) + mock_canvas.canvas_to_world.assert_called_once_with((1, 2)) + mock_view.set_hover_info.assert_called_once_with("[2, 1] 0") + + +@no_type_check +@_patch_views +def test_histogram_controller() -> None: + ctrl = ArrayViewer() + mock_view = ctrl.view + + ctrl.data = np.zeros((10, 4, 10, 10)).astype(np.uint8) + + # adding a histogram tells the view to add a histogram, and updates the data + ctrl._add_histogram() + mock_view.add_histogram.assert_called_once() + mock_histogram = ctrl._histogram + mock_histogram.set_data.assert_called_once() + + # changing the index updates the histogram + mock_histogram.set_data.reset_mock() + ctrl.display_model.current_index.assign({0: 1, 1: 2, 3: 3}) + mock_histogram.set_data.assert_called_once() + + # switching to composite mode puts the histogram view in the + # lut controller for all channels (this may change) + ctrl.display_model.channel_mode = ChannelMode.COMPOSITE + assert mock_histogram in ctrl._lut_controllers[0].lut_views + + +@pytest.mark.usefixtures("any_app") +def test_array_viewer_with_app() -> None: + """Example usage of new mvc pattern.""" + viewer = ArrayViewer() + assert gui_frontend() in type(viewer._view).__name__.lower() + viewer.show() + + data = np.random.randint(0, 255, size=(10, 10, 10, 10, 10), dtype="uint8") + viewer.data = data + + # test changing current index via the view + index_mock = Mock() + viewer.display_model.current_index.value_changed.connect(index_mock) + index = {0: 4, 1: 1, 2: 2} + # setting the index should trigger the signal, only once + viewer._view.set_current_index(index) + index_mock.assert_called_once() + for k, v in index.items(): + assert viewer.display_model.current_index[k] == v + # setting again should not trigger the signal + index_mock.reset_mock() + viewer._view.set_current_index(index) + index_mock.assert_not_called() diff --git a/tests/test_examples.py b/tests/test_examples.py index 5250ddaa..26325f3d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -4,21 +4,29 @@ from pathlib import Path import pytest -from qtpy.QtWidgets import QApplication -EXAMPLES = Path(__file__).parent.parent / "examples" -EXAMPLES_PY = list(EXAMPLES.glob("*.py")) +try: + import pytestqt + + if pytestqt.qt_compat.qt_api.pytest_qt_api.startswith("pyside"): + pytest.skip( + "viewer still occasionally segfaults with pyside", allow_module_level=True + ) +except ImportError: + pytest.skip("This module requires qt frontend", allow_module_level=True) -@pytest.fixture -def no_qapp_exec(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(QApplication, "exec", lambda *_: None) + +EXAMPLES = Path(__file__).parent.parent / "examples" +EXAMPLES_PY = list(EXAMPLES.glob("*.py")) @pytest.mark.allow_leaks -@pytest.mark.usefixtures("no_qapp_exec") +@pytest.mark.usefixtures("any_app") @pytest.mark.parametrize("example", EXAMPLES_PY, ids=lambda x: x.name) -def test_example(qapp: QApplication, example: Path) -> None: +@pytest.mark.filterwarnings("ignore:Downcasting integer data") +@pytest.mark.filterwarnings("ignore:.*Falling back to CPUScaledTexture") +def test_example(example: Path) -> None: try: runpy.run_path(str(example)) except ImportError as e: diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..6f4094e0 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,25 @@ +from unittest.mock import Mock + +from ndv.models._array_display_model import ArrayDisplayModel + + +def test_array_display_model() -> None: + m = ArrayDisplayModel() + + mock = Mock() + m.events.channel_axis.connect(mock) + m.current_index.item_added.connect(mock) + m.current_index.item_changed.connect(mock) + + m.channel_axis = 4 + mock.assert_called_once_with(4, None) # new, old + mock.reset_mock() + m.current_index["5"] = 1 + mock.assert_called_once_with(5, 1) # key, value + mock.reset_mock() + m.current_index[5] = 4 + mock.assert_called_once_with(5, 4, 1) # key, new, old + mock.reset_mock() + + assert ArrayDisplayModel.model_json_schema(mode="validation") + assert ArrayDisplayModel.model_json_schema(mode="serialization") diff --git a/tests/test_stats_model.py b/tests/test_stats_model.py deleted file mode 100644 index d285b79f..00000000 --- a/tests/test_stats_model.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np -import pytest - -from ndv.histogram.model import StatsModel - -EPSILON = 1e-6 - - -@pytest.fixture -def data() -> np.ndarray: - gen = np.random.default_rng(0xDEADBEEF) - # Average - 1.000104 - # Std. Dev. - 10.003385 - data = gen.normal(1, 10, (1000, 1000)) - return data - - -def test_empty_stats_model() -> None: - model = StatsModel() - with pytest.raises(RuntimeError): - _ = model.data - assert model.average is None - assert model.standard_deviation is None - assert model.histogram is None - assert model.bins == 256 - - -def test_stats_model(data: np.ndarray) -> None: - model = StatsModel() - model.data = data - assert np.all(model.data == data) - # Basic regression tests - assert abs(model.average - 1.000104) < 1e-6 - assert abs(model.standard_deviation - 10.003385) < 1e-6 - assert 256 == model.bins - values, edges = model.histogram - assert len(values) == 256 - assert np.all(values >= 0) - assert np.all(values <= data.size) - assert len(edges) == 257 - assert edges[0] == np.min(data) - assert edges[256] == np.max(data) diff --git a/tests/test_vispy_histogram_view.py b/tests/test_vispy_histogram_view.py deleted file mode 100644 index 651ac835..00000000 --- a/tests/test_vispy_histogram_view.py +++ /dev/null @@ -1,337 +0,0 @@ -from __future__ import annotations - -import math -from typing import TYPE_CHECKING - -import cmap -import numpy as np -import pytest -from qtpy.QtWidgets import QHBoxLayout, QWidget -from vispy.app.canvas import MouseEvent -from vispy.color import Color - -from ndv.histogram.views._vispy import Grabbable, VispyHistogramView - -if TYPE_CHECKING: - from pytestqt.qtbot import QtBot - -# Accounts for differences between 32-bit and 64-bit floats -EPSILON = 1e-6 -# FIXME: Why do plot checks need a larger epsilon? -PLOT_EPSILON = 1e-4 - - -@pytest.fixture -def data() -> np.ndarray: - gen = np.random.default_rng(seed=0xDEADBEEF) - return gen.normal(10, 10, 10000).astype(np.float64) - - -@pytest.fixture -def view(qtbot: QtBot, data: np.ndarray) -> VispyHistogramView: - # Create view - view = VispyHistogramView() - view._canvas.size = (100, 100) - # FIXME: Why does `qtbot.add_widget(view.view())` not work? - wdg = QWidget() - layout = QHBoxLayout(wdg) - layout.addWidget(view.view()) - qtbot.add_widget(wdg) - # Set initial data - values, bin_edges = np.histogram(data) - view.set_histogram(values, bin_edges) - - return view - - -def test_plot(view: VispyHistogramView) -> None: - plot = view.plot - - assert plot.title == "" - plot.title = "foo" - assert plot._title.text == "foo" - - assert plot.xlabel == "" - plot.xlabel = "bar" - assert plot._xlabel.text == "bar" - - assert plot.ylabel == "" - plot.ylabel = "baz" - assert plot._ylabel.text == "baz" - - # Test axis lock - pan - _domain = plot.xaxis.axis.domain - _range = plot.yaxis.axis.domain - plot.camera.pan([20, 20]) - assert np.all(np.isclose(_domain, [x - 20 for x in plot.xaxis.axis.domain])) - assert np.all(np.isclose(_range, plot.yaxis.axis.domain)) - - # Test axis lock - zoom - _domain = plot.xaxis.axis.domain - _range = plot.yaxis.axis.domain - plot.camera.zoom(0.5) - dx = (_domain[1] - _domain[0]) / 4 - assert np.all( - np.isclose([_domain[0] + dx, _domain[1] - dx], plot.xaxis.axis.domain) - ) - assert np.all(np.isclose(_range, plot.yaxis.axis.domain)) - - -def test_clims(data: np.ndarray, view: VispyHistogramView) -> None: - # on startup, clims should be at the extent of the data - clims = np.min(data), np.max(data) - assert view._clims is not None - assert clims[0] == view._clims[0] - assert clims[1] == view._clims[1] - assert abs(clims[0] - view._lut_line._line.pos[0, 0]) <= EPSILON - assert abs(clims[1] - view._lut_line._line.pos[-1, 0]) <= EPSILON - # set clims, assert a change - clims = 9, 11 - view.set_clims(clims) - assert clims[0] == view._clims[0] - assert clims[1] == view._clims[1] - assert abs(clims[0] - view._lut_line._line.pos[0, 0]) <= EPSILON - assert abs(clims[1] - view._lut_line._line.pos[-1, 0]) <= EPSILON - # set clims backwards - ensure the view flips them - clims = 5, 3 - view.set_clims(clims) - assert clims[1] == view._clims[0] - assert clims[0] == view._clims[1] - assert abs(clims[1] - view._lut_line._line.pos[0, 0]) <= EPSILON - assert abs(clims[0] - view._lut_line._line.pos[-1, 0]) <= EPSILON - - -def test_gamma(data: np.ndarray, view: VispyHistogramView) -> None: - # on startup, gamma should be 1 - assert 1 == view._gamma - gx, gy = (np.max(data) + np.min(data)) / 2, 0.5**view._gamma - assert abs(gx - view._gamma_handle_pos[0, 0]) <= EPSILON - assert abs(gy - view._gamma_handle_pos[0, 1]) <= EPSILON - # set gamma, assert a change - g = 2 - view.set_gamma(g) - assert g == view._gamma - gx, gy = (np.max(data) + np.min(data)) / 2, 0.5**view._gamma - assert abs(gx - view._gamma_handle_pos[0, 0]) <= EPSILON - assert abs(gy - view._gamma_handle_pos[0, 1]) <= EPSILON - # set invalid gammas, assert no change - with pytest.raises(ValueError): - view.set_gamma(-1) - - -def test_cmap(view: VispyHistogramView) -> None: - # By default, histogram is red - assert view._hist_mesh.color == Color("red") - # Set cmap, assert a change - view.set_cmap(cmap.Colormap("blue")) - assert view._hist_mesh.color == Color("blue") - - -def test_visibility(view: VispyHistogramView) -> None: - # By default, everything is visible - assert view._hist_mesh.visible - assert view._lut_line.visible - assert view._gamma_handle.visible - # Visible = False - view.set_visibility(False) - assert not view._hist_mesh.visible - assert not view._lut_line.visible - assert not view._gamma_handle.visible - # Visible = True - view.set_visibility(True) - assert view._hist_mesh.visible - assert view._lut_line.visible - assert view._gamma_handle.visible - - -def test_domain(data: np.ndarray, view: VispyHistogramView) -> None: - def assert_extent(min_x: float, max_x: float) -> None: - domain = view.plot.xaxis.axis.domain - assert abs(min_x - domain[0]) <= PLOT_EPSILON - assert abs(max_x - domain[1]) <= PLOT_EPSILON - min_y, max_y = 0, np.max(np.histogram(data)[0]) - range = view.plot.yaxis.axis.domain # noqa: A001 - assert abs(min_y - range[0]) <= PLOT_EPSILON - assert abs(max_y - range[1]) <= PLOT_EPSILON - - # By default, the view should be around the histogram - assert_extent(np.min(data), np.max(data)) - # Set the domain, request a change - new_domain = (10, 12) - view.set_domain(new_domain) - assert_extent(*new_domain) - # Set the domain to None, assert going back - new_domain = None - view.set_domain(new_domain) - assert_extent(np.min(data), np.max(data)) - # Assert None value in tuple raises ValueError - with pytest.raises(ValueError): - view.set_domain((None, 12)) - # Set the domain with min>max, ensure values flipped - new_domain = (12, 10) - view.set_domain(new_domain) - assert_extent(10, 12) - - -def test_range(data: np.ndarray, view: VispyHistogramView) -> None: - # FIXME: Why do we need a larger epsilon? - _EPSILON = 1e-4 - - def assert_extent(min_y: float, max_y: float) -> None: - min_x, max_x = np.min(data), np.max(data) - domain = view.plot.xaxis.axis.domain - assert abs(min_x - domain[0]) <= _EPSILON - assert abs(max_x - domain[1]) <= _EPSILON - range = view.plot.yaxis.axis.domain # noqa: A001 - assert abs(min_y - range[0]) <= _EPSILON - assert abs(max_y - range[1]) <= _EPSILON - - # By default, the view should be around the histogram - assert_extent(0, np.max(np.histogram(data)[0])) - # Set the range, request a change - new_range = (10, 12) - view.set_range(new_range) - assert_extent(*new_range) - # Set the range to None, assert going back - new_range = None - view.set_range(new_range) - assert_extent(0, np.max(np.histogram(data)[0])) - # Assert None value in tuple raises ValueError - with pytest.raises(ValueError): - view.set_range((None, 12)) - # Set the range with min>max, ensure values flipped - new_range = (12, 10) - view.set_range(new_range) - assert_extent(10, 12) - - -def test_vertical(view: VispyHistogramView) -> None: - # Start out Horizontal - assert not view._vertical - domain_before = view.plot.xaxis.axis.domain - range_before = view.plot.yaxis.axis.domain - # Toggle vertical, assert domain <-> range - view.set_vertical(True) - assert view._vertical - domain_after = view.plot.xaxis.axis.domain - # NB vertical mode inverts y axis - range_after = view.plot.yaxis.axis.domain[::-1] - assert abs(domain_before[0] - range_after[0]) <= PLOT_EPSILON - assert abs(domain_before[1] - range_after[1]) <= PLOT_EPSILON - assert abs(range_before[0] - domain_after[0]) <= PLOT_EPSILON - assert abs(range_before[1] - domain_after[1]) <= PLOT_EPSILON - # Toggle vertical again, assert domain <-> range again - view.set_vertical(False) - assert not view._vertical - domain_after = view.plot.xaxis.axis.domain - range_after = view.plot.yaxis.axis.domain - assert abs(domain_before[0] - domain_after[0]) <= PLOT_EPSILON - assert abs(domain_before[1] - domain_after[1]) <= PLOT_EPSILON - assert abs(range_before[0] - range_after[0]) <= PLOT_EPSILON - assert abs(range_before[1] - range_after[1]) <= PLOT_EPSILON - - -def test_log(view: VispyHistogramView) -> None: - # Start out linear - assert not view._log_y - linear_range = view.plot.yaxis.axis.domain[1] - linear_hist = view._hist_mesh.bounds(1)[1] - # lut line, gamma markers controlled by scale - linear_line_scale = view._handle_transform.scale[1] - - # Toggle log, assert range shrinks - view.set_range_log(True) - assert view._log_y - log_range = view.plot.yaxis.axis.domain[1] - log_hist = view._hist_mesh.bounds(1)[1] - log_line_scale = view._handle_transform.scale[1] - assert abs(math.log10(linear_range) - log_range) <= EPSILON - assert abs(math.log10(linear_hist) - log_hist) <= EPSILON - # NB This final check isn't so simple because of margins, scale checks, - # etc - so need a larger epsilon. - assert abs(math.log10(linear_line_scale) - log_line_scale) <= 0.1 - - # Toggle log, assert range reverts - view.set_range_log(False) - assert not view._log_y - revert_range = view.plot.yaxis.axis.domain[1] - revert_hist = view._hist_mesh.bounds(1)[1] - revert_line_scale = view._handle_transform.scale[1] - assert abs(linear_range - revert_range) <= EPSILON - assert abs(linear_hist - revert_hist) <= EPSILON - assert abs(linear_line_scale - revert_line_scale) <= EPSILON - - -# @pytest.mark.skipif(sys.platform != "darwin", reason="the mouse event is tricky") -def test_move_clim(qtbot: QtBot, view: VispyHistogramView) -> None: - # Set clims within the viewbox - view.set_domain((0, 100)) - view.set_clims((10, 90)) - # Click on the left clim - press_pos = view.node_tform.imap([10])[:2] - event = MouseEvent("mouse_press", pos=press_pos, button=1) - view.on_mouse_press(event) - assert view._grabbed == Grabbable.LEFT_CLIM - assert not view.plot.camera.interactive - # Move it to 50 - move_pos = view.node_tform.imap([50])[:2] - event = MouseEvent("mouse_move", pos=move_pos, button=1) - with qtbot.waitSignal(view.climsChanged): - view.on_mouse_move(event) - assert view._grabbed == Grabbable.LEFT_CLIM - assert not view.plot.camera.interactive - # Release mouse - release_pos = move_pos - event = MouseEvent("mouse_release", pos=release_pos, button=1) - view.on_mouse_release(event) - assert view._grabbed == Grabbable.NONE - assert view.plot.camera.interactive - - # Move both clims to 50 - view.set_clims((50, 50)) - # Ensure clicking and moving at 50 moves the right clim - press_pos = view.node_tform.imap([50])[:2] - event = MouseEvent("mouse_press", pos=press_pos, button=1) - view.on_mouse_press(event) - assert view._grabbed == Grabbable.RIGHT_CLIM - assert not view.plot.camera.interactive - # Move it to 70 - move_pos = view.node_tform.imap([70])[:2] - event = MouseEvent("mouse_move", pos=move_pos, button=1) - with qtbot.waitSignal(view.climsChanged): - view.on_mouse_move(event) - assert view._grabbed == Grabbable.RIGHT_CLIM - assert not view.plot.camera.interactive - # Release mouse - release_pos = move_pos - event = MouseEvent("mouse_release", pos=release_pos, button=1) - view.on_mouse_release(event) - assert view._grabbed == Grabbable.NONE - assert view.plot.camera.interactive - - -def test_move_gamma(qtbot: QtBot, view: VispyHistogramView) -> None: - # Set clims outside the viewbox - # NB the canvas is small in this test, so we have to put the clims - # far away or they'll be grabbed over the gamma - view.set_domain((0, 100)) - view.set_clims((-9950, 10050)) - # Click on the gamma handle - press_pos = view.node_tform.imap(view._handle_transform.map([50, 0.5]))[:2] - event = MouseEvent("mouse_press", pos=press_pos, button=1) - view.on_mouse_press(event) - assert view._grabbed == Grabbable.GAMMA - assert not view.plot.camera.interactive - # Move it to 50 - move_pos = view.node_tform.imap(view._handle_transform.map([50, 0.75]))[:2] - event = MouseEvent("mouse_move", pos=move_pos, button=1) - with qtbot.waitSignal(view.gammaChanged): - view.on_mouse_move(event) - assert view._grabbed == Grabbable.GAMMA - assert not view.plot.camera.interactive - # Release mouse - release_pos = move_pos - event = MouseEvent("mouse_release", pos=release_pos, button=1) - view.on_mouse_release(event) - assert view._grabbed == Grabbable.NONE - assert view.plot.camera.interactive diff --git a/tests/test_nd_viewer.py b/tests/v1/test_v1_viewer.py similarity index 64% rename from tests/test_nd_viewer.py rename to tests/v1/test_v1_viewer.py index 956dd5cc..0c4a6b9f 100644 --- a/tests/test_nd_viewer.py +++ b/tests/v1/test_v1_viewer.py @@ -1,18 +1,25 @@ from __future__ import annotations -import os import sys -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np import pytest + +try: + import pytestqt + + if pytestqt.qt_compat.qt_api.pytest_qt_api.startswith("pyside"): + pytest.skip("V1 viewer segfaults with pyside", allow_module_level=True) + +except ImportError: + pytest.skip("This module requires qt frontend", allow_module_level=True) + + from qtpy.QtCore import QEvent, QPointF, Qt from qtpy.QtGui import QMouseEvent -from ndv import NDViewer - -if TYPE_CHECKING: - from pytestqt.qtbot import QtBot +from ndv.v1 import NDViewer def allow_linux_widget_leaks(func: Any) -> Any: @@ -21,15 +28,10 @@ def allow_linux_widget_leaks(func: Any) -> Any: return func -BACKENDS = ["vispy"] -# avoid pygfx backend on linux CI -if not os.getenv("CI") or sys.platform == "darwin": - BACKENDS.append("pygfx") - - -def test_empty_viewer(qtbot: QtBot) -> None: +@allow_linux_widget_leaks +@pytest.mark.usefixtures("any_app") +def test_empty_viewer() -> None: viewer = NDViewer() - qtbot.add_widget(viewer) viewer.refresh() viewer.set_data(np.random.rand(4, 3, 32, 32)) assert isinstance(viewer.data, np.ndarray) @@ -38,14 +40,14 @@ def test_empty_viewer(qtbot: QtBot) -> None: @allow_linux_widget_leaks -@pytest.mark.parametrize("backend", BACKENDS) -def test_ndviewer(qtbot: QtBot, backend: str, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("NDV_CANVAS_BACKEND", backend) +def test_ndviewer(any_app: Any) -> None: dask_arr = np.empty((4, 3, 2, 32, 32), dtype=np.uint8) v = NDViewer(dask_arr) - qtbot.addWidget(v) + # qtbot.addWidget(v) v.show() - qtbot.waitUntil(v._is_idle, timeout=1000) + if isinstance(any_app, tuple) and len(any_app) == 2: + qtbot = any_app[1] + qtbot.waitUntil(v._is_idle, timeout=1000) v.set_ndim(3) v.set_channel_mode("composite") v.set_current_index({0: 2, 1: 1, 2: 1}) @@ -57,17 +59,18 @@ def test_ndviewer(qtbot: QtBot, backend: str, monkeypatch: pytest.MonkeyPatch) - # wait until there are no running jobs, because the callbacks # in the futures hold a strong reference to the viewer - qtbot.waitUntil(v._is_idle, timeout=3000) + # qtbot.waitUntil(v._is_idle, timeout=3000) # not testing pygfx yet... @pytest.mark.skipif(sys.platform != "darwin", reason="the mouse event is tricky") -def test_hover_info(qtbot: QtBot) -> None: +def test_hover_info(any_app: Any) -> None: data = np.ones((4, 3, 32, 32), dtype=np.float32) viewer = NDViewer(data) - qtbot.addWidget(viewer) viewer.show() - qtbot.waitUntil(viewer._is_idle, timeout=1000) + if isinstance(any_app, tuple) and len(any_app) == 2: + qtbot = any_app[1] + qtbot.waitUntil(viewer._is_idle, timeout=1000) mouse_event = QMouseEvent( QEvent.Type.MouseMove, QPointF(viewer._qcanvas.rect().center()),