Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/test_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ def test_frame_names_preserved(self) -> None:
combined_names = {frame.name for frame in combined.frames}
assert original_names == combined_names

def test_frame_layout_preserved(self) -> None:
"""Test that frame layout (e.g., axis range) is preserved."""
fig = xpx(self.da_3d).line(animation_frame="time", range_y=[0, 10])
overlay_fig = xpx(self.da_3d).scatter(animation_frame="time")

# Verify base has frame layout
assert fig.frames[0].layout is not None

combined = overlay(fig, overlay_fig)

# Frame layout should be preserved
for i, frame in enumerate(combined.frames):
assert frame.layout == fig.frames[i].layout


class TestOverlayFacetsAndAnimation:
"""Tests for overlay with both facets and animation."""
Expand Down Expand Up @@ -560,6 +574,20 @@ def test_mismatched_animation_frames_raises(self) -> None:
with pytest.raises(ValueError, match="frame names don't match"):
add_secondary_y(fig1, fig2)

def test_frame_layout_preserved(self) -> None:
"""Test that frame layout (e.g., axis range) is preserved."""
base = xpx(self.da_2d).line(animation_frame="time", range_y=[0, 10])
secondary = xpx(self.da_2d).bar(animation_frame="time")

# Verify base has frame layout
assert base.frames[0].layout is not None

combined = add_secondary_y(base, secondary)

# Frame layout should be preserved
for i, frame in enumerate(combined.frames):
assert frame.layout == base.frames[i].layout


class TestAddSecondaryYDeepCopy:
"""Tests to ensure add_secondary_y creates deep copies."""
Expand Down
2 changes: 2 additions & 0 deletions xarray_plotly/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _merge_frames(
data=merged_data,
name=frame_name,
traces=list(range(base_trace_count + sum(overlay_trace_counts))),
layout=base_frame.layout,
)
)

Expand Down Expand Up @@ -403,6 +404,7 @@ def _merge_secondary_y_frames(
data=merged_data,
name=frame_name,
traces=list(range(base_trace_count + secondary_trace_count)),
layout=base_frame.layout,
)
)

Expand Down