Skip to content

Commit 97a2fbe

Browse files
committed
[SYSTEMDS-3902] Sparse data transfer: Python --> Java
This commit implements optimized data transfer for Scipy sparse matrices from Python to the Java runtime. Key changes include the addition of `convertSciPyCSRToMB` and `convertSciPyCOOToMB` in the Java utility layer to directly handle compressed sparse row and coordinate formats. On the Python side, the `SystemDSContext` now supports a `sparse_data_transfer` flag and a new `from_py` method to unify data ingestion. These updates allow sparse data to be transferred without being converted to dense arrays, improving efficiency. Additionally, several data conversion methods were refactored for better maintenance.
1 parent 3f841b7 commit 97a2fbe

7 files changed

Lines changed: 1034 additions & 395 deletions

File tree

src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.nio.ByteOrder;
2424
import java.nio.charset.StandardCharsets;
2525

26+
import org.apache.log4j.Logger;
2627
import org.apache.sysds.common.Types;
2728
import org.apache.sysds.runtime.DMLRuntimeException;
2829
import org.apache.sysds.runtime.frame.data.columns.Array;
@@ -35,6 +36,7 @@
3536
* Utils for converting python data to java.
3637
*/
3738
public class Py4jConverterUtils {
39+
private static final Logger LOG = Logger.getLogger(Py4jConverterUtils.class);
3840
public static MatrixBlock convertPy4JArrayToMB(byte[] data, int rlen, int clen) {
3941
return convertPy4JArrayToMB(data, rlen, clen, false, Types.ValueType.FP64);
4042
}
@@ -63,6 +65,45 @@ public static MatrixBlock convertSciPyCOOToMB(byte[] data, byte[] row, byte[] co
6365
return mb;
6466
}
6567

68+
public static MatrixBlock convertSciPyCSRToMB(byte[] data, byte[] indices, byte[] indptr, int rlen, int clen, int nnz) {
69+
LOG.debug("Converting compressed sparse row matrix to MatrixBlock");
70+
MatrixBlock mb = new MatrixBlock(rlen, clen, true);
71+
mb.allocateSparseRowsBlock(false);
72+
ByteBuffer dataBuf = ByteBuffer.wrap(data);
73+
dataBuf.order(ByteOrder.nativeOrder());
74+
ByteBuffer indicesBuf = ByteBuffer.wrap(indices);
75+
indicesBuf.order(ByteOrder.nativeOrder());
76+
ByteBuffer indptrBuf = ByteBuffer.wrap(indptr);
77+
indptrBuf.order(ByteOrder.nativeOrder());
78+
79+
// Read indptr array to get row boundaries
80+
int[] rowPtrs = new int[rlen + 1];
81+
for(int i = 0; i <= rlen; i++) {
82+
rowPtrs[i] = indptrBuf.getInt();
83+
}
84+
85+
// Iterate through each row
86+
for(int row = 0; row < rlen; row++) {
87+
int startIdx = rowPtrs[row];
88+
int endIdx = rowPtrs[row + 1];
89+
90+
// Set buffer positions to the start of this row
91+
dataBuf.position(startIdx * Double.BYTES);
92+
indicesBuf.position(startIdx * Integer.BYTES);
93+
94+
// Process all non-zeros in this row sequentially
95+
for(int idx = startIdx; idx < endIdx; idx++) {
96+
double val = dataBuf.getDouble();
97+
int colIndex = indicesBuf.getInt();
98+
mb.set(row, colIndex, val);
99+
}
100+
}
101+
102+
mb.recomputeNonZeros();
103+
mb.examSparsity();
104+
return mb;
105+
}
106+
66107
public static MatrixBlock allocateDenseOrSparse(int rlen, int clen, boolean isSparse) {
67108
MatrixBlock ret = new MatrixBlock(rlen, clen, isSparse);
68109
ret.allocateBlock();
@@ -208,6 +249,7 @@ private static void readBufferIntoArray(ByteBuffer buffer, Array<?> array, Types
208249
public static byte[] convertMBtoPy4JDenseArr(MatrixBlock mb) {
209250
byte[] ret = null;
210251
if(mb.isInSparseFormat()) {
252+
LOG.debug("Converting sparse matrix to dense");
211253
mb.sparseToDense();
212254
}
213255

src/main/python/systemds/context/systemds_context.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@
2929
import sys
3030
import struct
3131
import traceback
32+
import warnings
3233
from contextlib import contextmanager
3334
from glob import glob
3435
from queue import Queue
3536
from subprocess import PIPE, Popen
3637
from threading import Thread
37-
from time import sleep, time
38+
from time import sleep
3839
from typing import Dict, Iterable, Sequence, Tuple, Union
3940
from concurrent.futures import ThreadPoolExecutor
4041

4142
import numpy as np
4243
import pandas as pd
44+
import scipy.sparse as sp
4345
from py4j.java_gateway import GatewayParameters, JavaGateway, Py4JNetworkError
4446
from systemds.operator import (
4547
Frame,
@@ -77,6 +79,7 @@ class SystemDSContext(object):
7779
_FIFO_JAVA2PY_PIPES = []
7880
_data_transfer_mode = 0
7981
_multi_pipe_enabled = False
82+
_sparse_data_transfer = True
8083
_logging_initialized = False
8184
_executor_pool = ThreadPoolExecutor(max_workers=os.cpu_count() * 2 or 4)
8285

@@ -89,6 +92,7 @@ def __init__(
8992
py4j_logging_level: int = 50,
9093
data_transfer_mode: int = 1,
9194
multi_pipe_enabled: bool = False,
95+
sparse_data_transfer: bool = True,
9296
):
9397
"""Starts a new instance of SystemDSContext, in which the connection to a JVM systemds instance is handled
9498
Any new instance of this SystemDS Context, would start a separate new JVM.
@@ -103,14 +107,26 @@ def __init__(
103107
The logging levels are as follows: 10 DEBUG, 20 INFO, 30 WARNING, 40 ERROR, 50 CRITICAL.
104108
:param py4j_logging_level: The logging level for Py4j to use, since all communication to the JVM is done through this,
105109
it can be verbose if not set high.
106-
:param data_transfer_mode: default 0,
110+
:param data_transfer_mode: default 0, 0 for py4j, 1 for using pipes (on unix systems)
111+
:param multi_pipe_enabled: default False, if True, use multiple pipes for data transfer
112+
only used if data_transfer_mode is 1.
113+
.. experimental:: This parameter is experimental and may be removed in a future version.
114+
:param sparse_data_transfer: default True, if True, use optimized sparse matrix transfer,
115+
if False, convert sparse matrices to dense arrays before transfer
107116
"""
108117

118+
if multi_pipe_enabled:
119+
warnings.warn(
120+
"The 'multi_pipe_enabled' parameter is experimental and may be removed in a future version.",
121+
DeprecationWarning,
122+
stacklevel=2,
123+
)
109124
self.__setup_logging(logging_level, py4j_logging_level)
110125
self.__start(port, capture_stdout)
111126
self.capture_stats(capture_statistics)
112127
self._log.debug("Started JVM and SystemDS python context manager")
113128
self.__setup_data_transfer(data_transfer_mode, multi_pipe_enabled)
129+
self._sparse_data_transfer = sparse_data_transfer
114130

115131
def __setup_data_transfer(self, data_transfer_mode=0, multi_pipe_enabled=False):
116132
self._data_transfer_mode = data_transfer_mode
@@ -769,21 +785,65 @@ def scalar(self, v: Dict[str, VALID_INPUT_TYPES]) -> Scalar:
769785
# therefore the output type is assign.
770786
return Scalar(self, v, assign=True)
771787

788+
def from_py(
789+
self,
790+
src: Union[np.ndarray, sp.spmatrix, pd.DataFrame, pd.Series],
791+
*args: Sequence[VALID_INPUT_TYPES],
792+
**kwargs: Dict[str, VALID_INPUT_TYPES],
793+
) -> Union[Matrix, Frame]:
794+
"""Generate DAGNode representing data given by a python object, which will be sent to SystemDS on need.
795+
:param src: the python object
796+
:param args: unnamed parameters
797+
:param kwargs: named parameters
798+
:return: A Matrix or Frame Node
799+
"""
800+
801+
def get_params(src, args, kwargs):
802+
unnamed_params = ["'./tmp/{file_name}'"]
803+
if len(src.shape) == 2:
804+
named_params = {"rows": src.shape[0], "cols": src.shape[1]}
805+
elif len(src.shape) == 1:
806+
named_params = {"rows": src.shape[0], "cols": 1}
807+
else:
808+
# TODO Support tensors.
809+
raise ValueError("Only two dimensional arrays supported")
810+
unnamed_params.extend(args)
811+
named_params.update(kwargs)
812+
return unnamed_params, named_params
813+
814+
if isinstance(src, np.ndarray):
815+
unnamed_params, named_params = get_params(src, args, kwargs)
816+
return Matrix(self, "read", unnamed_params, named_params, local_data=src)
817+
elif isinstance(src, sp.spmatrix):
818+
unnamed_params, named_params = get_params(src, args, kwargs)
819+
return Matrix(self, "read", unnamed_params, named_params, local_data=src)
820+
elif isinstance(src, pd.DataFrame):
821+
unnamed_params, named_params = get_params(src, args, kwargs)
822+
named_params["data_type"] = '"frame"'
823+
return Frame(self, "read", unnamed_params, named_params, local_data=src)
824+
elif isinstance(src, pd.Series):
825+
unnamed_params, named_params = get_params(src, args, kwargs)
826+
named_params["data_type"] = '"frame"'
827+
return Frame(self, "read", unnamed_params, named_params, local_data=src)
828+
else:
829+
raise ValueError(f"Unsupported data type: {type(src)}")
830+
772831
def from_numpy(
773832
self,
774-
mat: np.array,
833+
mat: Union[np.ndarray, sp.spmatrix],
775834
*args: Sequence[VALID_INPUT_TYPES],
776835
**kwargs: Dict[str, VALID_INPUT_TYPES],
777836
) -> Matrix:
778-
"""Generate DAGNode representing matrix with data given by a numpy array, which will be sent to SystemDS
779-
on need.
837+
"""Generate DAGNode representing matrix with data given by a numpy array or scipy sparse matrix,
838+
which will be sent to SystemDS on need.
780839
781-
:param mat: the numpy array
840+
:param mat: the numpy array or scipy sparse matrix
782841
:param args: unnamed parameters
783842
:param kwargs: named parameters
784843
:return: A Matrix
844+
Note: This method is deprecated. Use from_py instead.
785845
"""
786-
846+
self._log.warning(f"Deprecated method from_numpy. Use from_py instead.")
787847
unnamed_params = ["'./tmp/{file_name}'"]
788848

789849
if len(mat.shape) == 2:
@@ -811,7 +871,9 @@ def from_pandas(
811871
:param args: unnamed parameters
812872
:param kwargs: named parameters
813873
:return: A Frame
874+
Note: This method is deprecated. Use from_py instead.
814875
"""
876+
self._log.warning(f"Deprecated method from_pandas. Use from_py instead.")
815877
unnamed_params = ["'./tmp/{file_name}'"]
816878

817879
if len(df.shape) == 2:
@@ -824,9 +886,6 @@ def from_pandas(
824886

825887
unnamed_params.extend(args)
826888
named_params["data_type"] = '"frame"'
827-
828-
self._pd_dataframe = df
829-
830889
named_params.update(kwargs)
831890
return Frame(self, "read", unnamed_params, named_params, local_data=df)
832891

0 commit comments

Comments
 (0)