11from collections .abc import Sequence
2- from typing import Any , Callable , Optional
2+ from typing import Any , Callable , Optional , Union
33
44import numpy as np
5+ import pandas as pd
56from joblib import Parallel , delayed , effective_n_jobs
67
78
89def parallelized_with_batches (
910 fn : Callable [[Sequence [Any ]], Any ],
10- inputs_list : np .ndarray ,
11+ inputs_list : Union [ np .ndarray , pd . DataFrame ] ,
1112 n_jobs : Optional [int ] = None ,
1213 ** job_kwargs : Any ,
1314) -> Sequence [Optional [Any ]]:
@@ -33,6 +34,11 @@ def parallelized_with_batches(
3334 n_jobs = len (inputs_list )
3435 pool = Parallel (n_jobs = n_jobs , ** job_kwargs )
3536
36- input_chunks = np .array_split (inputs_list , n_jobs )
37+ if isinstance (inputs_list , (pd .DataFrame , pd .Series )):
38+ indexes = np .array_split (range (len (inputs_list )), n_jobs )
39+ input_chunks = [inputs_list .iloc [idx ] for idx in indexes ]
40+ else :
41+ input_chunks = np .array_split (inputs_list , n_jobs )
42+
3743 results = pool (delayed (fn )(chunk ) for chunk in input_chunks )
3844 return results
0 commit comments