Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8e7d6da
Add LLM inference support to JMLC API via Py4J bridge
kubraaksux Feb 12, 2026
47dd0db
Refactor loadModel to accept worker script path as parameter
kubraaksux Feb 13, 2026
672a3fa
Add dynamic port allocation and improve resource cleanup
kubraaksux Feb 13, 2026
dacdc1c
Move llm_worker.py to fix Python module collision
kubraaksux Feb 13, 2026
29f657c
Use python3 with fallback to python in Connection.java
kubraaksux Feb 14, 2026
e40e4f2
Add batch inference with FrameBlock and metrics support
kubraaksux Feb 14, 2026
fdd1684
Clean up test: extract constants and shared setup method
kubraaksux Feb 14, 2026
b9ba3e0
Add token counts, GPU support, and improve error handling
kubraaksux Feb 14, 2026
2e984a2
Increase worker startup timeout to 300s for larger models
kubraaksux Feb 16, 2026
bf666c2
Revert accidental changes to MatrixBlockDictionary.java
kubraaksux Feb 16, 2026
5faa691
Add GPU batching support to JMLC LLM inference
kubraaksux Feb 16, 2026
c9c85d4
Keep both sequential and batched inference modes in PreparedScript
kubraaksux Feb 16, 2026
4b44dd1
Add gitignore rules for .env files, meeting notes, and local tool config
kubraaksux Feb 16, 2026
72bc334
Add llmPredict builtin, opcode and ParamBuiltinOp entries
kubraaksux Feb 16, 2026
0ad1b56
Add llmPredict parser validation in ParameterizedBuiltinFunctionExpre…
kubraaksux Feb 16, 2026
1e48362
Wire llmPredict through hop, lop and instruction generation
kubraaksux Feb 16, 2026
de675ac
Add llmPredict CP instruction with HTTP-based inference
kubraaksux Feb 16, 2026
5eab87d
Remove Py4J-based LLM inference from JMLC API
kubraaksux Feb 16, 2026
bea062a
Rewrite LLM test to use llmPredict DML built-in
kubraaksux Feb 16, 2026
edf4e39
Add OpenAI-compatible HTTP inference server for HuggingFace models
kubraaksux Feb 16, 2026
45882e2
Fix llmPredict code quality issues
kubraaksux Feb 16, 2026
c3e9a1f
Add concurrency parameter to llmPredict built-in
kubraaksux Feb 16, 2026
53e3feb
Remove license header from test, clarify llm_server.py docstring
kubraaksux Feb 16, 2026
e872f22
Fix JMLC frame binding: match DML variable names to registered inputs
kubraaksux Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.hops.OptimizerUtils;
Expand Down Expand Up @@ -66,6 +70,7 @@
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.DataConverter;
import py4j.GatewayServer;

/**
* Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with
Expand All @@ -91,6 +96,12 @@ public class Connection implements Closeable
private final DMLConfig _dmlconf;
private final CompilerConfig _cconf;
private static FileSystem fs = null;
private Process _pythonProcess = null;
private py4j.GatewayServer _gatewayServer = null;
private LLMCallback _llmWorker = null;
private CountDownLatch _workerLatch = null;

private static final Log LOG = LogFactory.getLog(Connection.class.getName());

/**
* Connection constructor, the starting point for any other JMLC API calls.
Expand Down Expand Up @@ -287,13 +298,161 @@ public PreparedScript prepareScript(String script, Map<String,String> nsscripts,
//return newly create precompiled script
return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf);
}

/**
* Loads a HuggingFace model via Python worker for LLM inference.
* Uses auto-detected available ports for Py4J communication.
*
* @param modelName HuggingFace model name
* @param workerScriptPath path to the Python worker script
* @return LLMCallback interface to the Python worker
*/
public LLMCallback loadModel(String modelName, String workerScriptPath) {
//auto-find available ports
int javaPort = findAvailablePort();
int pythonPort = findAvailablePort();
return loadModel(modelName, workerScriptPath, javaPort, pythonPort);
}

/**
* Loads a HuggingFace model via Python worker for LLM inference.
* Starts a Python subprocess and connects via Py4J.
*
* @param modelName HuggingFace model name
* @param workerScriptPath path to the Python worker script
* @param javaPort port for Java gateway server
* @param pythonPort port for Python callback server
* @return LLMCallback interface to the Python worker
*/
public LLMCallback loadModel(String modelName, String workerScriptPath, int javaPort, int pythonPort) {
if (_llmWorker != null)
return _llmWorker;
try {
//initialize latch for worker registration
_workerLatch = new CountDownLatch(1);

//start Py4J gateway server with callback support
_gatewayServer = new GatewayServer.GatewayServerBuilder()
.entryPoint(this)
.javaPort(javaPort)
.callbackClient(pythonPort, java.net.InetAddress.getLoopbackAddress())
.build();
_gatewayServer.start();

//give gateway time to start
Thread.sleep(500);

//start python worker process with both ports
String pythonCmd = findPythonCommand();
LOG.info("Starting LLM worker with script: " + workerScriptPath +
" (python=" + pythonCmd + ", javaPort=" + javaPort + ", pythonPort=" + pythonPort + ")");
_pythonProcess = new ProcessBuilder(
pythonCmd, workerScriptPath, modelName,
String.valueOf(javaPort), String.valueOf(pythonPort)
).redirectErrorStream(true).start();

//read python output in background thread
Thread outputReader = new Thread(() -> {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(_pythonProcess.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null) {
LOG.info("[LLM Worker] " + line);
}
} catch (IOException e) {
LOG.error("Error reading LLM worker output", e);
}
});
outputReader.setName("llm-worker-output");
outputReader.setDaemon(true);
outputReader.start();

//larger models (7B+) need more time to load weights into GPU memory
long deadlineNs = System.nanoTime() + TimeUnit.SECONDS.toNanos(300);
while (!_workerLatch.await(2, TimeUnit.SECONDS)) {
if (!_pythonProcess.isAlive()) {
int exitCode = _pythonProcess.exitValue();
throw new DMLException("LLM worker process died during startup (exit code " + exitCode + ")");
}
if (System.nanoTime() > deadlineNs) {
throw new DMLException("Timeout waiting for LLM worker to register (300s)");
}
}

} catch (DMLException e) {
throw e;
} catch (Exception e) {
throw new DMLException("Failed to start LLM worker: " + e.getMessage());
}
return _llmWorker;
}

/**
* Finds the available Python command, trying python3 first then python.
* @return python command name
*/
private static String findPythonCommand() {
for (String cmd : new String[]{"python3", "python"}) {
try {
Process p = new ProcessBuilder(cmd, "--version")
.redirectErrorStream(true).start();
int exitCode = p.waitFor();
if (exitCode == 0)
return cmd;
} catch (Exception e) {
//command not found, try next
}
}
throw new DMLException("No Python installation found (tried python3, python)");
}

/**
* Finds an available port on the local machine.
* @return available port number
*/
private int findAvailablePort() {
try (java.net.ServerSocket socket = new java.net.ServerSocket(0)) {
socket.setReuseAddress(true);
return socket.getLocalPort();
} catch (IOException e) {
throw new DMLException("Failed to find available port: " + e.getMessage());
}
}

/**
* Called by Python worker to register itself via Py4J.
*/
public void registerWorker(LLMCallback worker) {
_llmWorker = worker;
if (_workerLatch != null) {
_workerLatch.countDown();
}
LOG.info("LLM worker registered successfully");
}

/**
* Close connection to SystemDS, which clears the
* thread-local DML and compiler configurations.
*/
@Override
public void close() {

//shutdown LLM worker if running
if (_pythonProcess != null) {
_pythonProcess.destroyForcibly();
try {
_pythonProcess.waitFor(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
_pythonProcess = null;
}
if (_gatewayServer != null) {
_gatewayServer.shutdown();
_gatewayServer = null;
}
_llmWorker = null;

//clear thread-local configurations
ConfigurationManager.clearLocalConfigs();
if( ConfigurationManager.isCodegenEnabled() )
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/LLMCallback.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.apache.sysds.api.jmlc;

/**
* Interface for the Python LLM worker.
* The Python side implements this via Py4J callback.
*/
public interface LLMCallback {

/**
* Generates text using the LLM model.
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return generated text continuation
*/
String generate(String prompt, int maxNewTokens, double temperature, double topP);

/**
* Generates text and returns result with token counts as a JSON string.
* Format: {"text": "...", "input_tokens": N, "output_tokens": M}
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return JSON string with generated text and token counts
*/
String generateWithTokenCount(String prompt, int maxNewTokens, double temperature, double topP);
}
109 changes: 109 additions & 0 deletions src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public class PreparedScript implements ConfigurableAPI
private final CompilerConfig _cconf;
private HashMap<String, String> _outVarLineage;

//LLM inference support
private LLMCallback _llmWorker = null;

private PreparedScript(PreparedScript that) {
//shallow copy, except for a separate symbol table
//and related meta data of reused inputs
Expand Down Expand Up @@ -160,6 +163,112 @@ public CompilerConfig getCompilerConfig() {
return _cconf;
}

/**
* Sets the LLM worker callback for text generation.
*
* @param worker the LLM callback interface
*/
public void setLLMWorker(LLMCallback worker) {
_llmWorker = worker;
}

/**
* Gets the LLM worker callback.
*
* @return the LLM callback interface, or null if not set
*/
public LLMCallback getLLMWorker() {
return _llmWorker;
}

/**
* Generates text using the LLM worker.
*
* @param prompt the input prompt text
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature (0.0 = deterministic, higher = more random)
* @param topP nucleus sampling probability threshold
* @return generated text
* @throws DMLException if no LLM worker is set
*/
public String generate(String prompt, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
return _llmWorker.generate(prompt, maxNewTokens, temperature, topP);
}

/**
* Generates text for multiple prompts and returns results as a FrameBlock.
* The FrameBlock has two columns: [prompt, generated_text].
*
* @param prompts array of input prompt texts
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature
* @param topP nucleus sampling probability threshold
* @return FrameBlock with columns [prompt, generated_text]
*/
public FrameBlock generateBatch(String[] prompts, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
//generate text for each prompt
String[][] data = new String[prompts.length][2];
for (int i = 0; i < prompts.length; i++) {
data[i][0] = prompts[i];
data[i][1] = _llmWorker.generate(prompts[i], maxNewTokens, temperature, topP);
}
//create FrameBlock with string schema
ValueType[] schema = new ValueType[]{ValueType.STRING, ValueType.STRING};
String[] colNames = new String[]{"prompt", "generated_text"};
FrameBlock fb = new FrameBlock(schema, colNames);
for (String[] row : data)
fb.appendRow(row);
return fb;
}

/**
* Generates text for multiple prompts and returns results with timing metrics.
* The FrameBlock has five columns: [prompt, generated_text, time_ms, input_tokens, output_tokens].
*
* @param prompts array of input prompt texts
* @param maxNewTokens maximum number of new tokens to generate
* @param temperature sampling temperature
* @param topP nucleus sampling probability threshold
* @return FrameBlock with columns [prompt, generated_text, time_ms, input_tokens, output_tokens]
*/
public FrameBlock generateBatchWithMetrics(String[] prompts, int maxNewTokens, double temperature, double topP) {
if (_llmWorker == null) {
throw new DMLException("No LLM worker set. Call setLLMWorker() first.");
}
//generate text for each prompt with timing and token counts
String[][] data = new String[prompts.length][5];
for (int i = 0; i < prompts.length; i++) {
long start = System.nanoTime();
String json = _llmWorker.generateWithTokenCount(prompts[i], maxNewTokens, temperature, topP);
long elapsed = (System.nanoTime() - start) / 1_000_000;
//parse JSON response: {"text": "...", "input_tokens": N, "output_tokens": M}
try {
org.apache.wink.json4j.JSONObject obj = new org.apache.wink.json4j.JSONObject(json);
data[i][0] = prompts[i];
data[i][1] = obj.getString("text");
data[i][2] = String.valueOf(elapsed);
data[i][3] = String.valueOf(obj.getInt("input_tokens"));
data[i][4] = String.valueOf(obj.getInt("output_tokens"));
} catch (Exception e) {
throw new DMLException("Failed to parse LLM worker response: " + e.getMessage());
}
}
//create FrameBlock with schema
ValueType[] schema = new ValueType[]{
ValueType.STRING, ValueType.STRING, ValueType.INT64, ValueType.INT64, ValueType.INT64};
String[] colNames = new String[]{"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"};
FrameBlock fb = new FrameBlock(schema, colNames);
for (String[] row : data)
fb.appendRow(row);
return fb;
}

/**
* Binds a scalar boolean to a registered input variable.
*
Expand Down
Loading
Loading