Skip to content

Commit bfe97ea

Browse files
committed
fix(render): align rasterized artists with get_extent (pixel-edge extent)
Images, labels, and datashader output were drawn with matplotlib's default pixel-center imshow extent (-0.5, W-0.5, ...), placing them half a pixel off the world coordinates used for the axis limits (get_extent), the affine box, and matplotlib point/shape overlays. The affine's linear part amplifies the constant 0.5 offset, so a Scale(1000) image shifted by 500 world units (#216). Switch to the pixel-edge convention (0, W, H, 0) so each artist's data box is the same [0, shape] box get_extent transforms — they now coincide under any affine. Labels are drawn outside the shared _ax_show_and_transform helper via a direct imshow(origin="lower"), so they get an explicit extent=(0, W, 0, H) (origin-lower order) to match. This also removes the residual half-canvas-pixel offset of datashader points relative to the matplotlib backend. The pixel-edge convention is spatialdata's own: get_extent reports an image as occupying [0, shape], and get_centroids places a single pixel's centroid at its half-integer center — so a get_centroids overlay now lands dead-center on its label pixel instead of on the corner. Tests: - Non-visual regressions (test_utils.py) asserting the rendered world box equals get_extent for images and labels (Identity + Scale), and that the datashader points image occupies the points' extent. - Visual regression (test_render_labels.py) overlaying get_centroids on a small label grid, where the half-pixel shift is large in display pixels: dots sit at pixel centers after the fix, at pixel corners before.
1 parent 50a6606 commit bfe97ea

4 files changed

Lines changed: 91 additions & 4 deletions

File tree

src/spatialdata_plot/pl/_datashader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,10 @@ def _ax_show_and_transform(
730730
norm: Normalize | None = None,
731731
interpolation: str | None = None,
732732
) -> matplotlib.image.AxesImage:
733-
# ``extent`` uses mpl's pixel-grid convention; world placement happens via
734-
# ``set_transform(trans_data)`` afterwards.
735-
image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5)
733+
# Pixel-edge extent [0, W] x [0, H], matching get_extent (which sets the axis limits)
734+
# and the affine's data box (placement is via set_transform below). mpl's default
735+
# pixel-center extent (-0.5, W-0.5, ...) offsets by half a pixel, amplified by the affine.
736+
image_extent = (0.0, array.shape[1], array.shape[0], 0.0)
736737
# ``alpha`` is applied only when no cmap is set, so RGBA arrays already
737738
# carrying per-pixel alpha (e.g. datashader output) are not double-attenuated.
738739
imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm}

src/spatialdata_plot/pl/render.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2400,7 +2400,16 @@ def _draw_labels(
24002400
# non-linear norm (LogNorm/PowerNorm). Display the RGB without a norm and build the
24012401
# continuous colorbar mappable separately from the resolved norm (mirrors the outline path),
24022402
# so the colorbar reflects the real norm subclass.
2403-
img = ax.imshow(labels, rasterized=True, alpha=alpha, origin="lower", zorder=render_params.zorder)
2403+
# Pixel-edge extent (0, W, 0, H) matching get_extent/the affine box; mpl's default
2404+
# pixel-center extent would offset labels half a pixel. Order follows origin="lower".
2405+
img = ax.imshow(
2406+
labels,
2407+
rasterized=True,
2408+
alpha=alpha,
2409+
origin="lower",
2410+
extent=(0.0, labels.shape[1], 0.0, labels.shape[0]),
2411+
zorder=render_params.zorder,
2412+
)
24042413
img.set_transform(trans_data)
24052414
if color_spec.is_categorical:
24062415
return img

tests/pl/test_render_labels.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ def test_plot_labels_render_permutations(self, sdata_blobs: SpatialData):
6161
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", colorbar=False, **kw).pl.show(ax=ax)
6262
ax.set_title(title, fontsize=8)
6363

64+
def test_plot_label_centroids_sit_at_pixel_centers(self):
65+
# Regression for #216: on a tiny grid each data-pixel spans many display pixels, so a
66+
# half-pixel image/overlay shift is blatant. Centroids (spatialdata's pixel-edge
67+
# convention) must sit dead-center on their label pixels, not on the pixel corners.
68+
from spatialdata import get_centroids
69+
from spatialdata.models import PointsModel
70+
from spatialdata.transformations import Identity
71+
72+
arr = np.zeros((6, 6), dtype=np.int32)
73+
for label, (row, col) in enumerate([(1, 1), (1, 4), (4, 1), (4, 4), (2, 3)], start=1):
74+
arr[row, col] = label
75+
sdata = SpatialData(labels={"lab": Labels2DModel.parse(arr, dims=("y", "x"))})
76+
centroids = get_centroids(sdata["lab"], coordinate_system="global").compute()
77+
sdata["centroids"] = PointsModel.parse(centroids[["x", "y"]], transformations={"global": Identity()})
78+
sdata.pl.render_labels("lab").pl.render_points("centroids", color="red", size=100).pl.show()
79+
6480
def test_plot_can_render_labels(self, sdata_blobs: SpatialData):
6581
sdata_blobs.pl.render_labels(element="blobs_labels").pl.show()
6682

tests/pl/test_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,3 +1286,64 @@ def test_first_color_per_category_skips_nan_and_handles_numeric_categories():
12861286
source = pd.Categorical([1, None, 2, 1], categories=[1, 2])
12871287
cv = ["#aaaaaa", "#ffffff", "#bbbbbb", "#aaaaaa"]
12881288
assert _first_color_per_category(source, cv) == {1: "#aaaaaa", 2: "#bbbbbb"}
1289+
1290+
1291+
def _rendered_image_world_box(im, ax) -> tuple[float, float, float, float]:
1292+
"""World-space (x0, x1, y0, y1) the AxesImage's pixel grid actually occupies."""
1293+
left, right, bottom, top = im.get_extent()
1294+
element_affine = im.get_transform() - ax.transData
1295+
(x0, y0), (x1, y1) = element_affine.transform([(left, bottom), (right, top)])
1296+
return float(x0), float(x1), float(y0), float(y1)
1297+
1298+
1299+
@pytest.mark.parametrize("transform_name", ["identity", "scale"])
1300+
@pytest.mark.parametrize("element", ["image", "labels"])
1301+
def test_rasterized_artist_aligns_with_get_extent(transform_name, element):
1302+
# Regression test for #216: a rasterized image/labels artist must occupy the
1303+
# same world box as get_extent (which sets the axis limits). The old pixel-center
1304+
# extent left it half a pixel off, amplified by the affine (Scale -> hundreds of px).
1305+
from spatialdata import get_extent
1306+
from spatialdata.models import Image2DModel
1307+
from spatialdata.transformations import Identity, Scale
1308+
1309+
transform = Identity() if transform_name == "identity" else Scale([1000.0, 1000.0], axes=("x", "y"))
1310+
if element == "image":
1311+
el = Image2DModel.parse(np.zeros((1, 4, 8)), dims=("c", "y", "x"), transformations={"global": transform})
1312+
sdata = SpatialData(images={"el": el})
1313+
ax = sdata.pl.render_images().pl.show(return_ax=True)
1314+
else:
1315+
el = Labels2DModel.parse(
1316+
np.zeros((4, 8), dtype=np.int32), dims=("y", "x"), transformations={"global": transform}
1317+
)
1318+
sdata = SpatialData(labels={"el": el})
1319+
ax = sdata.pl.render_labels().pl.show(return_ax=True)
1320+
1321+
ext = get_extent(sdata["el"], coordinate_system="global")
1322+
x0, x1, y0, y1 = _rendered_image_world_box(ax.get_images()[0], ax)
1323+
assert (min(x0, x1), max(x0, x1)) == pytest.approx(ext["x"])
1324+
assert (min(y0, y1), max(y0, y1)) == pytest.approx(ext["y"])
1325+
plt.close("all")
1326+
1327+
1328+
def test_datashader_points_image_aligns_with_points_extent():
1329+
# Regression test for #216 (Sonja's case): the rasterized datashader-points image
1330+
# must occupy the points' world extent, matching where matplotlib scatters them.
1331+
from spatialdata.models import Image2DModel
1332+
from spatialdata.transformations import Identity
1333+
1334+
sdata = SpatialData(
1335+
images={"img": Image2DModel.parse(np.full((10, 10, 3), 128, dtype=np.uint8), dims=("y", "x", "c"))},
1336+
points={
1337+
"pts": PointsModel.parse(
1338+
pd.DataFrame({"x": [0.1, 0.9, 0.9, 0.1], "y": [0.1, 0.1, 0.9, 0.9]}),
1339+
transformations={"global": Identity()},
1340+
)
1341+
},
1342+
)
1343+
ax = sdata.pl.render_images().pl.render_points("pts", method="datashader", size=40).pl.show(return_ax=True)
1344+
1345+
# second image on the axis is the rasterized points (first is the background image)
1346+
x0, x1, y0, y1 = _rendered_image_world_box(ax.get_images()[1], ax)
1347+
assert (min(x0, x1), max(x0, x1)) == pytest.approx((0.1, 0.9))
1348+
assert (min(y0, y1), max(y0, y1)) == pytest.approx((0.1, 0.9))
1349+
plt.close("all")

0 commit comments

Comments
 (0)