Skip to content

Commit 95c45d4

Browse files
authored
Add papermill-based tests for PyIceberg examples (#3330)
Closes #3328 # Rationale for this change `pyiceberg_example.ipynb` and `spark_integration_example.ipynb` had no automated test coverage. Breaking changes to notebook cells could go undetected in CI. This PR adds papermill-based tests that execute the real notebooks as-is, so any change to a cell is automatically reflected in the tests. ## Are these changes tested? Yes. The tests themselves are the change. Run them with: ```bash make test-notebook ``` ## Are there any user-facing changes? No.
1 parent a9ad3a3 commit 95c45d4

5 files changed

Lines changed: 314 additions & 1 deletion

File tree

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
.PHONY: help install install-uv check-license lint \
1818
test test-integration test-integration-setup test-integration-exec test-integration-cleanup test-integration-rebuild \
19-
test-s3 test-adls test-gcs test-coverage coverage-report \
19+
test-s3 test-adls test-gcs test-coverage coverage-report test test-notebook\
2020
docs-serve docs-build notebook notebook-infra \
2121
clean
2222

@@ -150,6 +150,9 @@ coverage-report: ## Combine and report coverage
150150
uv run $(PYTHON_ARG) coverage html
151151
uv run $(PYTHON_ARG) coverage xml
152152

153+
test-notebook: ## Run notebook tests (pyiceberg_example and spark_integration_example) via papermill
154+
$(TEST_RUNNER) pytest tests/notebooks/test_pyiceberg_example.py tests/notebooks/test_spark_integration_example.py -m notebook $(PYTEST_ARGS)
155+
153156
# ================
154157
# Documentation
155158
# ================

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ dev = [
122122
"google-cloud-bigquery>=3.33.0,<4",
123123
"pyarrow-stubs>=20.0.0.20251107", # Remove when pyarrow >= 23.0.0 https://github.com/apache/arrow/pull/47609
124124
"sqlalchemy>=2.0.18,<3",
125+
"papermill>=2.6.0",
126+
"nbformat>=5.10.0",
127+
"ipykernel>=6.29.0",
125128
]
126129
# for mkdocs
127130
docs = [
@@ -161,6 +164,7 @@ markers = [
161164
"integration: marks integration tests against Apache Spark",
162165
"gcs: marks a test as requiring access to gcs compliant storage (use with --gs.token, --gs.project, and --gs.endpoint)",
163166
"benchmark: collection of tests to validate read/write performance before and after a change",
167+
"notebook: marks tests that execute Jupyter notebooks via papermill",
164168
]
165169

166170
# Turns a warning into an error
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from pathlib import Path
19+
20+
import nbformat
21+
import papermill as pm
22+
import pytest
23+
24+
pytestmark = pytest.mark.notebook
25+
26+
NOTEBOOK_PATH = Path(__file__).parents[2] / "notebooks" / "pyiceberg_example.ipynb"
27+
28+
29+
def get_all_stdout(nb: nbformat.NotebookNode) -> str:
30+
"""Concatenate all stdout streams from every executed cell."""
31+
return "".join(
32+
out.get("text", "")
33+
for cell in nb.cells
34+
for out in cell.get("outputs", [])
35+
if out.get("output_type") == "stream" and out.get("name") == "stdout"
36+
)
37+
38+
39+
@pytest.fixture(scope="session")
40+
def pyiceberg_nb(tmp_path_factory: pytest.TempPathFactory) -> nbformat.NotebookNode:
41+
out = tmp_path_factory.mktemp("nb_out") / "pyiceberg_example_out.ipynb"
42+
return pm.execute_notebook(str(NOTEBOOK_PATH), str(out), kernel_name="python3")
43+
44+
45+
class TestSmoke:
46+
def test_notebook_completes_without_error(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
47+
"""papermill raises PapermillExecutionError if any cell fails."""
48+
assert pyiceberg_nb is not None
49+
50+
def test_all_code_cells_executed(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
51+
for cell in pyiceberg_nb.cells:
52+
if cell.cell_type == "code":
53+
assert cell.get("execution_count") is not None, f"Cell not executed:\n{cell.source[:80]}"
54+
55+
56+
class TestCellOutputs:
57+
def test_pyiceberg_version_printed(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
58+
assert "PyIceberg version:" in get_all_stdout(pyiceberg_nb)
59+
60+
def test_warehouse_location_printed(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
61+
stdout = get_all_stdout(pyiceberg_nb)
62+
assert "Warehouse location:" in stdout
63+
assert "iceberg_warehouse_" in stdout
64+
65+
def test_catalog_loaded_successfully(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
66+
assert "Catalog loaded successfully!" in get_all_stdout(pyiceberg_nb)
67+
68+
def test_namespace_default_created(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
69+
assert "default" in get_all_stdout(pyiceberg_nb)
70+
71+
def test_rows_written_is_five(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
72+
assert "Rows written: 5" in get_all_stdout(pyiceberg_nb)
73+
74+
def test_schema_evolved_message(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
75+
assert "Schema evolved!" in get_all_stdout(pyiceberg_nb)
76+
77+
def test_tip_per_mile_column_present_after_evolution(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
78+
assert "tip_per_mile" in get_all_stdout(pyiceberg_nb)
79+
80+
def test_filter_result_is_positive(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
81+
"""The notebook prints 'Rows with tip_per_mile > 1.0: N' — N must be > 0."""
82+
stdout = get_all_stdout(pyiceberg_nb)
83+
assert "Rows with tip_per_mile > 1.0:" in stdout
84+
for line in stdout.splitlines():
85+
if "Rows with tip_per_mile > 1.0:" in line:
86+
count = int(line.split(":")[-1].strip())
87+
assert count > 0
88+
break
89+
90+
def test_snapshot_id_printed(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
91+
assert "Current snapshot ID:" in get_all_stdout(pyiceberg_nb)
92+
93+
def test_table_history_has_entries(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
94+
stdout = get_all_stdout(pyiceberg_nb)
95+
assert "Table history:" in stdout
96+
assert "Snapshot:" in stdout
97+
98+
def test_warehouse_contains_parquet_and_metadata_files(self, pyiceberg_nb: nbformat.NotebookNode) -> None:
99+
stdout = get_all_stdout(pyiceberg_nb)
100+
assert ".parquet" in stdout
101+
assert ".metadata.json" in stdout
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import textwrap
19+
from pathlib import Path
20+
21+
import nbformat
22+
import papermill as pm
23+
import pytest
24+
25+
pytestmark = pytest.mark.notebook
26+
27+
NOTEBOOK_PATH = Path(__file__).parents[2] / "notebooks" / "spark_integration_example.ipynb"
28+
29+
# ---------------------------------------------------------------------------
30+
# Mock pyspark
31+
# Replaces pyspark.sql.SparkSession with a fake one
32+
# ---------------------------------------------------------------------------
33+
_MOCK_PYSPARK = textwrap.dedent("""\
34+
import sys
35+
import types
36+
from unittest.mock import MagicMock
37+
38+
def _make_fake_pyspark():
39+
pyspark_mod = types.ModuleType("pyspark")
40+
sql_mod = types.ModuleType("pyspark.sql")
41+
pyspark_mod.sql = sql_mod
42+
sys.modules.setdefault("pyspark", pyspark_mod)
43+
sys.modules.setdefault("pyspark.sql", sql_mod)
44+
return pyspark_mod, sql_mod
45+
46+
_pyspark, _sql = _make_fake_pyspark()
47+
48+
_SHOW_CATALOGS = (
49+
"+-------------+\\n"
50+
"|catalogName |\\n"
51+
"+-------------+\\n"
52+
"|spark_catalog|\\n"
53+
"|local |\\n"
54+
"+-------------+\\n"
55+
)
56+
_SHOW_NAMESPACES = (
57+
"+---------+\\n"
58+
"|namespace|\\n"
59+
"+---------+\\n"
60+
"|default |\\n"
61+
"+---------+\\n"
62+
)
63+
_SHOW_TABLES = (
64+
"+---------+-----------+-----------+\\n"
65+
"|namespace|tableName |isTemporary|\\n"
66+
"+---------+-----------+-----------+\\n"
67+
"|default |test_all |false |\\n"
68+
"+---------+-----------+-----------+\\n"
69+
)
70+
_DESCRIBE_TABLE = (
71+
"+--------------------+---------+-------+\\n"
72+
"|col_name |data_type|comment|\\n"
73+
"+--------------------+---------+-------+\\n"
74+
"|boolean_col |boolean |null |\\n"
75+
"|integer_col |integer |null |\\n"
76+
"+--------------------+---------+-------+\\n"
77+
)
78+
_SQL_RESPONSES = {
79+
"SHOW CATALOGS": _SHOW_CATALOGS,
80+
"SHOW NAMESPACES": _SHOW_NAMESPACES,
81+
"SHOW TABLES FROM default": _SHOW_TABLES,
82+
"DESCRIBE TABLE default.test_all_types": _DESCRIBE_TABLE,
83+
}
84+
85+
def _make_df(output):
86+
df = MagicMock()
87+
df.show.side_effect = lambda *a, **kw: print(output, end="")
88+
return df
89+
90+
class _FakeBuilder:
91+
def remote(self, url): return self
92+
def getOrCreate(self): return _FakeSession()
93+
94+
class _FakeSession:
95+
builder = _FakeBuilder()
96+
def sql(self, query):
97+
key = query.strip().rstrip(";")
98+
output = _SQL_RESPONSES.get(key, "+------+\\n| col |\\n+------+\\n| val |\\n+------+\\n")
99+
return _make_df(output)
100+
101+
_FakeSparkSession = MagicMock(spec=object)
102+
_FakeSparkSession.builder = _FakeBuilder()
103+
_sql.SparkSession = _FakeSparkSession
104+
""")
105+
106+
107+
def get_all_stdout(nb: nbformat.NotebookNode) -> str:
108+
"""Concatenate all stdout streams from every executed cell."""
109+
return "".join(
110+
out.get("text", "")
111+
for cell in nb.cells
112+
for out in cell.get("outputs", [])
113+
if out.get("output_type") == "stream" and out.get("name") == "stdout"
114+
)
115+
116+
117+
def _inject_mock_and_execute(notebook_path: Path, output_path: Path) -> nbformat.NotebookNode:
118+
"""
119+
Load the real notebook, prepend the mock-pyspark setup cell, write to a
120+
temporary copy and execute it with papermill.
121+
"""
122+
nb = nbformat.read(str(notebook_path), as_version=4)
123+
124+
mock_cell = nbformat.v4.new_code_cell(_MOCK_PYSPARK)
125+
mock_cell.metadata["tags"] = ["injected-mock"]
126+
nb.cells.insert(0, mock_cell)
127+
128+
patched_path = output_path.parent / "spark_patched.ipynb"
129+
nbformat.write(nb, str(patched_path))
130+
131+
return pm.execute_notebook(str(patched_path), str(output_path), kernel_name="python3")
132+
133+
134+
@pytest.fixture(scope="session")
135+
def spark_nb(tmp_path_factory: pytest.TempPathFactory) -> nbformat.NotebookNode:
136+
out = tmp_path_factory.mktemp("nb_out") / "spark_integration_example_out.ipynb"
137+
return _inject_mock_and_execute(NOTEBOOK_PATH, out)
138+
139+
140+
class TestSmoke:
141+
def test_notebook_completes_without_error(self, spark_nb: nbformat.NotebookNode) -> None:
142+
assert spark_nb is not None
143+
144+
def test_all_code_cells_executed(self, spark_nb: nbformat.NotebookNode) -> None:
145+
for cell in spark_nb.cells:
146+
if cell.cell_type == "code":
147+
assert cell.get("execution_count") is not None, f"Cell not executed:\n{cell.source[:80]}"
148+
149+
150+
class TestCellOutputs:
151+
def test_show_catalogs_lists_spark_catalog_and_local(self, spark_nb: nbformat.NotebookNode) -> None:
152+
stdout = get_all_stdout(spark_nb)
153+
assert "spark_catalog" in stdout
154+
assert "local" in stdout
155+
156+
def test_show_namespaces_contains_default(self, spark_nb: nbformat.NotebookNode) -> None:
157+
assert "default" in get_all_stdout(spark_nb)
158+
159+
def test_show_tables_produces_tabular_output(self, spark_nb: nbformat.NotebookNode) -> None:
160+
assert "+---------+-----------+-----------+" in get_all_stdout(spark_nb)
161+
162+
def test_describe_table_lists_column_names(self, spark_nb: nbformat.NotebookNode) -> None:
163+
assert "col_name" in get_all_stdout(spark_nb)
164+
165+
def test_describe_table_lists_data_types(self, spark_nb: nbformat.NotebookNode) -> None:
166+
stdout = get_all_stdout(spark_nb)
167+
assert "boolean" in stdout or "integer" in stdout
168+
169+
def test_show_tables_includes_test_table_row(self, spark_nb: nbformat.NotebookNode) -> None:
170+
assert "test_all" in get_all_stdout(spark_nb)

uv.lock

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)