diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index dd0daee6..968fbe53 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1078,6 +1078,13 @@ def _set_color_source_vec( table_name=table_name, ) + # When both the element's own dataframe and the chosen table contain a + # column with this name, an explicit `table_name=` resolves the ambiguity — + # keep only the table origin and skip the multi-origin error below. + explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins) + if explicit_table_shadows_df: + origins = [o for o in origins if o.origin != "df"] + if len(origins) > 1: raise ValueError( f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. " @@ -1094,6 +1101,15 @@ def _set_color_source_vec( ) if preloaded_color_data is not None: color_source_vector = preloaded_color_data + elif explicit_table_shadows_df: + # Pass the table as `element` so upstream `get_values` skips the + # element-column lookup and avoids the multi-origin error. + color_source_vector = get_values( + value_key=value_to_plot, + element=sdata[table_name], + element_name=element_name, + table_layer=table_layer, + )[value_to_plot] else: color_source_vector = get_values( value_key=value_to_plot, @@ -3170,9 +3186,9 @@ def _validate_col_for_column_table( if col_for_color is None: return None, None - if not labels and col_for_color in sdata[element_name].columns: - table_name = None - elif table_name is not None: + if not labels and col_for_color in sdata[element_name].columns and table_name is None: + return col_for_color, None + if table_name is not None: tables = get_element_annotators(sdata, element_name) if table_name not in tables: logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.") diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 467a5efe..7f342ff7 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -420,3 +420,41 @@ def test_color_column_collision_on_annotating_table_raises(): sdata.pl.render_shapes("s", color="orange") sdata.pl.render_shapes("s", color="#ffa500") + + +def test_explicit_table_name_honored_when_element_has_same_column(): + # regression test for #620: explicit table_name= must not be silently + # discarded when the element has a same-named column with different values. + shapes = ShapesModel.parse( + gpd.GeoDataFrame( + { + "geometry": [Point(5, 5), Point(15, 5)], + "radius": [2.0, 2.0], + "cat": pd.Categorical(["X", "Y"]), + } + ) + ) + obs = pd.DataFrame( + { + "instance_id": [0, 1], + "region": pd.Categorical(["s1", "s1"]), + "cat": pd.Categorical(["A", "B"]), + } + ) + table = TableModel.parse( + AnnData(X=np.zeros((2, 1)), obs=obs), + region=["s1"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(shapes={"s1": shapes}, tables={"t": table}) + + fig, ax = plt.subplots() + sdata.pl.render_shapes("s1", color="cat", table_name="t").pl.show(ax=ax) + assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["A", "B"] + plt.close(fig) + + fig, ax = plt.subplots() + sdata.pl.render_shapes("s1", color="cat").pl.show(ax=ax) + assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["X", "Y"] + plt.close(fig)