Skip to content

Commit df74e5f

Browse files
committed
[SPARK-XXXXX][PYTHON] Add DataStreamReader.name() to Classic PySpark
### What changes were proposed in this pull request? This PR adds the `name()` method to Classic PySpark's `DataStreamReader` class. This method allows users to specify a name for streaming sources, which is used in checkpoint metadata and enables stable checkpoint locations for source evolution. Changes include: - Add `name()` method to `DataStreamReader` in `python/pyspark/sql/streaming/readwriter.py` - Add comprehensive test suite in `python/pyspark/sql/tests/streaming/test_streaming_reader_name.py` - Update compatibility test to mark `name` as currently missing from Connect (until the Connect PR merges) The method validates that the source_name contains only ASCII letters, digits, and underscores, raising `PySparkTypeError` or `PySparkValueError` for invalid inputs. ### Why are the changes needed? This brings Classic PySpark to feature parity with the Scala/Java API for streaming source naming. The `name()` method is essential for: 1. Identifying sources in checkpoint metadata 2. Enabling stable checkpoint locations during source evolution 3. Providing consistency across Classic and Connect implementations ### Does this PR introduce _any_ user-facing change? Yes. Users can now call `.name()` on DataStreamReader in Classic PySpark: ```python spark.readStream.format("parquet").name("my_source").load("/path") ``` ### How was this patch tested? - Added comprehensive unit tests in `test_streaming_reader_name.py` covering: - Valid name patterns (letters, digits, underscores) - Invalid names (hyphens, spaces, dots, special characters, empty strings, None, wrong types) - Method chaining - Different data formats (parquet, json) - Integration with streaming queries - Updated compatibility tests to account for the current state where Classic has `name` but Connect doesn't yet ### Was this patch authored or co-authored using generative AI tooling? Yes.
1 parent 4e52198 commit df74e5f

4 files changed

Lines changed: 257 additions & 1 deletion

File tree

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@
465465
"Parameter value <arg_name> must be a valid UUID format: <origin>"
466466
]
467467
},
468+
"INVALID_STREAMING_SOURCE_NAME": {
469+
"message": [
470+
"Invalid streaming source name '<source_name>'. Source names must contain only ASCII letters, digits, and underscores."
471+
]
472+
},
468473
"INVALID_TIMEOUT_TIMESTAMP": {
469474
"message": [
470475
"Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)."

python/pyspark/sql/streaming/readwriter.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
1819
import sys
1920
from collections.abc import Iterator
2021
from typing import cast, overload, Any, Callable, List, Optional, TYPE_CHECKING, Union
@@ -241,6 +242,52 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader":
241242
self._jreader = self._jreader.option(k, to_str(options[k]))
242243
return self
243244

245+
def name(self, source_name: str) -> "DataStreamReader":
246+
"""Specifies a name for the streaming source.
247+
248+
This name is used to identify the source in checkpoint metadata and enables
249+
stable checkpoint locations for source evolution.
250+
251+
.. versionadded:: 4.2.0
252+
253+
Parameters
254+
----------
255+
source_name : str
256+
the name to assign to this streaming source. Must contain only ASCII letters,
257+
digits, and underscores.
258+
259+
Returns
260+
-------
261+
:class:`DataStreamReader`
262+
263+
Notes
264+
-----
265+
This API is experimental.
266+
267+
Examples
268+
--------
269+
>>> spark.readStream.format("rate").name("my_source") # doctest: +SKIP
270+
<...streaming.readwriter.DataStreamReader object ...>
271+
"""
272+
if not source_name or not isinstance(source_name, str):
273+
raise PySparkTypeError(
274+
errorClass="NOT_STR",
275+
messageParameters={
276+
"arg_name": "source_name",
277+
"arg_type": type(source_name).__name__,
278+
},
279+
)
280+
281+
# Validate that source_name contains only ASCII letters, digits, and underscores
282+
if not re.match(r"^[a-zA-Z0-9_]+$", source_name):
283+
raise PySparkValueError(
284+
errorClass="INVALID_STREAMING_SOURCE_NAME",
285+
messageParameters={"source_name": source_name},
286+
)
287+
288+
self._jreader = self._jreader.name(source_name)
289+
return self
290+
244291
def load(
245292
self,
246293
path: Optional[str] = None,
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import tempfile
19+
import time
20+
21+
from pyspark.errors import PySparkTypeError
22+
from pyspark.testing.sqlutils import ReusedSQLTestCase
23+
24+
25+
class DataStreamReaderNameTests(ReusedSQLTestCase):
26+
"""Test suite for DataStreamReader.name() functionality in PySpark."""
27+
28+
@classmethod
29+
def setUpClass(cls):
30+
super(DataStreamReaderNameTests, cls).setUpClass()
31+
# Enable streaming source evolution feature
32+
cls.spark.conf.set("spark.sql.streaming.queryEvolution.enableSourceEvolution", "true")
33+
cls.spark.conf.set("spark.sql.streaming.offsetLog.formatVersion", "2")
34+
35+
def test_name_with_valid_names(self):
36+
"""Test that various valid source name patterns work correctly."""
37+
valid_names = [
38+
"mySource",
39+
"my_source",
40+
"MySource123",
41+
"_private",
42+
"source_123_test",
43+
"123source",
44+
]
45+
46+
for name in valid_names:
47+
with tempfile.TemporaryDirectory(prefix=f"test_{name}_") as tmpdir:
48+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
49+
df = (
50+
self.spark.readStream.format("parquet")
51+
.schema("id LONG")
52+
.name(name)
53+
.load(tmpdir)
54+
)
55+
self.assertTrue(df.isStreaming, f"DataFrame should be streaming for name: {name}")
56+
57+
def test_name_method_chaining(self):
58+
"""Test that name() returns the reader for method chaining."""
59+
with tempfile.TemporaryDirectory(prefix="test_chaining_") as tmpdir:
60+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
61+
df = (
62+
self.spark.readStream.format("parquet")
63+
.schema("id LONG")
64+
.name("my_source")
65+
.option("maxFilesPerTrigger", "1")
66+
.load(tmpdir)
67+
)
68+
69+
self.assertTrue(df.isStreaming, "DataFrame should be streaming")
70+
71+
def test_name_before_format(self):
72+
"""Test that order doesn't matter - name can be set before format."""
73+
with tempfile.TemporaryDirectory(prefix="test_before_format_") as tmpdir:
74+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
75+
df = (
76+
self.spark.readStream.name("my_source")
77+
.format("parquet")
78+
.schema("id LONG")
79+
.load(tmpdir)
80+
)
81+
82+
self.assertTrue(df.isStreaming, "DataFrame should be streaming")
83+
84+
def test_invalid_name_with_hyphen(self):
85+
"""Test that source name with hyphen is rejected."""
86+
with tempfile.TemporaryDirectory(prefix="test_invalid_") as tmpdir:
87+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
88+
with self.assertRaises(Exception) as context:
89+
self.spark.readStream.format("parquet").schema("id LONG").name("my-source").load(
90+
tmpdir
91+
)
92+
93+
# The error message should contain information about invalid name
94+
self.assertIn("source", str(context.exception).lower())
95+
96+
def test_invalid_name_with_space(self):
97+
"""Test that source name with space is rejected."""
98+
with tempfile.TemporaryDirectory(prefix="test_invalid_") as tmpdir:
99+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
100+
with self.assertRaises(Exception) as context:
101+
self.spark.readStream.format("parquet").schema("id LONG").name("my source").load(
102+
tmpdir
103+
)
104+
105+
self.assertIn("source", str(context.exception).lower())
106+
107+
def test_invalid_name_with_dot(self):
108+
"""Test that source name with dot is rejected."""
109+
with tempfile.TemporaryDirectory(prefix="test_invalid_") as tmpdir:
110+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
111+
with self.assertRaises(Exception) as context:
112+
self.spark.readStream.format("parquet").schema("id LONG").name("my.source").load(
113+
tmpdir
114+
)
115+
116+
self.assertIn("source", str(context.exception).lower())
117+
118+
def test_invalid_name_with_special_chars(self):
119+
"""Test that source name with special characters is rejected."""
120+
with tempfile.TemporaryDirectory(prefix="test_invalid_") as tmpdir:
121+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
122+
with self.assertRaises(Exception) as context:
123+
self.spark.readStream.format("parquet").schema("id LONG").name("my@source").load(
124+
tmpdir
125+
)
126+
127+
self.assertIn("source", str(context.exception).lower())
128+
129+
def test_invalid_name_empty_string(self):
130+
"""Test that empty string is rejected."""
131+
with self.assertRaises(PySparkTypeError):
132+
self.spark.readStream.format("rate").name("").load()
133+
134+
def test_invalid_name_none(self):
135+
"""Test that None is rejected."""
136+
with self.assertRaises(PySparkTypeError):
137+
self.spark.readStream.format("rate").name(None).load()
138+
139+
def test_invalid_name_wrong_type(self):
140+
"""Test that non-string types are rejected."""
141+
with self.assertRaises(PySparkTypeError):
142+
self.spark.readStream.format("rate").name(123).load()
143+
144+
def test_name_with_different_formats(self):
145+
"""Test that name() works with different streaming data sources."""
146+
with tempfile.TemporaryDirectory(prefix="test_name_formats_") as tmpdir:
147+
# Create test data
148+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir + "/parquet_data")
149+
self.spark.range(10).selectExpr("id", "CAST(id AS STRING) as value").write.mode(
150+
"overwrite"
151+
).json(tmpdir + "/json_data")
152+
153+
# Test with parquet
154+
parquet_df = (
155+
self.spark.readStream.format("parquet")
156+
.name("parquet_source")
157+
.schema("id LONG")
158+
.load(tmpdir + "/parquet_data")
159+
)
160+
self.assertTrue(parquet_df.isStreaming, "Parquet DataFrame should be streaming")
161+
162+
# Test with json - specify schema
163+
json_df = (
164+
self.spark.readStream.format("json")
165+
.name("json_source")
166+
.schema("id LONG, value STRING")
167+
.load(tmpdir + "/json_data")
168+
)
169+
self.assertTrue(json_df.isStreaming, "JSON DataFrame should be streaming")
170+
171+
def test_name_persists_through_query(self):
172+
"""Test that the name persists when starting a streaming query."""
173+
with tempfile.TemporaryDirectory(prefix="test_name_query_") as tmpdir:
174+
data_dir = tmpdir + "/data"
175+
checkpoint_dir = tmpdir + "/checkpoint"
176+
177+
# Create test data
178+
self.spark.range(10).write.mode("overwrite").parquet(data_dir)
179+
180+
df = (
181+
self.spark.readStream.format("parquet")
182+
.schema("id LONG")
183+
.name("parquet_source_test")
184+
.load(data_dir)
185+
)
186+
187+
query = (
188+
df.writeStream.format("noop").option("checkpointLocation", checkpoint_dir).start()
189+
)
190+
191+
try:
192+
# Let it run briefly
193+
time.sleep(1)
194+
195+
# Verify query is running
196+
self.assertTrue(query.isActive, "Query should be active")
197+
finally:
198+
query.stop()
199+
200+
201+
if __name__ == "__main__":
202+
from pyspark.testing import main
203+
204+
main()

python/pyspark/sql/tests/test_connect_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_streaming_reader_compatibility(self):
487487
"""Test Data Stream Reader compatibility between classic and connect."""
488488
expected_missing_connect_properties = set()
489489
expected_missing_classic_properties = set()
490-
expected_missing_connect_methods = set()
490+
expected_missing_connect_methods = {"name"}
491491
expected_missing_classic_methods = set()
492492
self.check_compatibility(
493493
ClassicDataStreamReader,

0 commit comments

Comments
 (0)