diff --git a/plans/interactive-selection.md b/plans/interactive-selection.md new file mode 100644 index 00000000..f2f9c150 --- /dev/null +++ b/plans/interactive-selection.md @@ -0,0 +1,230 @@ +# Interactive region selection in spatialdata-plot + +Status: spec (v0). Materialized from session handoff on 2026-05-21. + +## Goal + +A minimal, in-notebook (Jupyter / VSCode-Remote-SSH) widget that lets the user +draw a region on a spatialdata-plot canvas and persist it back into the +SpatialData object as a ShapesModel element. Works over an SSH bridge to a +SLURM compute node. No napari, no desktop GUI. + +## Confirmed design decisions + +- Output: persisted ShapesModel written back to the on-disk zarr via + `sdata.write_element`. Survives kernel restarts. +- Selector shapes in v0: rectangle, polygon (click vertices), lasso (freehand). +- Scale handling: auto-downsample on the fly. Pyramid-aware when available; + `dask.coarsen` fallback when not. +- Layers in v0: images only. The image is rendered once via the existing + `sdata.pl.render_images().pl.show()` pipeline into a matplotlib figure, + exported to PNG, and laid under a client-side drawing canvas. +- Backend: **custom anywidget** with HTML5/SVG drawing tools (rectangle, + polygon, freehand-lasso). All drawing happens in the browser; shape + geometry is reported back to Python via traitlet sync. Image is sent + once as a base64 data URL; mouse moves never round-trip the kernel. + No bokeh/datashader. + +### Why anywidget, not ipympl or plotly + +The original spec called for `%matplotlib widget` (ipympl). The prototype +revealed two showstoppers over SSH: +1. **ipympl streams PNG frames per mouse-move** over websocket — every drag + incurs SSH round-trip latency, making freehand drawing unusable. +2. **plotly's `FigureWidget`** has broken two-way shape sync in + VSCode-Remote-SSH (regardless of plotly 5 vs 6 — different bugs each). + +A small (~250-line) anywidget with traitlet-synced shape geometry was the +only architecture that worked reliably in VSCode-Remote and produced +responsive drawing. The image render still uses sdata-plot's matplotlib +pipeline; we just don't drive interaction through it. + +## Resolved questions (locked 2026-05-21, task #1) + +- **Q1 — Channel/contrast widgets**: **No live widgets in v0.** `channel=` and + `clims=` remain optional kwargs that forward to `render_images`. No + ipywidgets-driven controls. Widget toolbar deferred to v1. +- **Q2 — Auto-redraw on zoom**: **v1.** v0 renders once at the chosen scale; + `xlim_changed`/`ylim_changed` does not re-pick pyramid level. Static extent + ships sooner. +- **Q3 — Selector kind switching**: **One per call.** `selector=` is fixed at + session construction; no mid-session switching. Switchable kinds deferred to + v1. +- **Q4 — `name=` default**: **Required.** No default; omitting `name=` raises. + Keeps persisted element names intentional and zarr listings legible. + +## Public API sketch + +```python +import spatialdata_plot # registers .pl + +session = sdata.pl.interactive( + coordinate_system=None, # optional pre-selection; None = let user pick in UI + element=None, # optional pre-selection; None = let user pick in UI + persist=True, # show "Write to disk" button (False = memory only) +) +session.show() # renders the ipywidgets controls + draw canvas + +# User picks CS + image, clicks Render, draws shapes, names + Saves each set. +# Each Save adds an entry to sdata.shapes (memory). Write to disk persists +# the most recent commit via sdata.write_element. + +sdata["tumor_region"] # ShapesModel +sub = sdata.query.polygon(sdata, sdata["tumor_region"]) +``` + +Removed kwargs vs original spec: +- `selector=` — UI has a tool toggle (rect/polygon/lasso); no need to bind one + selector at construction (Q3 resolution). +- `name=` — typed in the UI before each Save (Q4 resolution). +- `channel=`, `clims=` — deferred to v1 (Q1 resolution). +- `max_render_pixels=` — render is fixed at `figsize=(7,7), dpi=120` ≈ 840×840 + PNG; pyramid-aware downsampling deferred to v1. +- `overwrite=` — collision handling is automatic: same name → append UTC + timestamp. + +## Module layout + +``` +src/spatialdata_plot/pl/interactive/ + __init__.py # exports interactive, InteractiveSession, DrawCanvas + _session.py # InteractiveSession class — ipywidgets controls + _canvas.py # DrawCanvas anywidget + traitlets + _render.py # render_to_png helper (sdata.pl → PNG + extent) + _commit.py # pixel-shape → CS-correct shapely Polygon → ShapesModel + _persist.py # write_element + collision/timestamp policy + static/ + draw_canvas.js # the ESM module; _esm = Path(...) reads at import + +tests/test_interactive/ + test_commit.py # pixel→CS conversion + ShapesModel correctness + test_render.py # render_to_png returns valid PNG + extent + test_persist.py # collision/timestamp policy + test_canvas.py # smoke: instantiate widget, check traitlet defaults +``` + +`sdata.pl.interactive(...)` is a method on `PlotAccessor` in +`src/spatialdata_plot/_accessor.py`. It constructs an `InteractiveSession` +and returns it; `session.show()` displays the controls + draw canvas. + +Dropped from the original spec: +- `_downsample.py` — pyramid-aware downsampling deferred to v1; v0 renders + at a fixed dpi (`figsize=(7,7), dpi=120`). +- `_selectors.py` — matplotlib selectors are replaced by the anywidget; the + three drawing tools (rect/polygon/lasso) live in `static/draw_canvas.js`. + +## Coordinate-system rules (highest-risk surface) + +1. Session is bound to ONE coordinate system at construction. +2. Render is in that CS; axes coords on the canvas equal coords in the CS + (1:1). +3. On commit, vertices are already in the rendered CS — no transform needed + for the selection itself. +4. The committed ShapesModel is registered with `{cs_name: Identity()}`. +5. Cross-CS selection is the user's job downstream. Not v0. + +Avoids the classic double-applied-transform bug. + +## Rendering + +`_render.render_to_png(sdata, element, coordinate_system) -> (png_bytes, image_w, image_h, xlim, ylim)` + +- Uses `sdata.pl.render_images(element=...).pl.show(coordinate_systems=..., ax=...)`. +- Axes fills the figure (`ax.add_axes([0,0,1,1])`, `set_axis_off()`) so PNG pixel + coordinates map exactly to data coordinates via `xlim`/`ylim`. +- Fixed at `figsize=(7,7)` × `dpi=120` ≈ 840×840 PNG for v0. Pyramid-aware + downsampling deferred to v1. +- 3D / z-stacks: refused by `render_images` itself (commit 3ebefe1) — we + propagate that error. + +## Drawing tools (in `static/draw_canvas.js`) + +| kind | gesture | commit trigger | +|-------------|------------------------------------------------|-----------------------------------------------| +| rectangle | left-drag corner → corner | mouse release | +| polygon | click each vertex | snap-to-first-vertex (within 10 px) or Enter | +| lasso | left-drag freehand | mouse release | + +Plus client-side: wheel-zoom, shift-drag-pan, alt-click-shape-to-delete, +hover-highlight, Ctrl+Z undo, Delete clear, R/P/L tool shortcuts, F fit. + +Lasso vertices are simplified server-side via `shapely.simplify(tolerance=0.5)` +in `_commit` before persisting. + +## Persistence policy + +- `sdata.path` set → `sdata.write_element(name)` on every commit. +- Not zarr-backed → warn once, keep in memory. +- `overwrite=False` default. Collision → rename to `"_"`. +- `session.commits` list tracks names committed this session. + +## Risks (pre-mitigated) + +1. CS mistakes → identity transform + unit tests. +2. Image too large → `max_render_pixels` hard cap with clear error. +3. ipympl flakiness in VSCode → documented fallback to browser-Jupyter via + `ssh -L 8888:localhost:8888 node`. +4. Walltime kill → auto-persist every commit. +5. Lasso 10k vertices → `shapely.simplify`. +6. Concurrent zarr writers → documented, no locking in v0. +7. 3D / z-stacks → refuse with same error as static render (commit 3ebefe1). +8. Auto-zoom redraw not in v0 → static extent ships first. + +## Test strategy + +- Unit: `_commit` (synthetic pixel-coord shapes → CS-coord ShapesModel correctness). +- Unit: `_render` (returns valid PNG bytes + extent matching the axis limits). +- Unit: `_persist` (collision-rename + timestamp policy). +- Smoke: `_canvas` (instantiate `DrawCanvas`, check traitlet defaults). +- NO visual / live-canvas tests in v0 — the JS widget can't be driven from Python. + Manual checklist in PR description covers the canvas behaviour. + +## Dependencies + +Exposed as `[project.optional-dependencies].interactive` so the feature is +opt-in (`pip install spatialdata-plot[interactive]`). Mirrors the pixi +`interactive` dep-group. + +- `anywidget` (NEW) — the widget framework. +- `ipywidgets` (NEW or pin existing transitive) — for the controls VBox. +- `ipykernel` — needed by anywidget for comm channel. +- `shapely`, `geopandas` — already transitive via spatialdata. + +`ipympl` and `plotly` are NOT runtime deps of the new architecture (we tried +both and rejected them). They remain in the prototype/pixi feature only for +historical comparison and may be dropped from the interactive feature later. + +## v1 roadmap (after v0 ships) + +1. Auto-downsample on zoom (pyramid-aware redraw on `xlim_changed`). +2. Channel + contrast widget controls in the figure toolbar. +3. Labels overlay (segmentation visible during selection). +4. Multiple selectors per session; switchable kinds. +5. Datashader path for points-heavy elements. + +## Task queue + +1. Resolve spec open questions Q1–Q4 +2. Add ipympl dep + pixi interactive feature +3. Scaffold `pl/interactive` submodule +4. Wire `sdata.pl.interactive` entrypoint +5. Implement `_commit`: vertices → ShapesModel +6. Implement `_persist`: zarr write policy +7. Implement `_downsample`: scale picker + warn +8. Implement `_render`: image render to ax +9. Implement `_selectors`: Rectangle/Polygon/Lasso adapters +10. Wire `InteractiveSession` end-to-end +11. Manual end-to-end test on cluster +12. Document feature in module docstring + README + +## Operating rules + +- Repo CLAUDE.md rules apply: plan-first for multi-file work, no drive-by + refactors, run pixi-defined tasks (lint/format/test) before commits, no + pre-commit / no visual tests locally (CI only). +- Pixi only. No venv/pip. `dev-py313` environment. +- Don't stage with `-A`; stage only what's touched. +- Human drives the actual ipympl canvas; agent cannot see it. Agent can + drive a parallel headless kernel on the same node for non-UI checks. +- If task #1 answers change the spec materially, update this file before + starting #2. diff --git a/pyproject.toml b/pyproject.toml index 7ebf7057..a4444830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,11 @@ dependencies = [ "scikit-learn", "spatialdata>=0.3", ] +optional-dependencies.interactive = [ + "anywidget", + "ipykernel", + "ipywidgets", +] urls.Documentation = "https://spatialdata.scverse.org/projects/plot/en/latest/index.html" urls.Home-page = "https://github.com/scverse/spatialdata-plot.git" urls.Source = "https://github.com/scverse/spatialdata-plot.git" @@ -61,6 +66,16 @@ doc = [ "sphinxcontrib-katex", "sphinxext-opengraph", ] +interactive-extras = [ + # Prototype-only helpers used by Sandbox.ipynb. The published runtime extra + # is [project.optional-dependencies].interactive above (anywidget/ipykernel/ + # ipywidgets only) — these are kept here for the dev-interactive-py313 env. + "ipympl", + # pinned to 5.x: plotly 6's anywidget-backed FigureWidget doesn't relay + # client-side draw events back to Python, so layout.shapes never syncs. + "plotly>=5.20,<6", + "squidpy", +] [tool.hatch] build.hooks.vcs.version-file = "_version.py" @@ -86,29 +101,49 @@ envs.hatch-test.scripts.cov-report = [ "coverage report", "coverage xml -o cover metadata.allow-direct-references = true version.source = "vcs" -[tool.pixi] -workspace.channels = [ "conda-forge" ] -workspace.platforms = [ "linux-64", "osx-arm64" ] -dependencies.python = ">=3.11" -pypi-dependencies.spatialdata-plot = { path = ".", editable = true } -tasks.format = "ruff format ." -tasks.kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"' -tasks.lab = "jupyter lab" -tasks.lint = "ruff check ." -tasks.pre-commit-install = "pre-commit install" -tasks.pre-commit-run = "pre-commit run --all-files" -tasks.test = "pytest -v --color=yes --tb=short --durations=10" +[tool.pixi.workspace] +channels = [ "conda-forge" ] +platforms = [ "linux-64", "osx-arm64" ] + +[tool.pixi.dependencies] +python = ">=3.11" + +[tool.pixi.pypi-dependencies] +spatialdata-plot = { path = ".", editable = true } + +# When the `interactive` feature is active, install the package with the +# `interactive` PyPI extra (anywidget, ipykernel, ipywidgets) so the pixi +# env mirrors what `pip install spatialdata-plot[interactive]` would give. +[tool.pixi.feature.interactive.pypi-dependencies] +spatialdata-plot = { path = ".", editable = true, extras = [ "interactive" ] } + +[tool.pixi.tasks] +format = "ruff format ." +kernel-install = 'python -m ipykernel install --user --name pixi-dev --display-name "sdata-plot (dev)"' +kernel-install-interactive = 'python -m ipykernel install --user --name sdata-plot-interactive --display-name "sdata-plot (interactive)"' +lab = "jupyter lab" +lint = "ruff check ." +pre-commit-install = "pre-commit install" +pre-commit-run = "pre-commit run --all-files" +test = "pytest -v --color=yes --tb=short --durations=10" + # for gh-actions -feature.py311.dependencies.python = "3.11.*" -feature.py313.dependencies.python = "3.13.*" +[tool.pixi.feature.py311.dependencies] +python = "3.11.*" + +[tool.pixi.feature.py313.dependencies] +python = "3.13.*" + +[tool.pixi.environments] # 3.13 lane -environments.default = { features = [ "py313" ], solve-group = "py313" } +default = { features = [ "py313" ], solve-group = "py313" } # 3.11 lane (for gh-actions) -environments.dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" } -environments.dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" } -environments.docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" } -environments.docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" } -environments.test-py313 = { features = [ "test", "py313" ], solve-group = "py313" } +dev-py311 = { features = [ "dev", "test", "py311" ], solve-group = "py311" } +dev-py313 = { features = [ "dev", "test", "py313" ], solve-group = "py313" } +dev-interactive-py313 = { features = [ "dev", "test", "interactive", "interactive-extras", "py313" ], solve-group = "py313" } +docs-py311 = { features = [ "doc", "py311" ], solve-group = "py311" } +docs-py313 = { features = [ "doc", "py313" ], solve-group = "py313" } +test-py313 = { features = [ "test", "py313" ], solve-group = "py313" } [tool.ruff] line-length = 120 diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 37a80593..49a02c50 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -171,6 +171,85 @@ def _copy( return sdata + def annotate( + self, + coordinate_system: str, + element: str, + *, + persist: bool = True, + max_width: int = 880, + ) -> None: + """Draw and save regions interactively on an image element. + + Renders the image element in the given coordinate system as a + client-side drawing canvas (rectangle / polygon / lasso tools). + Drawn shapes are saved into ``sdata.shapes`` under a user-typed name + on click of the *Save* button — each save creates one ShapesModel + with one row per drawn shape, registered with an ``Identity`` + transformation in the chosen coordinate system. The canvas is + cleared on every Save so the next set of shapes can be drawn + independently. + + Same-name commits overwrite both in memory and (via + ``SpatialData.write_element``) on disk. + + Requires the ``interactive`` extra: ``pip install 'spatialdata-plot[interactive]'``. + + Parameters + ---------- + coordinate_system : + Coordinate system to render and resolve drawn shapes against. + Drawn polygons are stored with an ``Identity`` transformation + in this CS. + element : + Name of the image element to render. + persist : + If ``True`` (default), show a *Write to disk* button that calls + :meth:`SpatialData.write_element` for the most recent save. + Set to ``False`` to limit the session to in-memory commits. + max_width : + Maximum display width in CSS pixels. The widget fills its + container width but never exceeds this. Display hint only; + the underlying render is always 840 × 840 px. + + Returns + ------- + None + Displays the widget in the current notebook cell. Drawn and + saved shapes appear in ``sdata.shapes``; inspect them there. + + Raises + ------ + ValueError + If ``coordinate_system`` is unknown, ``element`` is unknown, + or ``element`` is not registered in ``coordinate_system``. + ImportError + If the ``interactive`` extra is not installed. + + Examples + -------- + >>> import spatialdata_plot # noqa: F401 registers .pl + >>> sdata.pl.annotate("global", "he_image") + >>> # ... user draws and clicks Save with name "tumor" ... + >>> sdata.shapes["tumor"] + """ + try: + from spatialdata_plot.pl.interactive._session import _InteractiveSession + except ImportError as exc: + raise ImportError( + "sdata.pl.annotate() requires the `interactive` extra. " + "Install with: pip install 'spatialdata-plot[interactive]'" + ) from exc + + session = _InteractiveSession( + self._sdata, + coordinate_system=coordinate_system, + element=element, + persist=persist, + max_width=max_width, + ) + session.show() + @_deprecation_alias(elements="element", version="0.3.0") def render_shapes( self, diff --git a/src/spatialdata_plot/pl/interactive/__init__.py b/src/spatialdata_plot/pl/interactive/__init__.py new file mode 100644 index 00000000..35cc2868 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/__init__.py @@ -0,0 +1,11 @@ +"""Interactive region selection on a SpatialData image. + +Use via :meth:`spatialdata_plot.pl.basic.PlotAccessor.annotate`: + +>>> import spatialdata_plot # noqa: F401 registers .pl +>>> sdata.pl.annotate("global", "he_image") +""" + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/src/spatialdata_plot/pl/interactive/_canvas.py b/src/spatialdata_plot/pl/interactive/_canvas.py new file mode 100644 index 00000000..e024f7d2 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_canvas.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from pathlib import Path + +import anywidget +import traitlets + +_ESM_PATH = Path(__file__).parent / "static" / "draw_canvas.js" + +TOOLS = ("rectangle", "polygon", "lasso") + + +class DrawCanvas(anywidget.AnyWidget): + """Client-side SVG drawing surface for interactive region selection. + + The image (PNG data URL) is shown as a CSS-transformed background; an + overlay SVG catches mouse events and emits committed shapes in image- + pixel coordinates via the ``shapes`` traitlet. + + Convert the pixel-coord shapes to data/CS coordinates with + :func:`spatialdata_plot.pl.interactive._commit.pixel_shape_to_polygon`. + + Traitlets + --------- + image_url + ``data:image/png;base64,...`` for the rendered image. + image_width, image_height + Pixel dimensions of the PNG (used to set the SVG ``viewBox``). + tool + One of ``TOOLS``. + shapes + List of ``{"type": "rect"|"polygon", "verts": [[x, y], ...]}`` in + image-pixel coordinates. JS pushes to this on commit. + clear_trigger, close_poly_trigger, undo_trigger, fit_trigger + Integer counters. Increment from Python to invoke the corresponding + JS action; JS observers are stateless w.r.t. the value, only the + change event matters. + """ + + _esm = _ESM_PATH + + image_url = traitlets.Unicode("").tag(sync=True) + image_width = traitlets.Int(720).tag(sync=True) + image_height = traitlets.Int(720).tag(sync=True) + max_display_width = traitlets.Int(880).tag(sync=True) + tool = traitlets.Enum(TOOLS, default_value="rectangle").tag(sync=True) + shapes = traitlets.List([]).tag(sync=True) + clear_trigger = traitlets.Int(0).tag(sync=True) + close_poly_trigger = traitlets.Int(0).tag(sync=True) + undo_trigger = traitlets.Int(0).tag(sync=True) + fit_trigger = traitlets.Int(0).tag(sync=True) diff --git a/src/spatialdata_plot/pl/interactive/_commit.py b/src/spatialdata_plot/pl/interactive/_commit.py new file mode 100644 index 00000000..0972c7cd --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_commit.py @@ -0,0 +1,46 @@ +"""Convert canvas pixel-coord shapes into a CS-coord ShapesModel.""" + +from __future__ import annotations + +from typing import Any + +import geopandas as gpd +from shapely.geometry import Polygon +from spatialdata.models import ShapesModel +from spatialdata.transformations.transformations import Identity + +from ._render import RenderExtent + +_LASSO_SIMPLIFY_TOL = 0.5 +_DENSE_POLYGON_VERTEX_THRESHOLD = 50 + + +def pixel_shape_to_polygon(shape: dict[str, Any], extent: RenderExtent) -> Polygon | None: + """Convert a single ``DrawCanvas`` shape entry to a CS-coord shapely Polygon. + + Returns ``None`` if the shape is invalid (no verts, <3 verts after + construction, or empty). + """ + verts = shape.get("verts") if isinstance(shape, dict) else None + if not verts: + return None + + xmin, xmax = float(extent.xlim[0]), float(extent.xlim[1]) + y_lo, y_hi = sorted((float(extent.ylim[0]), float(extent.ylim[1]))) + w, h = extent.image_w, extent.image_h + + cs_verts = [(xmin + (v[0] / w) * (xmax - xmin), y_lo + (v[1] / h) * (y_hi - y_lo)) for v in verts] + if len(cs_verts) < 3: + return None + poly = Polygon(cs_verts) + if poly.is_empty: + return None + if shape.get("type") == "polygon" and len(cs_verts) > _DENSE_POLYGON_VERTEX_THRESHOLD: + poly = poly.simplify(_LASSO_SIMPLIFY_TOL, preserve_topology=True) + return poly + + +def build_shapes_model(polygons: list[Polygon], coordinate_system: str) -> Any: + """Wrap shapely polygons in a ShapesModel registered with Identity in ``coordinate_system``.""" + gdf = gpd.GeoDataFrame({"geometry": polygons}) + return ShapesModel.parse(gdf, transformations={coordinate_system: Identity()}) diff --git a/src/spatialdata_plot/pl/interactive/_persist.py b/src/spatialdata_plot/pl/interactive/_persist.py new file mode 100644 index 00000000..b4142a83 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_persist.py @@ -0,0 +1,13 @@ +"""Commit a ShapesModel into sdata.shapes.""" + +from __future__ import annotations + +from typing import Any + +import spatialdata as sd + + +def commit_to_memory(sdata: sd.SpatialData, shapes_model: Any, name: str) -> str: + """Add ``shapes_model`` to ``sdata.shapes`` under ``name``, overwriting on collision.""" + sdata.shapes[name] = shapes_model + return name diff --git a/src/spatialdata_plot/pl/interactive/_render.py b/src/spatialdata_plot/pl/interactive/_render.py new file mode 100644 index 00000000..0e33c7be --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_render.py @@ -0,0 +1,54 @@ +"""Render an image element to a PNG suitable for the DrawCanvas background.""" + +from __future__ import annotations + +from dataclasses import dataclass +from io import BytesIO + +import matplotlib.pyplot as plt +import spatialdata as sd + +_FIGSIZE = (7, 7) +_DPI = 120 +_IMAGE_W = _FIGSIZE[0] * _DPI +_IMAGE_H = _FIGSIZE[1] * _DPI + + +@dataclass(frozen=True) +class RenderExtent: + """Geometry of a render — PNG pixel dims + CS-coord limits at render time. + + For matplotlib image axes (``origin='upper'``) ``ylim`` is reversed: + the smaller y maps to PNG row 0. ``pixel_shape_to_polygon`` accepts + either orientation. + """ + + image_w: int + image_h: int + xlim: tuple[float, float] + ylim: tuple[float, float] + + +def render_to_png( + sdata: sd.SpatialData, + element: str, + coordinate_system: str, +) -> tuple[bytes, RenderExtent]: + """Render ``element`` in ``coordinate_system`` to PNG + its extent. + + The matplotlib axes fills the figure (``[0, 0, 1, 1]`` with axis off) so + the PNG-pixel ↔ data-coord mapping is exactly ``xlim`` × ``ylim``. + """ + fig = plt.figure(figsize=_FIGSIZE, dpi=_DPI) + try: + ax = fig.add_axes([0, 0, 1, 1]) + sdata.pl.render_images(element=element).pl.show(coordinate_systems=coordinate_system, ax=ax) + xlim = ax.get_xlim() + ylim = ax.get_ylim() + ax.set_axis_off() + buf = BytesIO() + fig.savefig(buf, format="png", dpi=_DPI, pad_inches=0) + finally: + plt.close(fig) + extent = RenderExtent(_IMAGE_W, _IMAGE_H, tuple(xlim), tuple(ylim)) + return buf.getvalue(), extent diff --git a/src/spatialdata_plot/pl/interactive/_session.py b/src/spatialdata_plot/pl/interactive/_session.py new file mode 100644 index 00000000..4f7691e0 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/_session.py @@ -0,0 +1,324 @@ +"""ipywidgets-based session orchestrating the DrawCanvas. Internal.""" + +from __future__ import annotations + +import base64 +from typing import Any, Literal + +import ipywidgets as W +import spatialdata as sd +from IPython.display import display +from shapely.geometry import Polygon +from spatialdata.transformations.operations import get_transformation + +from ._canvas import DrawCanvas +from ._commit import build_shapes_model, pixel_shape_to_polygon +from ._persist import commit_to_memory +from ._render import RenderExtent, render_to_png + +BannerKind = Literal["info", "success", "error", "hint"] + +_BANNER_CLASS = { + "info": "sdp-banner sdp-banner-info", + "success": "sdp-banner sdp-banner-success", + "error": "sdp-banner sdp-banner-error", + "hint": "sdp-banner sdp-banner-hint", +} + +_CSS = """ + +""" + + +def _fmt_banner(msg: str, kind: BannerKind = "info") -> str: + return f"
{msg}
" + + +def _validate(sdata: sd.SpatialData, coordinate_system: str, element: str) -> None: + if coordinate_system not in sdata.coordinate_systems: + raise ValueError( + f"Unknown coordinate system {coordinate_system!r}. Available: {list(sdata.coordinate_systems)}" + ) + if element not in sdata.images: + raise ValueError(f"Unknown image element {element!r}. Available: {list(sdata.images)}") + transforms = get_transformation(sdata[element], get_all=True) + if coordinate_system not in transforms: + raise ValueError( + f"Image {element!r} is not registered in coordinate system " + f"{coordinate_system!r}. Registered in: {list(transforms)}" + ) + + +class _InteractiveSession: + """Drives the DrawCanvas widget. Constructed by :meth:`PlotAccessor.annotate`.""" + + def __init__( + self, + sdata: sd.SpatialData, + coordinate_system: str, + element: str, + *, + persist: bool = True, + max_width: int = 880, + ) -> None: + _validate(sdata, coordinate_system, element) + + self._sdata = sdata + self._cs = coordinate_system + self._element = element + self._persist_enabled = persist + self._max_width = max_width + self.canvas: DrawCanvas | None = None + self._extent: RenderExtent | None = None + self._commits: list[str] = [] + + self._style = W.HTML(value=_CSS) + + icon_btn_layout = W.Layout(width="36px") + self.tool_tb = W.ToggleButtons( + options=[("Rect", "rectangle"), ("Polygon", "polygon"), ("Lasso", "lasso")], + value="rectangle", + description="", + tooltips=["R", "P", "L"], + ) + self.tool_tb.observe(self._on_tool_change, names="value") + self.close_poly_btn = self._trigger_btn( + "", + "check", + "close_poly_trigger", + tooltip="Close polygon (Enter)", + disabled=True, + layout=icon_btn_layout, + ) + self.undo_btn = self._trigger_btn( + "", + "rotate-left", + "undo_trigger", + tooltip="Undo (Ctrl+Z)", + disabled=True, + layout=icon_btn_layout, + ) + self.clear_btn = self._trigger_btn( + "", + "trash", + "clear_trigger", + tooltip="Clear canvas", + after=lambda: self._set_banner("Canvas cleared.", "info"), + layout=icon_btn_layout, + ) + self.fit_btn = self._trigger_btn( + "", + "compress", + "fit_trigger", + tooltip="Fit view (F)", + layout=icon_btn_layout, + ) + + self.name_tx = W.Text( + value="", + placeholder="name…", + layout=W.Layout(flex="1 1 140px", min_width="100px"), + ) + self.save_btn = W.Button( + description="Save", + button_style="success", + icon="save", + tooltip="Save shapes to sdata.shapes[name]", + ) + self.save_btn.on_click(self._on_save) + + save_row_widgets: list[W.Widget] = [self.name_tx, self.save_btn] + self.persist_btn: W.Button | None = None + if persist: + self.persist_btn = W.Button( + description="", + icon="hdd-o", + button_style="warning", + tooltip="Write last save to disk", + layout=icon_btn_layout, + ) + self.persist_btn.on_click(self._on_persist) + self.persist_btn.disabled = True + save_row_widgets.append(self.persist_btn) + + self.banner = W.HTML( + value=_fmt_banner( + f"Annotating {element!r} in coordinate system {coordinate_system!r}. " + "Pick a tool and draw. Click canvas first so keyboard shortcuts work. " + "R/P/L tools · Wheel zoom · Shift+drag pan · " + "Alt+click shape to delete · Ctrl+Z undo · F fit", + "hint", + ) + ) + self.plot_box = W.VBox([]) + + row_layout = W.Layout( + display="flex", + flex_flow="row wrap", + align_items="center", + gap="6px", + ) + card_layout = W.Layout(max_width=f"{max_width}px", width="100%") + toolbar = W.Box( + children=[self.tool_tb, self.close_poly_btn, self.undo_btn, self.clear_btn, self.fit_btn], + layout=row_layout, + ) + save_row = W.Box(children=save_row_widgets, layout=row_layout) + + controls_card = W.VBox( + [ + W.HTML( + value=( + f"
Annotate
" + f"
{element!r} · {coordinate_system!r}
" + ) + ), + toolbar, + save_row, + self.banner, + ], + layout=card_layout, + ) + controls_card.add_class("sdp-card") + + canvas_card = W.VBox([self.plot_box], layout=card_layout) + canvas_card.add_class("sdp-card") + + self.controls = W.VBox([self._style, controls_card, canvas_card]) + + def show(self) -> None: + self._render() + display(self.controls) + + def _set_banner(self, msg: str, kind: BannerKind = "info") -> None: + self.banner.value = _fmt_banner(msg, kind) + + def _trigger_btn( + self, + description: str, + icon: str, + trait_name: str, + *, + tooltip: str = "", + disabled: bool = False, + after: Any = None, + layout: W.Layout | None = None, + ) -> W.Button: + btn = W.Button(description=description, icon=icon, tooltip=tooltip) + btn.disabled = disabled + if layout is not None: + btn.layout = layout + + def _on_click(_b: W.Button) -> None: + if self.canvas is None: + return + setattr(self.canvas, trait_name, getattr(self.canvas, trait_name) + 1) + if after is not None: + after() + + btn.on_click(_on_click) + return btn + + def _render(self) -> None: + png_bytes, extent = render_to_png(self._sdata, self._element, self._cs) + data_url = "data:image/png;base64," + base64.b64encode(png_bytes).decode("ascii") + self._extent = extent + + self.canvas = DrawCanvas( + image_url=data_url, + image_width=extent.image_w, + image_height=extent.image_h, + max_display_width=self._max_width, + tool=self.tool_tb.value, + ) + self.canvas.observe(self._on_shapes_change, names="shapes") + self.plot_box.children = (self.canvas,) + + def _on_shapes_change(self, change: dict[str, Any]) -> None: + shapes = change["new"] or [] + self.undo_btn.disabled = len(shapes) == 0 + + def _on_tool_change(self, change: dict[str, Any]) -> None: + if self.canvas is None: + return + self.canvas.tool = change["new"] + self.close_poly_btn.disabled = change["new"] != "polygon" + self._set_banner(f"Tool: {change['new']}", "info") + + def _collect_polygons(self) -> list[Polygon]: + assert self.canvas is not None + assert self._extent is not None + polys: list[Polygon] = [] + for sh in self.canvas.shapes: + p = pixel_shape_to_polygon(sh, self._extent) + if p is not None: + polys.append(p) + return polys + + def _commit_polygons(self, polys: list[Polygon], name: str) -> str: + shapes_model = build_shapes_model(polys, self._cs) + target = commit_to_memory(self._sdata, shapes_model, name) + self._commits.append(target) + return target + + def _reset_canvas_state(self) -> None: + assert self.canvas is not None + self.canvas.clear_trigger += 1 + + def _on_save(self, _btn: W.Button) -> None: + name = self.name_tx.value.strip() + if not name: + self._set_banner("Name is required.", "error") + return + if self.canvas is None or not self.canvas.shapes: + self._set_banner("No shapes drawn yet.", "error") + return + + polys = self._collect_polygons() + if not polys: + self._set_banner( + f"{len(self.canvas.shapes)} shape(s) on canvas but none parsed as valid polygons.", + "error", + ) + return + + target = self._commit_polygons(polys, name) + self._reset_canvas_state() + + self._set_banner(f"Saved {target!r} with {len(polys)} polygon(s).", "success") + if self.persist_btn is not None: + self.persist_btn.disabled = self._sdata.path is None + + def _on_persist(self, _btn: W.Button) -> None: + if not self._persist_enabled: + return + if not self._commits: + self._set_banner("Nothing saved this session yet.", "error") + return + target = self._commits[-1] + try: + self._sdata.write_element(target, overwrite=True) + except (ValueError, OSError) as exc: + self._set_banner(str(exc), "error") + return + self._set_banner(f"Persisted {target!r} → {self._sdata.path}", "success") diff --git a/src/spatialdata_plot/pl/interactive/static/draw_canvas.js b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js new file mode 100644 index 00000000..73d968b4 --- /dev/null +++ b/src/spatialdata_plot/pl/interactive/static/draw_canvas.js @@ -0,0 +1,495 @@ +// anywidget ESM for spatialdata_plot.pl.interactive.DrawCanvas. +// Pure client-side drawing on an SVG overlay above the rendered image PNG. +// Shape geometry (in image-pixel coordinates) is synced back to Python via +// the `shapes` traitlet; conversion to data/CS coords happens server-side. + +function render({ model, el }) { + const W = model.get("image_width"); + const H = model.get("image_height"); + const maxW = model.get("max_display_width") || 880; + + const wrap = document.createElement("div"); + wrap.style.cssText = ` + display: block; + width: 100%; + max-width: ${maxW}px; + background: #18181b; + padding: 6px; + border-radius: 10px; + box-shadow: 0 2px 6px rgba(0,0,0,0.08); + box-sizing: border-box; + `; + const container = document.createElement("div"); + container.style.cssText = ` + position: relative; + width: 100%; + aspect-ratio: ${W} / ${H}; + user-select: none; + background: #000; + border-radius: 6px; + overflow: hidden; + `; + wrap.appendChild(container); + + const img = document.createElement("img"); + img.src = model.get("image_url"); + img.style.cssText = ` + position: absolute; inset: 0; width: 100%; height: 100%; + pointer-events: none; + `; + img.draggable = false; + container.appendChild(img); + + const svgNS = "http://www.w3.org/2000/svg"; + const svg = document.createElementNS(svgNS, "svg"); + svg.style.cssText = ` + position: absolute; inset: 0; width: 100%; height: 100%; + cursor: crosshair; touch-action: none; + `; + svg.setAttribute("preserveAspectRatio", "none"); + container.appendChild(svg); + + el.appendChild(wrap); + + let shapes = []; + let drawing = null; + let drawingNode = null; // SVG node for the in-progress shape; updated in mousemove + let pendingPoly = null; + let hoverIndex = -1; + let vbox = { x: 0, y: 0, w: W, h: H }; + const SNAP_PX = 10; + const LASSO_MIN_PX = 1; // viewbox-px gate on lasso vert push + + function applyViewbox() { + const sx = W / vbox.w; + const sy = H / vbox.h; + img.style.transformOrigin = "0 0"; + img.style.transform = `scale(${sx}, ${sy}) translate(${-vbox.x}px, ${-vbox.y}px)`; + svg.setAttribute("viewBox", `${vbox.x} ${vbox.y} ${vbox.w} ${vbox.h}`); + } + applyViewbox(); + + function setShapes(next) { + if (next === shapes) return; + if (next.length === 0 && shapes.length === 0) return; + shapes = next; + model.set("shapes", shapes); + model.save_changes(); + } + + function popLastShape() { + if (shapes.length === 0) return; + setShapes(shapes.slice(0, -1)); + } + + function getXY(e) { + const r = svg.getBoundingClientRect(); + const fx = (e.clientX - r.left) / r.width; + const fy = (e.clientY - r.top) / r.height; + return [vbox.x + fx * vbox.w, vbox.y + fy * vbox.h]; + } + + function vboxScalePerSvgPx() { + return vbox.w / svg.getBoundingClientRect().width; + } + + function makeEl(tag, attrs) { + const n = document.createElementNS(svgNS, tag); + for (const k in attrs) n.setAttribute(k, attrs[k]); + return n; + } + + function pointsAttr(verts) { + return verts.map((v) => v.join(",")).join(" "); + } + + function shapeNode(s, color, opts) { + opts = opts || {}; + const common = { + stroke: color, + "stroke-width": opts.lw || 2, + "vector-effect": "non-scaling-stroke", + "stroke-dasharray": opts.dashed ? "6,4" : "", + }; + const fillOp = opts.fillOp == null ? 0.15 : opts.fillOp; + if (s.type === "rect") { + const [x0, y0] = s.verts[0]; + const [x1, y1] = s.verts[2]; + return makeEl("rect", { + ...common, + x: Math.min(x0, x1), + y: Math.min(y0, y1), + width: Math.abs(x1 - x0), + height: Math.abs(y1 - y0), + fill: color, + "fill-opacity": fillOp, + }); + } + if (s.type === "polygon") { + return makeEl("polygon", { + ...common, + points: pointsAttr(s.verts), + fill: color, + "fill-opacity": fillOp, + }); + } + if (s.type === "polyline") { + return makeEl("polyline", { + ...common, + points: pointsAttr(s.verts), + fill: "none", + }); + } + return null; + } + + function updateDrawingNode() { + if (!drawing || !drawingNode) return; + if (drawing.type === "rect") { + const [x0, y0] = drawing.verts[0]; + const [x1, y1] = drawing.verts[2]; + drawingNode.setAttribute("x", Math.min(x0, x1)); + drawingNode.setAttribute("y", Math.min(y0, y1)); + drawingNode.setAttribute("width", Math.abs(x1 - x0)); + drawingNode.setAttribute("height", Math.abs(y1 - y0)); + } else { + drawingNode.setAttribute("points", pointsAttr(drawing.verts)); + } + } + + function distPx(a, b) { + return Math.hypot(a[0] - b[0], a[1] - b[1]); + } + + function shouldSnapClosePoly(e) { + if (!pendingPoly || pendingPoly.verts.length < 3) return false; + const r = svg.getBoundingClientRect(); + const fx = pendingPoly.verts[0][0]; + const fy = pendingPoly.verts[0][1]; + const cx = r.left + ((fx - vbox.x) / vbox.w) * r.width; + const cy = r.top + ((fy - vbox.y) / vbox.h) * r.height; + return distPx([e.clientX, e.clientY], [cx, cy]) <= SNAP_PX; + } + + function attachShapeListeners(n, i) { + n.style.cursor = "pointer"; + n.dataset.idx = String(i); + n.addEventListener("mouseenter", () => { + hoverIndex = i; + redraw(); + }); + n.addEventListener("mouseleave", () => { + if (hoverIndex === i) { + hoverIndex = -1; + redraw(); + } + }); + n.addEventListener("click", (ev) => { + if (ev.altKey) { + const next = shapes.slice(); + next.splice(i, 1); + hoverIndex = -1; + setShapes(next); + ev.stopPropagation(); + } + }); + } + + function redraw() { + while (svg.firstChild) svg.removeChild(svg.firstChild); + drawingNode = null; + shapes.forEach((s, i) => { + const isHover = i === hoverIndex; + const n = shapeNode(s, isHover ? "#fb923c" : "#22d3ee", { + lw: isHover ? 3 : 2, + fillOp: isHover ? 0.25 : 0.15, + }); + if (n) { + attachShapeListeners(n, i); + svg.appendChild(n); + } + }); + if (drawing) { + const n = shapeNode(drawing, "#ec4899", { dashed: true }); + if (n) { + n.style.pointerEvents = "none"; + svg.appendChild(n); + drawingNode = n; + } + } + if (pendingPoly && pendingPoly.verts.length > 0) { + const px = vboxScalePerSvgPx(); + const rPx = 5 * px; + pendingPoly.verts.forEach(([x, y], i) => { + const c = makeEl("circle", { + cx: x, + cy: y, + r: i === 0 ? rPx * 1.3 : rPx, + fill: i === 0 ? "#facc15" : "#ec4899", + stroke: "white", + "stroke-width": 1.5 * px, + "vector-effect": "non-scaling-stroke", + }); + c.style.pointerEvents = "none"; + svg.appendChild(c); + }); + } + } + + function commitPendingPolygon() { + if (pendingPoly && pendingPoly.verts.length >= 3) { + setShapes([ + ...shapes, + { type: "polygon", verts: pendingPoly.verts }, + ]); + } + pendingPoly = null; + drawing = null; + redraw(); + } + + function vboxEq(a, b) { + return a.x === b.x && a.y === b.y && a.w === b.w && a.h === b.h; + } + + function zoomAt(clientX, clientY, factor) { + const r = svg.getBoundingClientRect(); + const fx = (clientX - r.left) / r.width; + const fy = (clientY - r.top) / r.height; + const px = vbox.x + fx * vbox.w; + const py = vbox.y + fy * vbox.h; + let newW = vbox.w / factor; + let newH = vbox.h / factor; + const minW = Math.max(5, W * 0.02); + const minH = Math.max(5, H * 0.02); + if (newW < minW) newW = minW; + if (newH < minH) newH = minH; + if (newW > W) { + newW = W; + newH = H; + } + const next = { x: px - fx * newW, y: py - fy * newH, w: newW, h: newH }; + clampVboxObj(next); + if (vboxEq(vbox, next)) return; + vbox = next; + applyViewbox(); + redraw(); + } + function panBy(dxClient, dyClient) { + const r = svg.getBoundingClientRect(); + const next = { + x: vbox.x - dxClient * (vbox.w / r.width), + y: vbox.y - dyClient * (vbox.h / r.height), + w: vbox.w, + h: vbox.h, + }; + clampVboxObj(next); + if (vboxEq(vbox, next)) return; + vbox = next; + applyViewbox(); + redraw(); + } + function clampVboxObj(v) { + if (v.x < 0) v.x = 0; + if (v.y < 0) v.y = 0; + if (v.x + v.w > W) v.x = W - v.w; + if (v.y + v.h > H) v.y = H - v.h; + } + function fitView() { + const next = { x: 0, y: 0, w: W, h: H }; + if (vboxEq(vbox, next)) return; + vbox = next; + applyViewbox(); + redraw(); + } + + let panStart = null; // panning iff panStart !== null + + function onWheel(e) { + e.preventDefault(); + const factor = e.deltaY < 0 ? 1.2 : 1 / 1.2; + zoomAt(e.clientX, e.clientY, factor); + } + + function onMouseDown(e) { + if (e.button === 1 || (e.button === 0 && e.shiftKey)) { + panStart = [e.clientX, e.clientY]; + svg.style.cursor = "grabbing"; + e.preventDefault(); + return; + } + if (e.button !== 0) return; + svg.focus(); + const tool = model.get("tool"); + if (tool === "polygon" && shouldSnapClosePoly(e)) { + commitPendingPolygon(); + e.preventDefault(); + return; + } + const [x, y] = getXY(e); + if (tool === "rectangle") { + drawing = { + type: "rect", + verts: [ + [x, y], + [x, y], + [x, y], + [x, y], + ], + }; + redraw(); + } else if (tool === "lasso") { + drawing = { type: "polygon", verts: [[x, y]] }; + redraw(); + } else if (tool === "polygon") { + if (!pendingPoly) pendingPoly = { type: "polygon", verts: [] }; + pendingPoly.verts.push([x, y]); + drawing = { type: "polyline", verts: pendingPoly.verts }; + redraw(); + } + e.preventDefault(); + } + + function onMouseMove(e) { + if (panStart !== null) { + const dx = e.clientX - panStart[0]; + const dy = e.clientY - panStart[1]; + panStart = [e.clientX, e.clientY]; + panBy(dx, dy); + return; + } + if (!drawing) return; + const tool = model.get("tool"); + const [x, y] = getXY(e); + if (tool === "rectangle") { + const [x0, y0] = drawing.verts[0]; + drawing.verts = [ + [x0, y0], + [x, y0], + [x, y], + [x0, y], + ]; + updateDrawingNode(); + } else if (tool === "lasso") { + const last = drawing.verts[drawing.verts.length - 1]; + if (distPx(last, [x, y]) >= LASSO_MIN_PX) { + drawing.verts.push([x, y]); + updateDrawingNode(); + } + } + } + + function onMouseUp(e) { + if (panStart !== null) { + panStart = null; + svg.style.cursor = "crosshair"; + return; + } + const tool = model.get("tool"); + if (tool === "rectangle" && drawing) { + const [[x0, y0], , [x1, y1]] = drawing.verts; + if (Math.abs(x1 - x0) >= 2 && Math.abs(y1 - y0) >= 2) { + setShapes([...shapes, { type: "rect", verts: drawing.verts }]); + } + drawing = null; + redraw(); + } else if (tool === "lasso" && drawing && drawing.verts.length >= 3) { + setShapes([...shapes, { type: "polygon", verts: drawing.verts }]); + drawing = null; + redraw(); + } + } + + function onKeyDown(e) { + const tool = model.get("tool"); + if (e.key === "r" || e.key === "R") { + model.set("tool", "rectangle"); + model.save_changes(); + e.preventDefault(); + return; + } + if (e.key === "p" || e.key === "P") { + model.set("tool", "polygon"); + model.save_changes(); + e.preventDefault(); + return; + } + if (e.key === "l" || e.key === "L") { + model.set("tool", "lasso"); + model.save_changes(); + e.preventDefault(); + return; + } + if (e.key === "f" || e.key === "F") { + fitView(); + e.preventDefault(); + return; + } + if (e.key === "Enter") { + if (tool === "polygon" && pendingPoly) commitPendingPolygon(); + e.preventDefault(); + return; + } + if (e.key === "Escape") { + pendingPoly = null; + drawing = null; + redraw(); + e.preventDefault(); + return; + } + if ((e.ctrlKey || e.metaKey) && (e.key === "z" || e.key === "Z")) { + popLastShape(); + e.preventDefault(); + return; + } + if (e.key === "Delete" || e.key === "Backspace") { + popLastShape(); + e.preventDefault(); + return; + } + } + + svg.tabIndex = 0; + svg.addEventListener("wheel", onWheel, { passive: false }); + svg.addEventListener("mousedown", onMouseDown); + svg.addEventListener("mousemove", onMouseMove); + svg.addEventListener("mouseup", onMouseUp); + svg.addEventListener("mouseleave", (e) => { + if (panStart === null) onMouseUp(e); + }); + svg.addEventListener("keydown", onKeyDown); + svg.addEventListener("contextmenu", (e) => e.preventDefault()); + + function updateCursor() { + svg.style.cursor = "crosshair"; + svg.title = `Tool: ${model.get("tool")}. R/P/L: tools. Enter: close poly. Esc: cancel. Ctrl+Z: undo. Alt+click shape: delete. Wheel: zoom. Shift+drag: pan. F: fit.`; + } + updateCursor(); + + model.on("change:tool", () => { + const hadInProgress = pendingPoly !== null || drawing !== null; + pendingPoly = null; + drawing = null; + updateCursor(); + if (hadInProgress) redraw(); + }); + model.on("change:clear_trigger", () => { + const hadInProgress = drawing !== null || pendingPoly !== null; + drawing = null; + pendingPoly = null; + if (shapes.length === 0 && !hadInProgress) return; + setShapes([]); + redraw(); + }); + model.on("change:close_poly_trigger", () => { + commitPendingPolygon(); + }); + model.on("change:undo_trigger", () => { + popLastShape(); + }); + model.on("change:fit_trigger", () => { + fitView(); + }); +} + +export default { render }; diff --git a/tests/test_interactive/__init__.py b/tests/test_interactive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_interactive/test_annotate.py b/tests/test_interactive/test_annotate.py new file mode 100644 index 00000000..7416ded0 --- /dev/null +++ b/tests/test_interactive/test_annotate.py @@ -0,0 +1,59 @@ +"""Tests for the user-facing sdata.pl.annotate() validation paths.""" + +from __future__ import annotations + +import numpy as np +import pytest +import spatialdata as sd + +pytest.importorskip("anywidget") +pytest.importorskip("ipywidgets") + +from spatialdata.models import Image2DModel +from spatialdata.transformations.transformations import Identity + +import spatialdata_plot # noqa: F401 registers .pl + + +@pytest.fixture +def no_display(monkeypatch): + monkeypatch.setattr( + "spatialdata_plot.pl.interactive._session._InteractiveSession.show", + lambda self: None, + ) + + +def _make_sdata_with_image() -> sd.SpatialData: + arr = np.random.default_rng(0).integers(0, 255, size=(3, 32, 32), dtype=np.uint8) + img = Image2DModel.parse(arr, dims=("c", "y", "x")) + return sd.SpatialData(images={"img": img}) + + +def test_annotate_unknown_coordinate_system_raises(no_display): + sdata = _make_sdata_with_image() + with pytest.raises(ValueError, match="Unknown coordinate system"): + sdata.pl.annotate("does_not_exist", "img") + + +def test_annotate_unknown_element_raises(no_display): + sdata = _make_sdata_with_image() + with pytest.raises(ValueError, match="Unknown image element"): + sdata.pl.annotate("global", "no_such_image") + + +def test_annotate_element_not_in_cs_raises(no_display): + rng = np.random.default_rng(0) + arr = rng.integers(0, 255, size=(3, 32, 32), dtype=np.uint8) + img = Image2DModel.parse( + arr, + dims=("c", "y", "x"), + transformations={"other_cs": Identity()}, + ) + anchor = Image2DModel.parse( + rng.integers(0, 255, size=(3, 32, 32), dtype=np.uint8), + dims=("c", "y", "x"), + transformations={"global": Identity()}, + ) + sdata = sd.SpatialData(images={"img": img, "anchor": anchor}) + with pytest.raises(ValueError, match="not registered in coordinate system"): + sdata.pl.annotate("global", "img") diff --git a/tests/test_interactive/test_canvas.py b/tests/test_interactive/test_canvas.py new file mode 100644 index 00000000..a6c7ea64 --- /dev/null +++ b/tests/test_interactive/test_canvas.py @@ -0,0 +1,49 @@ +"""Smoke tests for the DrawCanvas anywidget class.""" + +from __future__ import annotations + +import pytest + +pytest.importorskip("anywidget") +pytest.importorskip("ipywidgets") + + +def test_draw_canvas_imports(): + from spatialdata_plot.pl.interactive._canvas import DrawCanvas + + assert DrawCanvas is not None + + +def test_draw_canvas_default_traitlets(): + from spatialdata_plot.pl.interactive._canvas import DrawCanvas + + c = DrawCanvas() + assert c.tool == "rectangle" + assert c.shapes == [] + assert c.image_width == 720 + assert c.image_height == 720 + assert c.max_display_width == 880 + assert c.clear_trigger == 0 + assert c.close_poly_trigger == 0 + assert c.undo_trigger == 0 + assert c.fit_trigger == 0 + + +def test_draw_canvas_esm_file_is_bundled(): + """The ESM module file must ship with the package.""" + from spatialdata_plot.pl.interactive import _canvas + + assert _canvas._ESM_PATH.exists(), f"{_canvas._ESM_PATH} not bundled" + assert _canvas._ESM_PATH.suffix == ".js" + assert _canvas._ESM_PATH.stat().st_size > 0 + + +def test_draw_canvas_traitlet_assignment(): + """Setting traitlets from Python should work (Python → JS sync).""" + from spatialdata_plot.pl.interactive._canvas import DrawCanvas + + c = DrawCanvas() + c.tool = "polygon" + assert c.tool == "polygon" + c.clear_trigger += 1 + assert c.clear_trigger == 1 diff --git a/tests/test_interactive/test_commit.py b/tests/test_interactive/test_commit.py new file mode 100644 index 00000000..87ba5357 --- /dev/null +++ b/tests/test_interactive/test_commit.py @@ -0,0 +1,70 @@ +"""Tests for pixel-coord → CS-coord conversion and ShapesModel construction.""" + +from __future__ import annotations + +from shapely.geometry import Polygon +from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Identity + +from spatialdata_plot.pl.interactive._commit import ( + build_shapes_model, + pixel_shape_to_polygon, +) +from spatialdata_plot.pl.interactive._render import RenderExtent + + +def _extent(w=100, h=100, xlim=(0.0, 100.0), ylim=(100.0, 0.0)) -> RenderExtent: + return RenderExtent(w, h, xlim, ylim) + + +def test_rect_maps_to_full_cs_extent(): + shape = {"type": "rect", "verts": [[0, 0], [100, 0], [100, 100], [0, 100]]} + poly = pixel_shape_to_polygon(shape, _extent(xlim=(0.0, 50.0), ylim=(50.0, 0.0))) + assert poly.bounds == (0.0, 0.0, 50.0, 50.0) + + +def test_rect_subregion(): + shape = {"type": "rect", "verts": [[25, 25], [75, 25], [75, 75], [25, 75]]} + poly = pixel_shape_to_polygon(shape, _extent()) + assert poly.bounds == (25.0, 25.0, 75.0, 75.0) + + +def test_y_axis_orientation_matplotlib_image(): + shape = {"type": "polygon", "verts": [[0, 0], [10, 0], [10, 10], [0, 10]]} + poly = pixel_shape_to_polygon(shape, _extent()) + assert poly.bounds == (0.0, 0.0, 10.0, 10.0) + + +def test_y_axis_non_reversed_ylim(): + shape = {"type": "polygon", "verts": [[0, 0], [10, 0], [10, 10], [0, 10]]} + poly = pixel_shape_to_polygon(shape, _extent(ylim=(0.0, 100.0))) + assert poly.bounds == (0.0, 0.0, 10.0, 10.0) + + +def test_invalid_shapes_return_none(): + ext = _extent(xlim=(0, 1), ylim=(0, 1)) + assert pixel_shape_to_polygon({"type": "polygon", "verts": []}, ext) is None + assert pixel_shape_to_polygon({"type": "polygon", "verts": [[1, 1], [2, 2]]}, ext) is None + assert pixel_shape_to_polygon({}, ext) is None + + +def test_lasso_simplification_for_high_vertex_count(): + n = 200 + verts = ( + [[i, 0] for i in range(n)] + + [[n, j] for j in range(n)] + + [[n - i, n] for i in range(n)] + + [[0, n - j] for j in range(n)] + ) + shape = {"type": "polygon", "verts": verts} + poly = pixel_shape_to_polygon(shape, _extent(w=n, h=n, xlim=(0.0, float(n)), ylim=(float(n), 0.0))) + assert len(poly.exterior.coords) < 4 * n + assert poly.bounds == (0.0, 0.0, float(n), float(n)) + + +def test_build_shapes_model_registers_identity_transform(): + polys = [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])] + sm = build_shapes_model(polys, "my_cs") + transforms = get_transformation(sm, get_all=True) + assert "my_cs" in transforms + assert isinstance(transforms["my_cs"], Identity) diff --git a/tests/test_interactive/test_persist.py b/tests/test_interactive/test_persist.py new file mode 100644 index 00000000..357781d4 --- /dev/null +++ b/tests/test_interactive/test_persist.py @@ -0,0 +1,33 @@ +"""Tests for the in-memory commit policy.""" + +from __future__ import annotations + +import geopandas as gpd +import spatialdata as sd +from shapely.geometry import Polygon +from spatialdata.models import ShapesModel +from spatialdata.transformations.transformations import Identity + +from spatialdata_plot.pl.interactive._persist import commit_to_memory + + +def _make_shape() -> ShapesModel: + gdf = gpd.GeoDataFrame({"geometry": [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]}) + return ShapesModel.parse(gdf, transformations={"global": Identity()}) + + +def test_commit_to_memory_stores_under_name(): + sdata = sd.SpatialData() + target = commit_to_memory(sdata, _make_shape(), "tumor_region") + assert target == "tumor_region" + assert "tumor_region" in sdata.shapes + + +def test_commit_to_memory_overwrites_on_collision(): + sdata = sd.SpatialData() + first = _make_shape() + sdata.shapes["tumor_region"] = first + second = _make_shape() + target = commit_to_memory(sdata, second, "tumor_region") + assert target == "tumor_region" + assert sdata.shapes["tumor_region"] is second diff --git a/tests/test_interactive/test_render.py b/tests/test_interactive/test_render.py new file mode 100644 index 00000000..ecb12dbe --- /dev/null +++ b/tests/test_interactive/test_render.py @@ -0,0 +1,36 @@ +"""Smoke test for the matplotlib → PNG render path.""" + +from __future__ import annotations + +from io import BytesIO + +import numpy as np +import spatialdata as sd +from PIL import Image + +from spatialdata_plot.pl.interactive._render import _IMAGE_H, _IMAGE_W, render_to_png + + +def _make_sdata_with_image() -> sd.SpatialData: + from spatialdata.models import Image2DModel + + arr = np.random.default_rng(0).integers(0, 255, size=(3, 64, 64), dtype=np.uint8) + img = Image2DModel.parse(arr, dims=("c", "y", "x")) + return sd.SpatialData(images={"img": img}) + + +def test_render_to_png_returns_valid_png(): + sdata = _make_sdata_with_image() + png_bytes, extent = render_to_png(sdata, "img", "global") + assert png_bytes.startswith(b"\x89PNG\r\n\x1a\n") + decoded = Image.open(BytesIO(png_bytes)) + assert decoded.size == (extent.image_w, extent.image_h) == (_IMAGE_W, _IMAGE_H) + + +def test_render_to_png_returns_extent_matching_image(): + sdata = _make_sdata_with_image() + _, extent = render_to_png(sdata, "img", "global") + # Image2DModel(c=3, y=64, x=64) with no transformations: xlim covers + # roughly [0, 64], ylim is reversed under origin='upper'. + assert extent.xlim[0] <= 0 and extent.xlim[1] >= 63 + assert extent.ylim[0] >= 63 and extent.ylim[1] <= 0