Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,13 @@ def _set_color_source_vec(
table_name=table_name,
)

# When both the element's own dataframe and the chosen table contain a
# column with this name, an explicit `table_name=` resolves the ambiguity —
# keep only the table origin and skip the multi-origin error below.
explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins)
if explicit_table_shadows_df:
origins = [o for o in origins if o.origin != "df"]

if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
Expand All @@ -1094,6 +1101,15 @@ def _set_color_source_vec(
)
if preloaded_color_data is not None:
color_source_vector = preloaded_color_data
elif explicit_table_shadows_df:
# Pass the table as `element` so upstream `get_values` skips the
# element-column lookup and avoids the multi-origin error.
color_source_vector = get_values(
value_key=value_to_plot,
element=sdata[table_name],
element_name=element_name,
table_layer=table_layer,
)[value_to_plot]
else:
color_source_vector = get_values(
value_key=value_to_plot,
Expand Down Expand Up @@ -3170,9 +3186,9 @@ def _validate_col_for_column_table(
if col_for_color is None:
return None, None

if not labels and col_for_color in sdata[element_name].columns:
table_name = None
elif table_name is not None:
if not labels and col_for_color in sdata[element_name].columns and table_name is None:
return col_for_color, None
if table_name is not None:
tables = get_element_annotators(sdata, element_name)
if table_name not in tables:
logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.")
Expand Down
38 changes: 38 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,41 @@ def test_color_column_collision_on_annotating_table_raises():
sdata.pl.render_shapes("s", color="orange")

sdata.pl.render_shapes("s", color="#ffa500")


def test_explicit_table_name_honored_when_element_has_same_column():
# regression test for #620: explicit table_name= must not be silently
# discarded when the element has a same-named column with different values.
shapes = ShapesModel.parse(
gpd.GeoDataFrame(
{
"geometry": [Point(5, 5), Point(15, 5)],
"radius": [2.0, 2.0],
"cat": pd.Categorical(["X", "Y"]),
}
)
)
obs = pd.DataFrame(
{
"instance_id": [0, 1],
"region": pd.Categorical(["s1", "s1"]),
"cat": pd.Categorical(["A", "B"]),
}
)
table = TableModel.parse(
AnnData(X=np.zeros((2, 1)), obs=obs),
region=["s1"],
region_key="region",
instance_key="instance_id",
)
sdata = SpatialData(shapes={"s1": shapes}, tables={"t": table})

fig, ax = plt.subplots()
sdata.pl.render_shapes("s1", color="cat", table_name="t").pl.show(ax=ax)
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["A", "B"]
plt.close(fig)

fig, ax = plt.subplots()
sdata.pl.render_shapes("s1", color="cat").pl.show(ax=ax)
assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["X", "Y"]
plt.close(fig)
Loading