Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 13 additions & 29 deletions src/main/java/org/apache/sysds/api/PythonDMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.jmlc.Connection;

import py4j.DefaultGatewayServerListener;
import py4j.GatewayServer;
import py4j.GatewayServerListener;
import py4j.Py4JNetworkException;
import py4j.Py4JServerConnection;


public class PythonDMLScript {

private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName());
final private Connection _connection;
public static GatewayServer GwS;

/**
* Entry point for Python API.
Expand All @@ -42,7 +45,7 @@ public class PythonDMLScript {
public static void main(String[] args) throws Exception {
final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
DMLScript.loadConfiguration(dmlOptions.configFile);
final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort);
GwS.addListener(new DMLGateWayListener());
try {
GwS.start();
Expand All @@ -67,38 +70,20 @@ private PythonDMLScript() {
_connection = new Connection();
}

public static void setDMLGateWayListenerLoggerLevel(Level l){
Logger.getLogger(DMLGateWayListener.class).setLevel(l);
}

public Connection getConnection() {
return _connection;
}

protected static class DMLGateWayListener implements GatewayServerListener {
protected static class DMLGateWayListener extends DefaultGatewayServerListener {
private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName());

@Override
public void connectionError(Exception e) {
LOG.warn("Connection error: " + e.getMessage());
System.exit(1);
}

@Override
public void connectionStarted(Py4JServerConnection gatewayConnection) {
LOG.debug("Connection Started: " + gatewayConnection.toString());
}

@Override
public void connectionStopped(Py4JServerConnection gatewayConnection) {
LOG.debug("Connection stopped: " + gatewayConnection.toString());
}

@Override
public void serverError(Exception e) {
LOG.error("Server Error " + e.getMessage());
}

@Override
public void serverPostShutdown() {
LOG.info("Shutdown done");
System.exit(0);
}

@Override
Expand All @@ -108,13 +93,12 @@ public void serverPreShutdown() {

@Override
public void serverStarted() {
LOG.info("GatewayServer Started");
LOG.info("GatewayServer started");
}

@Override
public void serverStopped() {
LOG.info("GatewayServer Stopped");
System.exit(0);
LOG.info("GatewayServer stopped");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,55 @@

package org.apache.sysds.test.usertest.pythonapi;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.sysds.api.PythonDMLScript;
import org.apache.sysds.test.LoggingUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import py4j.GatewayServer;

import java.security.Permission;
import java.util.List;


/** Simple tests to verify startup of Python Gateway server happens without crashes */
public class StartupTest {
private LoggingUtils.TestAppender appender;
private SecurityManager sm;

@Before
public void setUp() {
appender = LoggingUtils.overwrite();
sm = System.getSecurityManager();
System.setSecurityManager(new NoExitSecurityManager());
PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL);
Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL);
}

@After
public void tearDown() {
LoggingUtils.reinsert(appender);
System.setSecurityManager(sm);
}

private void assertLogMessages(String... expectedMessages) {
List<LoggingEvent> log = LoggingUtils.reinsert(appender);
log.stream().forEach(l -> System.out.println(l.getMessage()));
Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size());

for (int i = 0; i < expectedMessages.length; i++) {
// order does not matter
boolean found = false;
for (String message : expectedMessages) {
found |= log.get(i).getMessage().toString().startsWith(message);
}
Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found);
}
}

@Test(expected = Exception.class)
public void testStartupIncorrect_1() throws Exception {
Expand All @@ -50,4 +94,48 @@ public void testStartupIncorrect_5() throws Exception {
// Number out of range
PythonDMLScript.main(new String[] {"-python", "918757"});
}

@Test
public void testStartupIncorrect_6() throws Exception {
GatewayServer gws1 = null;
try {
PythonDMLScript.main(new String[]{"-python", "4001"});
gws1 = PythonDMLScript.GwS;
Thread.sleep(200);
PythonDMLScript.main(new String[]{"-python", "4001"});
Thread.sleep(200);
} catch (SecurityException e) {
assertLogMessages(
"GatewayServer started",
"failed startup"
);
gws1.shutdown();
}
}

@Test
public void testStartupCorrect() throws Exception {
PythonDMLScript.main(new String[]{"-python", "4002"});
Thread.sleep(200);
PythonDMLScript script = (PythonDMLScript) PythonDMLScript.GwS.getGateway().getEntryPoint();
script.getConnection();
PythonDMLScript.GwS.shutdown();
Thread.sleep(200);
assertLogMessages(
"GatewayServer started",
"Starting JVM shutdown",
"Shutdown done",
"GatewayServer stopped"
);
}

class NoExitSecurityManager extends SecurityManager {
@Override
public void checkPermission(Permission perm) { }

@Override
public void checkExit(int status) {
throw new SecurityException("Intercepted exit()");
}
}
}
Loading