diff --git a/tests/test_figures.py b/tests/test_figures.py index 85ab92c..4226725 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -232,6 +232,20 @@ def test_frame_names_preserved(self) -> None: combined_names = {frame.name for frame in combined.frames} assert original_names == combined_names + def test_frame_layout_preserved(self) -> None: + """Test that frame layout (e.g., axis range) is preserved.""" + fig = xpx(self.da_3d).line(animation_frame="time", range_y=[0, 10]) + overlay_fig = xpx(self.da_3d).scatter(animation_frame="time") + + # Verify base has frame layout + assert fig.frames[0].layout is not None + + combined = overlay(fig, overlay_fig) + + # Frame layout should be preserved + for i, frame in enumerate(combined.frames): + assert frame.layout == fig.frames[i].layout + class TestOverlayFacetsAndAnimation: """Tests for overlay with both facets and animation.""" @@ -560,6 +574,20 @@ def test_mismatched_animation_frames_raises(self) -> None: with pytest.raises(ValueError, match="frame names don't match"): add_secondary_y(fig1, fig2) + def test_frame_layout_preserved(self) -> None: + """Test that frame layout (e.g., axis range) is preserved.""" + base = xpx(self.da_2d).line(animation_frame="time", range_y=[0, 10]) + secondary = xpx(self.da_2d).bar(animation_frame="time") + + # Verify base has frame layout + assert base.frames[0].layout is not None + + combined = add_secondary_y(base, secondary) + + # Frame layout should be preserved + for i, frame in enumerate(combined.frames): + assert frame.layout == base.frames[i].layout + class TestAddSecondaryYDeepCopy: """Tests to ensure add_secondary_y creates deep copies.""" diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 939d10d..70bf660 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -142,6 +142,7 @@ def _merge_frames( data=merged_data, name=frame_name, traces=list(range(base_trace_count + sum(overlay_trace_counts))), + layout=base_frame.layout, ) ) @@ -403,6 +404,7 @@ def _merge_secondary_y_frames( data=merged_data, name=frame_name, traces=list(range(base_trace_count + secondary_trace_count)), + layout=base_frame.layout, ) )