@@ -22,15 +22,22 @@ def WB(where: str, merge: bool, batch_size: int, batches: csp.ts[[pa.RecordBatch
2222 data = write_record_batches (where , batches , {}, merge , batch_size )
2323
2424
25+ def _concat_batches (batches : list [pa .RecordBatch ]) -> pa .RecordBatch :
26+ combined_table = pa .Table .from_batches (batches ).combine_chunks ()
27+ combined_batches = combined_table .to_batches ()
28+ if len (combined_batches ) > 1 :
29+ raise ValueError ("Not able to combine multiple record batches into one record batch" )
30+ return combined_batches [0 ]
31+
32+
2533class TestArrow :
2634 def make_record_batch (self , ts_col_name : str , row_size : int , ts : datetime ) -> pa .RecordBatch :
2735 data = {
2836 ts_col_name : pa .array ([ts ] * row_size , type = pa .timestamp ("ms" )),
2937 "name" : pa .array ([chr (ord ("A" ) + idx % 26 ) for idx in range (row_size )]),
3038 }
3139 schema = pa .schema ([(ts_col_name , pa .timestamp ("ms" )), ("name" , pa .string ())])
32- rb = pa .RecordBatch .from_pydict (data )
33- return rb .cast (schema )
40+ return pa .RecordBatch .from_pydict (data , schema = schema )
3441
3542 def make_data (self , ts_col_name : str , row_sizes : [int ], start : datetime = _STARTTIME , interval : int = 1 ):
3643 res = [
@@ -100,7 +107,7 @@ def test_start_found(self, small_batches: bool, row_sizes: [int], row_sizes_prev
100107 assert [len (r [1 ][0 ]) for r in results ["data" ]] == clean_row_sizes
101108 assert [r [1 ][0 ] for r in results ["data" ]] == clean_rbs
102109
103- results = csp .run (G , "TsCol" , schema , [pa . concat_batches (full_rbs )], small_batches , starttime = dt_start - delta )
110+ results = csp .run (G , "TsCol" , schema , [_concat_batches (full_rbs )], small_batches , starttime = dt_start - delta )
104111 assert len (results ["data" ]) == len (clean_row_sizes )
105112 assert [len (r [1 ][0 ]) for r in results ["data" ]] == clean_row_sizes
106113 assert [r [1 ][0 ] for r in results ["data" ]] == clean_rbs
@@ -126,7 +133,7 @@ def test_split(self, small_batches: bool, row_sizes: [int], repeat: int, dt_coun
126133 for idx , tup in enumerate (results ["data" ]):
127134 assert tup [1 ] == rbs_indivs [idx ]
128135
129- results = csp .run (G , "TsCol" , schema , [pa . concat_batches (rbs_full )], small_batches , starttime = _STARTTIME )
136+ results = csp .run (G , "TsCol" , schema , [_concat_batches (rbs_full )], small_batches , starttime = _STARTTIME )
130137 assert len (results ["data" ]) == len (rbs_indivs )
131138 for idx , tup in enumerate (results ["data" ]):
132139 assert pa .Table .from_batches (tup [1 ]) == pa .Table .from_batches (rbs_indivs [idx ])
@@ -201,7 +208,7 @@ def test_write_record_batches_concat(self, row_sizes: [int], concat: bool):
201208 if not concat :
202209 rbs_ts_expected = [rb [0 ] for rb in rbs_ts ]
203210 else :
204- rbs_ts_expected = [pa . concat_batches (rbs_ts [0 ])]
211+ rbs_ts_expected = [_concat_batches (rbs_ts [0 ])]
205212 assert rbs_ts_expected == res .to_batches ()
206213
207214 def test_write_record_batches_batch_sizes (self ):
@@ -214,7 +221,7 @@ def test_write_record_batches_batch_sizes(self):
214221 res = pq .read_table (temp_file .name )
215222 orig = pa .Table .from_batches (rbs )
216223 assert res .equals (orig )
217- rbs_ts_expected = [pa . concat_batches (rbs [2 * i : 2 * i + 2 ]) for i in range (5 )]
224+ rbs_ts_expected = [_concat_batches (rbs [2 * i : 2 * i + 2 ]) for i in range (5 )]
218225 assert rbs_ts_expected == res .to_batches ()
219226
220227 row_sizes = [10 ] * 10
@@ -226,5 +233,5 @@ def test_write_record_batches_batch_sizes(self):
226233 res = pq .read_table (temp_file .name )
227234 orig = pa .Table .from_batches (rbs )
228235 assert res .equals (orig )
229- rbs_ts_expected = [pa . concat_batches (rbs [3 * i : 3 * i + 3 ]) for i in range (4 )]
236+ rbs_ts_expected = [_concat_batches (rbs [3 * i : 3 * i + 3 ]) for i in range (4 )]
230237 assert rbs_ts_expected == res .to_batches ()
0 commit comments