|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import textwrap |
| 19 | +from pathlib import Path |
| 20 | + |
| 21 | +import nbformat |
| 22 | +import papermill as pm |
| 23 | +import pytest |
| 24 | + |
| 25 | +pytestmark = pytest.mark.notebook |
| 26 | + |
| 27 | +NOTEBOOK_PATH = Path(__file__).parents[2] / "notebooks" / "spark_integration_example.ipynb" |
| 28 | + |
| 29 | +# --------------------------------------------------------------------------- |
| 30 | +# Mock pyspark |
| 31 | +# Replaces pyspark.sql.SparkSession with a fake one |
| 32 | +# --------------------------------------------------------------------------- |
| 33 | +_MOCK_PYSPARK = textwrap.dedent("""\ |
| 34 | + import sys |
| 35 | + import types |
| 36 | + from unittest.mock import MagicMock |
| 37 | +
|
| 38 | + def _make_fake_pyspark(): |
| 39 | + pyspark_mod = types.ModuleType("pyspark") |
| 40 | + sql_mod = types.ModuleType("pyspark.sql") |
| 41 | + pyspark_mod.sql = sql_mod |
| 42 | + sys.modules.setdefault("pyspark", pyspark_mod) |
| 43 | + sys.modules.setdefault("pyspark.sql", sql_mod) |
| 44 | + return pyspark_mod, sql_mod |
| 45 | +
|
| 46 | + _pyspark, _sql = _make_fake_pyspark() |
| 47 | +
|
| 48 | + _SHOW_CATALOGS = ( |
| 49 | + "+-------------+\\n" |
| 50 | + "|catalogName |\\n" |
| 51 | + "+-------------+\\n" |
| 52 | + "|spark_catalog|\\n" |
| 53 | + "|local |\\n" |
| 54 | + "+-------------+\\n" |
| 55 | + ) |
| 56 | + _SHOW_NAMESPACES = ( |
| 57 | + "+---------+\\n" |
| 58 | + "|namespace|\\n" |
| 59 | + "+---------+\\n" |
| 60 | + "|default |\\n" |
| 61 | + "+---------+\\n" |
| 62 | + ) |
| 63 | + _SHOW_TABLES = ( |
| 64 | + "+---------+-----------+-----------+\\n" |
| 65 | + "|namespace|tableName |isTemporary|\\n" |
| 66 | + "+---------+-----------+-----------+\\n" |
| 67 | + "|default |test_all |false |\\n" |
| 68 | + "+---------+-----------+-----------+\\n" |
| 69 | + ) |
| 70 | + _DESCRIBE_TABLE = ( |
| 71 | + "+--------------------+---------+-------+\\n" |
| 72 | + "|col_name |data_type|comment|\\n" |
| 73 | + "+--------------------+---------+-------+\\n" |
| 74 | + "|boolean_col |boolean |null |\\n" |
| 75 | + "|integer_col |integer |null |\\n" |
| 76 | + "+--------------------+---------+-------+\\n" |
| 77 | + ) |
| 78 | + _SQL_RESPONSES = { |
| 79 | + "SHOW CATALOGS": _SHOW_CATALOGS, |
| 80 | + "SHOW NAMESPACES": _SHOW_NAMESPACES, |
| 81 | + "SHOW TABLES FROM default": _SHOW_TABLES, |
| 82 | + "DESCRIBE TABLE default.test_all_types": _DESCRIBE_TABLE, |
| 83 | + } |
| 84 | +
|
| 85 | + def _make_df(output): |
| 86 | + df = MagicMock() |
| 87 | + df.show.side_effect = lambda *a, **kw: print(output, end="") |
| 88 | + return df |
| 89 | +
|
| 90 | + class _FakeBuilder: |
| 91 | + def remote(self, url): return self |
| 92 | + def getOrCreate(self): return _FakeSession() |
| 93 | +
|
| 94 | + class _FakeSession: |
| 95 | + builder = _FakeBuilder() |
| 96 | + def sql(self, query): |
| 97 | + key = query.strip().rstrip(";") |
| 98 | + output = _SQL_RESPONSES.get(key, "+------+\\n| col |\\n+------+\\n| val |\\n+------+\\n") |
| 99 | + return _make_df(output) |
| 100 | +
|
| 101 | + _FakeSparkSession = MagicMock(spec=object) |
| 102 | + _FakeSparkSession.builder = _FakeBuilder() |
| 103 | + _sql.SparkSession = _FakeSparkSession |
| 104 | +""") |
| 105 | + |
| 106 | + |
| 107 | +def get_all_stdout(nb: nbformat.NotebookNode) -> str: |
| 108 | + """Concatenate all stdout streams from every executed cell.""" |
| 109 | + return "".join( |
| 110 | + out.get("text", "") |
| 111 | + for cell in nb.cells |
| 112 | + for out in cell.get("outputs", []) |
| 113 | + if out.get("output_type") == "stream" and out.get("name") == "stdout" |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def _inject_mock_and_execute(notebook_path: Path, output_path: Path) -> nbformat.NotebookNode: |
| 118 | + """ |
| 119 | + Load the real notebook, prepend the mock-pyspark setup cell, write to a |
| 120 | + temporary copy and execute it with papermill. |
| 121 | + """ |
| 122 | + nb = nbformat.read(str(notebook_path), as_version=4) |
| 123 | + |
| 124 | + mock_cell = nbformat.v4.new_code_cell(_MOCK_PYSPARK) |
| 125 | + mock_cell.metadata["tags"] = ["injected-mock"] |
| 126 | + nb.cells.insert(0, mock_cell) |
| 127 | + |
| 128 | + patched_path = output_path.parent / "spark_patched.ipynb" |
| 129 | + nbformat.write(nb, str(patched_path)) |
| 130 | + |
| 131 | + return pm.execute_notebook(str(patched_path), str(output_path), kernel_name="python3") |
| 132 | + |
| 133 | + |
| 134 | +@pytest.fixture(scope="session") |
| 135 | +def spark_nb(tmp_path_factory: pytest.TempPathFactory) -> nbformat.NotebookNode: |
| 136 | + out = tmp_path_factory.mktemp("nb_out") / "spark_integration_example_out.ipynb" |
| 137 | + return _inject_mock_and_execute(NOTEBOOK_PATH, out) |
| 138 | + |
| 139 | + |
| 140 | +class TestSmoke: |
| 141 | + def test_notebook_completes_without_error(self, spark_nb: nbformat.NotebookNode) -> None: |
| 142 | + assert spark_nb is not None |
| 143 | + |
| 144 | + def test_all_code_cells_executed(self, spark_nb: nbformat.NotebookNode) -> None: |
| 145 | + for cell in spark_nb.cells: |
| 146 | + if cell.cell_type == "code": |
| 147 | + assert cell.get("execution_count") is not None, f"Cell not executed:\n{cell.source[:80]}" |
| 148 | + |
| 149 | + |
| 150 | +class TestCellOutputs: |
| 151 | + def test_show_catalogs_lists_spark_catalog_and_local(self, spark_nb: nbformat.NotebookNode) -> None: |
| 152 | + stdout = get_all_stdout(spark_nb) |
| 153 | + assert "spark_catalog" in stdout |
| 154 | + assert "local" in stdout |
| 155 | + |
| 156 | + def test_show_namespaces_contains_default(self, spark_nb: nbformat.NotebookNode) -> None: |
| 157 | + assert "default" in get_all_stdout(spark_nb) |
| 158 | + |
| 159 | + def test_show_tables_produces_tabular_output(self, spark_nb: nbformat.NotebookNode) -> None: |
| 160 | + assert "+---------+-----------+-----------+" in get_all_stdout(spark_nb) |
| 161 | + |
| 162 | + def test_describe_table_lists_column_names(self, spark_nb: nbformat.NotebookNode) -> None: |
| 163 | + assert "col_name" in get_all_stdout(spark_nb) |
| 164 | + |
| 165 | + def test_describe_table_lists_data_types(self, spark_nb: nbformat.NotebookNode) -> None: |
| 166 | + stdout = get_all_stdout(spark_nb) |
| 167 | + assert "boolean" in stdout or "integer" in stdout |
| 168 | + |
| 169 | + def test_show_tables_includes_test_table_row(self, spark_nb: nbformat.NotebookNode) -> None: |
| 170 | + assert "test_all" in get_all_stdout(spark_nb) |
0 commit comments