Skip to content

Commit 5dd86f7

Browse files
authored
Merge pull request #74 from kbarbary/fix-pandas-dep-warning
fix deprecation warning for DataFrame.swapaxes
2 parents 9f5b1db + 8a091d7 commit 5dd86f7

3 files changed

Lines changed: 11 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop
118118
- [@mikemhenry](https://github.com/mikemhenry)
119119
- [@c-feldmann](https://github.com/c-feldmann)
120120
- Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw)
121+
- Kyle Barbary [@kbarbary](https://github.com/kbarbary)

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,4 @@ Scikit-Mol has been developed as a community effort with contributions from peop
118118
- [@mikemhenry](https://github.com/mikemhenry)
119119
- [@c-feldmann](https://github.com/c-feldmann)
120120
- Mieczyslaw Torchala [@mieczyslaw](https://github.com/mieczyslaw)
121+
- Kyle Barbary [@kbarbary](https://github.com/kbarbary)

scikit_mol/parallel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from collections.abc import Sequence
2-
from typing import Any, Callable, Optional
2+
from typing import Any, Callable, Optional, Union
33

44
import numpy as np
5+
import pandas as pd
56
from joblib import Parallel, delayed, effective_n_jobs
67

78

89
def parallelized_with_batches(
910
fn: Callable[[Sequence[Any]], Any],
10-
inputs_list: np.ndarray,
11+
inputs_list: Union[np.ndarray, pd.DataFrame],
1112
n_jobs: Optional[int] = None,
1213
**job_kwargs: Any,
1314
) -> Sequence[Optional[Any]]:
@@ -33,6 +34,11 @@ def parallelized_with_batches(
3334
n_jobs = len(inputs_list)
3435
pool = Parallel(n_jobs=n_jobs, **job_kwargs)
3536

36-
input_chunks = np.array_split(inputs_list, n_jobs)
37+
if isinstance(inputs_list, (pd.DataFrame, pd.Series)):
38+
indexes = np.array_split(range(len(inputs_list)), n_jobs)
39+
input_chunks = [inputs_list.iloc[idx] for idx in indexes]
40+
else:
41+
input_chunks = np.array_split(inputs_list, n_jobs)
42+
3743
results = pool(delayed(fn)(chunk) for chunk in input_chunks)
3844
return results

0 commit comments

Comments
 (0)