Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,19 @@ def _reparse_points(
df: pd.DataFrame,
transformation: Any,
coordinate_system: str,
color_column: str | None = None,
) -> None:
"""Re-register a points DataFrame in *sdata_filt* with its transformation."""
"""Re-register a points DataFrame in *sdata_filt* with its transformation.

``PointsModel.parse`` silently drops columns whose names collide with
reserved coordinate axes (currently only ``"z"``). When ``color_column``
names such a column, re-attach it so downstream color lookup can find it.
"""
dd_frame = dask.dataframe.from_pandas(df, npartitions=1)
sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
parsed = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"})
if color_column is not None and color_column in df.columns and color_column not in parsed.columns:
parsed[color_column] = dd_frame[color_column]
sdata_filt.points[element] = parsed
set_transformation(
element=sdata_filt.points[element],
transformation=transformation,
Expand Down Expand Up @@ -820,7 +829,7 @@ def _render_points(

# Convert back to dask dataframe to modify sdata
transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system]
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system)
_reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color)

if col_for_color is not None:
assert isinstance(col_for_color, str)
Expand Down Expand Up @@ -877,6 +886,7 @@ def _render_points(
points_pd_with_color,
transformation_in_cs,
coordinate_system,
col_for_color,
)

_warn_groups_ignored_continuous(groups, color_source_vector, col_for_color)
Expand All @@ -897,7 +907,7 @@ def _render_points(
# filter the materialized points, adata, and re-register in sdata_filt
points = points[keep].reset_index(drop=True)
adata = adata[keep]
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system)
_reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color)

# color_source_vector is None when the values aren't categorical
if color_source_vector is None and render_params.transfunc is not None:
Expand Down
34 changes: 34 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,40 @@ def test_no_table_fallback_warning_for_element_column(caplog):
plt.close("all")


def test_render_points_color_by_z_data_column():
# regression test for #615
pts = PointsModel.parse(
pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "z": [0.1, 0.5, 0.9]}),
)
assert "z" in pts.columns
sdata = SpatialData(points={"p": pts})
fig, ax = plt.subplots()
try:
sdata.pl.render_points("p", color="z").pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_points_color_by_z_with_extra_columns():
# regression test for #615
pts = PointsModel.parse(
pd.DataFrame(
{
"x": [1.0, 2.0, 3.0],
"y": [1.0, 2.0, 3.0],
"z": [0.1, 0.5, 0.9],
"score": [0.0, 0.5, 1.0],
}
),
)
sdata = SpatialData(points={"p": pts})
fig, ax = plt.subplots()
try:
sdata.pl.render_points("p", color="score").pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_points_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
Expand Down
Loading