diff --git a/ultraplot/axes/three.py b/ultraplot/axes/three.py index 20bb92ddb..cf0e314e9 100644 --- a/ultraplot/axes/three.py +++ b/ultraplot/axes/three.py @@ -33,3 +33,11 @@ def __init__(self, *args, **kwargs): kwargs.setdefault("alpha", 0.0) super().__init__(*args, **kwargs) + + def graph(self, *args, **kwargs): + """ + Draw network graphs on 3D projections. + """ + from .plot import PlotAxes + + return PlotAxes.graph(self, *args, **kwargs) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 69d9eca4b..38e32b60e 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -149,6 +149,22 @@ def test_graph_input(): ax.graph("invalid_input") +def test_graph_on_3d_projection(): + """ + Ensure graph plotting is available on 3D axes. + """ + import networkx as nx + + g = nx.path_graph(5) + _, axs = uplt.subplots(proj="3d") + ax = axs[0] + nodes, edges, labels = ax.graph(g) + assert callable(getattr(ax, "graph", None)) + assert nodes is not False + assert edges is not False + assert labels is False + + def test_graph_layout_input(): """ Test if layout is in a [0, 1] x [0, 1] box