From 9b0e68ee3fb5f04d2ed4d0f92809807b92bc3016 Mon Sep 17 00:00:00 2001 From: e-strauss <92718421+e-strauss@users.noreply.github.com> Date: Sat, 8 Mar 2025 13:41:35 +0100 Subject: [PATCH 1/2] [MINOR] Fix in Python API GatewayServerListener --- .../org/apache/sysds/api/PythonDMLScript.java | 42 +++++----------- .../test/usertest/pythonapi/StartupTest.java | 50 +++++++++++++++++++ 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java index c4957d4e9f2..80f5ffcd755 100644 --- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java +++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java @@ -21,17 +21,20 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.sysds.api.jmlc.Connection; +import py4j.DefaultGatewayServerListener; import py4j.GatewayServer; -import py4j.GatewayServerListener; import py4j.Py4JNetworkException; -import py4j.Py4JServerConnection; + public class PythonDMLScript { private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName()); final private Connection _connection; + public static GatewayServer GwS; /** * Entry point for Python API. @@ -42,7 +45,7 @@ public class PythonDMLScript { public static void main(String[] args) throws Exception { final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args); DMLScript.loadConfiguration(dmlOptions.configFile); - final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); + GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); GwS.addListener(new DMLGateWayListener()); try { GwS.start(); @@ -67,38 +70,20 @@ private PythonDMLScript() { _connection = new Connection(); } + public static void setDMLGateWayListenerLoggerLevel(Level l){ + Logger.getLogger(DMLGateWayListener.class).setLevel(l); + } + public Connection getConnection() { return _connection; } - protected static class DMLGateWayListener implements GatewayServerListener { + protected static class DMLGateWayListener extends DefaultGatewayServerListener { private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName()); - @Override - public void connectionError(Exception e) { - LOG.warn("Connection error: " + e.getMessage()); - System.exit(1); - } - - @Override - public void connectionStarted(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection Started: " + gatewayConnection.toString()); - } - - @Override - public void connectionStopped(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection stopped: " + gatewayConnection.toString()); - } - - @Override - public void serverError(Exception e) { - LOG.error("Server Error " + e.getMessage()); - } - @Override public void serverPostShutdown() { LOG.info("Shutdown done"); - System.exit(0); } @Override @@ -108,13 +93,12 @@ public void serverPreShutdown() { @Override public void serverStarted() { - LOG.info("GatewayServer Started"); + LOG.info("GatewayServer started"); } @Override public void serverStopped() { - LOG.info("GatewayServer Stopped"); - System.exit(0); + LOG.info("GatewayServer stopped"); } } diff --git a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java index 4b8395107f7..788d0c2544d 100644 --- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java +++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java @@ -19,11 +19,47 @@ package org.apache.sysds.test.usertest.pythonapi; +import org.apache.log4j.Level; +import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.api.PythonDMLScript; +import org.apache.sysds.test.LoggingUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import java.util.List; + + /** Simple tests to verify startup of Python Gateway server happens without crashes */ public class StartupTest { + private LoggingUtils.TestAppender appender; + + @Before + public void setUp() { + appender = LoggingUtils.overwrite(); + PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL); + } + + @After + public void tearDown() { + LoggingUtils.reinsert(appender); + } + + private void assertLogMessages(String... expectedMessages) { + List log = LoggingUtils.reinsert(appender); + log.stream().forEach(l -> System.out.println(l.getMessage())); + Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size()); + + for (int i = 0; i < expectedMessages.length; i++) { + // order does not matter + boolean found = false; + for (String message : expectedMessages) { + found |= log.get(i).getMessage().toString().startsWith(message); + } + Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found); + } + } @Test(expected = Exception.class) public void testStartupIncorrect_1() throws Exception { @@ -50,4 +86,18 @@ public void testStartupIncorrect_5() throws Exception { // Number out of range PythonDMLScript.main(new String[] {"-python", "918757"}); } + + @Test + public void testStartupCorrect() throws Exception { + PythonDMLScript.main(new String[]{"-python", "4001"}); + Thread.sleep(200); + PythonDMLScript.GwS.shutdown(); + Thread.sleep(200); + assertLogMessages( + "GatewayServer started", + "Starting JVM shutdown", + "Shutdown done", + "GatewayServer stopped" + ); + } } From c17750953d077b397ba86ef060098b8d4013fd40 Mon Sep 17 00:00:00 2001 From: e-strauss <92718421+e-strauss@users.noreply.github.com> Date: Fri, 14 Mar 2025 13:08:55 +0100 Subject: [PATCH 2/2] code coverage --- .../test/usertest/pythonapi/StartupTest.java | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java index 788d0c2544d..9e7cda13ee8 100644 --- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java +++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java @@ -20,6 +20,7 @@ package org.apache.sysds.test.usertest.pythonapi; import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.api.PythonDMLScript; import org.apache.sysds.test.LoggingUtils; @@ -27,23 +28,30 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import py4j.GatewayServer; +import java.security.Permission; import java.util.List; /** Simple tests to verify startup of Python Gateway server happens without crashes */ public class StartupTest { private LoggingUtils.TestAppender appender; + private SecurityManager sm; @Before public void setUp() { appender = LoggingUtils.overwrite(); + sm = System.getSecurityManager(); + System.setSecurityManager(new NoExitSecurityManager()); PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL); + Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL); } @After public void tearDown() { LoggingUtils.reinsert(appender); + System.setSecurityManager(sm); } private void assertLogMessages(String... expectedMessages) { @@ -87,10 +95,30 @@ public void testStartupIncorrect_5() throws Exception { PythonDMLScript.main(new String[] {"-python", "918757"}); } + @Test + public void testStartupIncorrect_6() throws Exception { + GatewayServer gws1 = null; + try { + PythonDMLScript.main(new String[]{"-python", "4001"}); + gws1 = PythonDMLScript.GwS; + Thread.sleep(200); + PythonDMLScript.main(new String[]{"-python", "4001"}); + Thread.sleep(200); + } catch (SecurityException e) { + assertLogMessages( + "GatewayServer started", + "failed startup" + ); + gws1.shutdown(); + } + } + @Test public void testStartupCorrect() throws Exception { - PythonDMLScript.main(new String[]{"-python", "4001"}); + PythonDMLScript.main(new String[]{"-python", "4002"}); Thread.sleep(200); + PythonDMLScript script = (PythonDMLScript) PythonDMLScript.GwS.getGateway().getEntryPoint(); + script.getConnection(); PythonDMLScript.GwS.shutdown(); Thread.sleep(200); assertLogMessages( @@ -100,4 +128,14 @@ public void testStartupCorrect() throws Exception { "GatewayServer stopped" ); } + + class NoExitSecurityManager extends SecurityManager { + @Override + public void checkPermission(Permission perm) { } + + @Override + public void checkExit(int status) { + throw new SecurityException("Intercepted exit()"); + } + } }