diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java index c4957d4e9f2..80f5ffcd755 100644 --- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java +++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java @@ -21,17 +21,20 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.sysds.api.jmlc.Connection; +import py4j.DefaultGatewayServerListener; import py4j.GatewayServer; -import py4j.GatewayServerListener; import py4j.Py4JNetworkException; -import py4j.Py4JServerConnection; + public class PythonDMLScript { private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName()); final private Connection _connection; + public static GatewayServer GwS; /** * Entry point for Python API. @@ -42,7 +45,7 @@ public class PythonDMLScript { public static void main(String[] args) throws Exception { final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args); DMLScript.loadConfiguration(dmlOptions.configFile); - final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); + GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); GwS.addListener(new DMLGateWayListener()); try { GwS.start(); @@ -67,38 +70,20 @@ private PythonDMLScript() { _connection = new Connection(); } + public static void setDMLGateWayListenerLoggerLevel(Level l){ + Logger.getLogger(DMLGateWayListener.class).setLevel(l); + } + public Connection getConnection() { return _connection; } - protected static class DMLGateWayListener implements GatewayServerListener { + protected static class DMLGateWayListener extends DefaultGatewayServerListener { private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName()); - @Override - public void connectionError(Exception e) { - LOG.warn("Connection error: " + e.getMessage()); - System.exit(1); - } - - @Override - public void connectionStarted(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection Started: " + gatewayConnection.toString()); - } - - @Override - public void connectionStopped(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection stopped: " + gatewayConnection.toString()); - } - - @Override - public void serverError(Exception e) { - LOG.error("Server Error " + e.getMessage()); - } - @Override public void serverPostShutdown() { LOG.info("Shutdown done"); - System.exit(0); } @Override @@ -108,13 +93,12 @@ public void serverPreShutdown() { @Override public void serverStarted() { - LOG.info("GatewayServer Started"); + LOG.info("GatewayServer started"); } @Override public void serverStopped() { - LOG.info("GatewayServer Stopped"); - System.exit(0); + LOG.info("GatewayServer stopped"); } } diff --git a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java index 4b8395107f7..9e7cda13ee8 100644 --- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java +++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java @@ -19,11 +19,55 @@ package org.apache.sysds.test.usertest.pythonapi; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.api.PythonDMLScript; +import org.apache.sysds.test.LoggingUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import py4j.GatewayServer; + +import java.security.Permission; +import java.util.List; + /** Simple tests to verify startup of Python Gateway server happens without crashes */ public class StartupTest { + private LoggingUtils.TestAppender appender; + private SecurityManager sm; + + @Before + public void setUp() { + appender = LoggingUtils.overwrite(); + sm = System.getSecurityManager(); + System.setSecurityManager(new NoExitSecurityManager()); + PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL); + Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL); + } + + @After + public void tearDown() { + LoggingUtils.reinsert(appender); + System.setSecurityManager(sm); + } + + private void assertLogMessages(String... expectedMessages) { + List log = LoggingUtils.reinsert(appender); + log.stream().forEach(l -> System.out.println(l.getMessage())); + Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size()); + + for (int i = 0; i < expectedMessages.length; i++) { + // order does not matter + boolean found = false; + for (String message : expectedMessages) { + found |= log.get(i).getMessage().toString().startsWith(message); + } + Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found); + } + } @Test(expected = Exception.class) public void testStartupIncorrect_1() throws Exception { @@ -50,4 +94,48 @@ public void testStartupIncorrect_5() throws Exception { // Number out of range PythonDMLScript.main(new String[] {"-python", "918757"}); } + + @Test + public void testStartupIncorrect_6() throws Exception { + GatewayServer gws1 = null; + try { + PythonDMLScript.main(new String[]{"-python", "4001"}); + gws1 = PythonDMLScript.GwS; + Thread.sleep(200); + PythonDMLScript.main(new String[]{"-python", "4001"}); + Thread.sleep(200); + } catch (SecurityException e) { + assertLogMessages( + "GatewayServer started", + "failed startup" + ); + gws1.shutdown(); + } + } + + @Test + public void testStartupCorrect() throws Exception { + PythonDMLScript.main(new String[]{"-python", "4002"}); + Thread.sleep(200); + PythonDMLScript script = (PythonDMLScript) PythonDMLScript.GwS.getGateway().getEntryPoint(); + script.getConnection(); + PythonDMLScript.GwS.shutdown(); + Thread.sleep(200); + assertLogMessages( + "GatewayServer started", + "Starting JVM shutdown", + "Shutdown done", + "GatewayServer stopped" + ); + } + + class NoExitSecurityManager extends SecurityManager { + @Override + public void checkPermission(Permission perm) { } + + @Override + public void checkExit(int status) { + throw new SecurityException("Intercepted exit()"); + } + } }