1616# limitations under the License.
1717################################################################################
1818
19- from typing import List , Optional
19+ from typing import List , Optional , Tuple
2020
2121import pyarrow as pa
2222from pyarrow import RecordBatch
@@ -53,6 +53,7 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p
5353 self .first_row_id = first_row_id
5454 self .max_sequence_number = max_sequence_number
5555 self .system_fields = system_fields
56+ < << << << HEAD
5657 self .blob_as_descriptor = blob_as_descriptor
5758 self .blob_descriptor_fields = blob_descriptor_fields or set ()
5859 self .file_io = file_io
@@ -66,6 +67,35 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p
6667 for field_name in self .blob_descriptor_fields
6768 if field_name in self .blob_field_names
6869 }
70+ == == == =
71+ self .requested_field_names = [field .name for field in fields ] if fields else None
72+ self .fields = fields
73+
74+ def _align_to_requested_names (
75+ self ,
76+ inter_arrays : List ,
77+ inter_names : List ,
78+ requested_field_names : List [str ],
79+ num_rows : int ,
80+ ) - > Tuple [List , List ]:
81+ name_to_idx = {n : i for i , n in enumerate (inter_names )}
82+ ordered_arrays = []
83+ ordered_names = []
84+ for name in requested_field_names :
85+ idx = name_to_idx .get (name )
86+ if idx is None and name .startswith ("_KEY_" ) and name [5 :] in name_to_idx :
87+ idx = name_to_idx [name [5 :]]
88+ if idx is not None :
89+ ordered_arrays .append (inter_arrays [idx ])
90+ ordered_names .append (name )
91+ else :
92+ field = self .schema_map .get (name )
93+ ordered_arrays .append (
94+ pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
95+ )
96+ ordered_names .append (name )
97+ return ordered_arrays , ordered_names
98+ >> >> >> > 277 fef48c (support shards read of data evolution )
6999
70100 def read_arrow_batch (self , start_idx = None , end_idx = None ) -> Optional [RecordBatch ]:
71101 if isinstance (self .format_reader , FormatBlobReader ):
@@ -75,11 +105,27 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
75105 if record_batch is None :
76106 return None
77107
108+ num_rows = record_batch .num_rows
78109 if self .partition_info is None and self .index_mapping is None :
79110 if self .row_tracking_enabled and self .system_fields :
80111 record_batch = self ._assign_row_tracking (record_batch )
112+ if self .requested_field_names is not None :
113+ inter_arrays = list (record_batch .columns )
114+ inter_names = list (record_batch .schema .names )
115+ ordered_arrays , ordered_names = self ._align_to_requested_names (
116+ inter_arrays , inter_names , self .requested_field_names , num_rows
117+ )
118+ record_batch = pa .RecordBatch .from_arrays (ordered_arrays , ordered_names )
81119 return record_batch
82120
121+ if (self .partition_info is None and self .index_mapping is not None
122+ and not self .requested_field_names ):
123+ ncol = record_batch .num_columns
124+ if len (self .index_mapping ) == ncol and self .index_mapping == list (range (ncol )):
125+ if self .row_tracking_enabled and self .system_fields :
126+ record_batch = self ._assign_row_tracking (record_batch )
127+ return record_batch
128+
83129 inter_arrays = []
84130 inter_names = []
85131 num_rows = record_batch .num_rows
@@ -93,32 +139,123 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch
93139 inter_names .append (partition_field .name )
94140 else :
95141 real_index = self .partition_info .get_real_index (i )
96- if real_index < record_batch .num_columns :
142+ name = (
143+ self .requested_field_names [i ]
144+ if self .requested_field_names and i < len (self .requested_field_names )
145+ else f"_col_{ i } "
146+ )
147+ batch_names = record_batch .schema .names
148+ col_idx = None
149+ if name in batch_names :
150+ col_idx = record_batch .schema .get_field_index (name )
151+ elif name .startswith ("_KEY_" ) and name [5 :] in batch_names :
152+ col_idx = record_batch .schema .get_field_index (name [5 :])
153+ if col_idx is not None :
154+ inter_arrays .append (record_batch .column (col_idx ))
155+ inter_names .append (name )
156+ elif real_index < record_batch .num_columns :
97157 inter_arrays .append (record_batch .column (real_index ))
98- inter_names .append (record_batch .schema .field (real_index ).name )
158+ inter_names .append (name )
159+ else :
160+ field = self .schema_map .get (name )
161+ inter_arrays .append (
162+ pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
163+ )
164+ inter_names .append (name )
99165 else :
100- inter_arrays = record_batch .columns
101- inter_names = record_batch .schema .names
166+ inter_arrays = list ( record_batch .columns )
167+ inter_names = list ( record_batch .schema .names )
102168
103- if self .index_mapping is not None :
169+ if self .requested_field_names is not None :
170+ inter_arrays , inter_names = self ._align_to_requested_names (
171+ inter_arrays , inter_names , self .requested_field_names , num_rows
172+ )
173+
174+ if self .index_mapping is not None and not (
175+ self .requested_field_names is not None and inter_names == self .requested_field_names ):
104176 mapped_arrays = []
105177 mapped_names = []
178+ partition_names = (
179+ set (pf .name for pf in self .partition_info .partition_fields )
180+ if self .partition_info else set ()
181+ )
182+ non_partition_indices = [idx for idx , name in enumerate (inter_names ) if name not in partition_names ]
106183 for i , real_index in enumerate (self .index_mapping ):
107- if 0 <= real_index < len (inter_arrays ):
108- mapped_arrays .append (inter_arrays [real_index ])
109- mapped_names .append (inter_names [real_index ])
184+ if 0 <= real_index < len (non_partition_indices ):
185+ actual_index = non_partition_indices [real_index ]
186+ mapped_arrays .append (inter_arrays [actual_index ])
187+ mapped_names .append (inter_names [actual_index ])
110188 else :
111- null_array = pa .nulls (num_rows )
189+ name = (
190+ self .requested_field_names [i ]
191+ if self .requested_field_names and i < len (self .requested_field_names )
192+ else f"null_col_{ i } "
193+ )
194+ field = self .schema_map .get (name )
195+ null_array = pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
112196 mapped_arrays .append (null_array )
113- mapped_names .append (f"null_col_{ i } " )
197+ mapped_names .append (name )
198+
199+ if self .partition_info :
200+ partition_arrays_map = {
201+ inter_names [i ]: inter_arrays [i ]
202+ for i in range (len (inter_names ))
203+ if inter_names [i ] in partition_names
204+ }
205+
206+ if self .requested_field_names :
207+ final_arrays = []
208+ final_names = []
209+ mapped_name_to_array = {name : arr for name , arr in zip (mapped_names , mapped_arrays )}
210+
211+ for name in self .requested_field_names :
212+ if name in mapped_name_to_array :
213+ final_arrays .append (mapped_name_to_array [name ])
214+ final_names .append (name )
215+ elif name in partition_arrays_map :
216+ final_arrays .append (partition_arrays_map [name ])
217+ final_names .append (name )
218+ else :
219+ # Field not in file (e.g. index_mapping -1): output null column
220+ field = self .schema_map .get (name )
221+ null_arr = pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
222+ final_arrays .append (null_arr )
223+ final_names .append (name )
224+
225+ inter_arrays = final_arrays
226+ inter_names = final_names
227+ else :
228+ mapped_name_set = set (mapped_names )
229+ for name , arr in partition_arrays_map .items ():
230+ if name not in mapped_name_set :
231+ mapped_arrays .append (arr )
232+ mapped_names .append (name )
233+ inter_arrays = mapped_arrays
234+ inter_names = mapped_names
235+ else :
236+ inter_arrays = mapped_arrays
237+ inter_names = mapped_names
114238
115239 if self .system_primary_key :
116240 for i in range (len (self .system_primary_key )):
117- if not mapped_names [i ].startswith ("_KEY_" ):
118- mapped_names [i ] = f"_KEY_{ mapped_names [i ]} "
241+ if i < len (inter_names ) and not inter_names [i ].startswith ("_KEY_" ):
242+ inter_names [i ] = f"_KEY_{ inter_names [i ]} "
243+
244+ if self .requested_field_names is not None and len (inter_arrays ) < len (self .requested_field_names ):
245+ for name in self .requested_field_names [len (inter_arrays ):]:
246+ field = self .schema_map .get (name )
247+ inter_arrays .append (
248+ pa .nulls (num_rows , type = field .type ) if field is not None else pa .nulls (num_rows )
249+ )
250+ inter_names .append (name )
119251
120- inter_arrays = mapped_arrays
121- inter_names = mapped_names
252+ for i , name in enumerate (inter_names ):
253+ target_field = self .schema_map .get (name )
254+ if target_field is not None and inter_arrays [i ].type != target_field .type :
255+ try :
256+ inter_arrays [i ] = inter_arrays [i ].cast (target_field .type )
257+ except (pa .ArrowInvalid , pa .ArrowNotImplementedError ):
258+ inter_arrays [i ] = pa .nulls (num_rows , type = target_field .type )
122259
123260 # to contains 'not null' property
124261 final_fields = []
@@ -205,18 +342,28 @@ def _deserialize_descriptor_or_none(raw: bytes):
205342 def _assign_row_tracking (self , record_batch : RecordBatch ) -> RecordBatch :
206343 """Assign row tracking meta fields (_ROW_ID and _SEQUENCE_NUMBER)."""
207344 arrays = list (record_batch .columns )
345+ num_cols = len (arrays )
208346
209347 # Handle _ROW_ID field
210348 if SpecialFields .ROW_ID .name in self .system_fields .keys ():
211349 idx = self .system_fields [SpecialFields .ROW_ID .name ]
212350 # Create a new array that fills with computed row IDs
213- arrays [idx ] = pa .array (range (self .first_row_id , self .first_row_id + record_batch .num_rows ), type = pa .int64 ())
351+ if idx < num_cols :
352+ if self .first_row_id is None :
353+ raise ValueError (
354+ "Row tracking requires first_row_id on the file; "
355+ "got None. Ensure file metadata has first_row_id when reading _ROW_ID."
356+ )
357+ arrays [idx ] = pa .array (
358+ range (self .first_row_id , self .first_row_id + record_batch .num_rows ),
359+ type = pa .int64 ())
214360
215361 # Handle _SEQUENCE_NUMBER field
216362 if SpecialFields .SEQUENCE_NUMBER .name in self .system_fields .keys ():
217363 idx = self .system_fields [SpecialFields .SEQUENCE_NUMBER .name ]
218364 # Create a new array that fills with max_sequence_number
219- arrays [idx ] = pa .repeat (self .max_sequence_number , record_batch .num_rows )
365+ if idx < num_cols :
366+ arrays [idx ] = pa .repeat (self .max_sequence_number , record_batch .num_rows )
220367
221368 names = record_batch .schema .names
222369 table = None
0 commit comments