-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathray_datasource.py
More file actions
224 lines (184 loc) · 8.75 KB
/
ray_datasource.py
File metadata and controls
224 lines (184 loc) · 8.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Module to read a Paimon table into a Ray Dataset, by using the Ray Datasource API.
"""
import heapq
import itertools
import logging
from functools import partial
from typing import List, Optional, Iterable
import pyarrow
from packaging.version import parse
import ray
from ray.data.datasource import Datasource
from pypaimon.read.split import Split
from pypaimon.read.table_read import TableRead
from pypaimon.schema.data_types import PyarrowFieldParser
logger = logging.getLogger(__name__)
# Ray version constants for compatibility
RAY_VERSION_SCHEMA_IN_READ_TASK = "2.48.0" # Schema moved from BlockMetadata to ReadTask
RAY_VERSION_PER_TASK_ROW_LIMIT = "2.52.0" # per_task_row_limit parameter introduced
class RayDatasource(Datasource):
"""
Ray Data Datasource implementation for reading Paimon tables.
This datasource enables distributed parallel reading of Paimon table splits,
allowing Ray to read multiple splits concurrently across the cluster.
.. note::
Ray Data converts ``large_binary()`` to ``binary()`` when reading.
When writing back via :meth:`TableWrite.write_ray`, the conversion is handled automatically.
"""
def __init__(self, table_read: TableRead, splits: List[Split]):
"""
Initialize PaimonDatasource.
Args:
table_read: TableRead instance for reading data
splits: List of splits to read
"""
self.table_read = table_read
self.splits = splits
self._schema = None
def get_name(self) -> str:
identifier = self.table_read.table.identifier
table_name = identifier.get_full_name() if hasattr(identifier, 'get_full_name') else str(identifier)
return f"PaimonTable({table_name})"
def estimate_inmemory_data_size(self) -> Optional[int]:
if not self.splits:
return 0
total_size = sum(split.file_size for split in self.splits)
return total_size if total_size > 0 else None
@staticmethod
def _distribute_splits_into_equal_chunks(
splits: Iterable[Split], n_chunks: int
) -> List[List[Split]]:
"""
Implement a greedy knapsack algorithm to distribute the splits across tasks,
based on their file size, as evenly as possible.
"""
chunks = [list() for _ in range(n_chunks)]
chunk_sizes = [(0, chunk_id) for chunk_id in range(n_chunks)]
heapq.heapify(chunk_sizes)
# From largest to smallest, add the splits to the smallest chunk one at a time
for split in sorted(
splits, key=lambda s: s.file_size if hasattr(s, 'file_size') and s.file_size > 0 else 0, reverse=True
):
smallest_chunk = heapq.heappop(chunk_sizes)
chunks[smallest_chunk[1]].append(split)
split_size = split.file_size if hasattr(split, 'file_size') and split.file_size > 0 else 0
heapq.heappush(
chunk_sizes,
(smallest_chunk[0] + split_size, smallest_chunk[1]),
)
return chunks
def get_read_tasks(self, parallelism: int, **kwargs) -> List:
"""Return a list of read tasks that can be executed in parallel."""
from ray.data.datasource import ReadTask
from ray.data.block import BlockMetadata
per_task_row_limit = kwargs.get('per_task_row_limit', None)
if parallelism < 1:
raise ValueError(f"parallelism must be at least 1, got {parallelism}")
if self._schema is None:
self._schema = PyarrowFieldParser.from_paimon_schema(self.table_read.read_type)
if parallelism > len(self.splits):
parallelism = len(self.splits)
logger.warning(
f"Reducing the parallelism to {parallelism}, as that is the number of splits"
)
# Store necessary information for creating readers in Ray workers
# Extract these to avoid serializing the entire self object in closures
table = self.table_read.table
predicate = self.table_read.predicate
read_type = self.table_read.read_type
schema = self._schema
# Create a partial function to avoid capturing self in closure
# This reduces serialization overhead (see https://github.com/ray-project/ray/issues/49107)
def _get_read_task(
splits: List[Split],
table=table,
predicate=predicate,
read_type=read_type,
schema=schema,
) -> Iterable[pyarrow.Table]:
"""Read function that will be executed by Ray workers."""
from pypaimon.read.table_read import TableRead
worker_table_read = TableRead(table, predicate, read_type)
arrow_table = worker_table_read.to_arrow(splits)
if arrow_table is not None and arrow_table.num_rows > 0:
return [arrow_table]
else:
empty_table = pyarrow.Table.from_arrays(
[pyarrow.array([], type=field.type) for field in schema],
schema=schema
)
return [empty_table]
# Use partial to create read function without capturing self
get_read_task = partial(
_get_read_task,
table=table,
predicate=predicate,
read_type=read_type,
schema=schema,
)
read_tasks = []
# Distribute splits across tasks using load balancing algorithm
for chunk_splits in self._distribute_splits_into_equal_chunks(self.splits, parallelism):
if not chunk_splits:
continue
# Calculate metadata for this chunk
total_rows = 0
total_size = 0
for split in chunk_splits:
if predicate is None:
# Only estimate rows if no predicate (predicate filtering changes row count)
if hasattr(split, 'row_count') and split.row_count > 0:
total_rows += split.row_count
if hasattr(split, 'file_size') and split.file_size > 0:
total_size += split.file_size
input_files = list(itertools.chain.from_iterable(
split.file_paths
for split in chunk_splits
if hasattr(split, 'file_paths') and split.file_paths
))
# For PrimaryKey tables, we can't accurately estimate num_rows before merge
if table and table.is_primary_key_table:
num_rows = None # Let Ray calculate actual row count after merge
elif predicate is not None:
num_rows = None # Can't estimate with predicate filtering
else:
num_rows = total_rows if total_rows > 0 else None
size_bytes = total_size if total_size > 0 else None
metadata_kwargs = {
'num_rows': num_rows,
'size_bytes': size_bytes,
'input_files': input_files if input_files else None,
'exec_stats': None, # Will be populated by Ray during execution
}
if parse(ray.__version__) < parse(RAY_VERSION_SCHEMA_IN_READ_TASK):
metadata_kwargs['schema'] = schema
metadata = BlockMetadata(**metadata_kwargs)
read_fn = partial(get_read_task, chunk_splits)
read_task_kwargs = {
'read_fn': read_fn,
'metadata': metadata,
}
if parse(ray.__version__) >= parse(RAY_VERSION_SCHEMA_IN_READ_TASK):
read_task_kwargs['schema'] = schema
if parse(ray.__version__) >= parse(RAY_VERSION_PER_TASK_ROW_LIMIT) and per_task_row_limit is not None:
read_task_kwargs['per_task_row_limit'] = per_task_row_limit
read_tasks.append(ReadTask(**read_task_kwargs))
return read_tasks