diff --git a/runner/android_junit_runner/CHANGELOG.md b/runner/android_junit_runner/CHANGELOG.md index 43ae3fda7..ea1058f71 100644 --- a/runner/android_junit_runner/CHANGELOG.md +++ b/runner/android_junit_runner/CHANGELOG.md @@ -6,6 +6,8 @@ **Bug Fixes** +* Ensure @Before and @Test run on the same thread in AndroidJUnit4ClassRunner. + **New Features** * Make perfetto trace sections for tests more identifiable by prefixing with "test:" and using fully qualified class name. (b/204992764) diff --git a/runner/android_junit_runner/java/androidx/test/internal/runner/junit4/AndroidJUnit4ClassRunner.java b/runner/android_junit_runner/java/androidx/test/internal/runner/junit4/AndroidJUnit4ClassRunner.java index a4dfde2c7..716194615 100644 --- a/runner/android_junit_runner/java/androidx/test/internal/runner/junit4/AndroidJUnit4ClassRunner.java +++ b/runner/android_junit_runner/java/androidx/test/internal/runner/junit4/AndroidJUnit4ClassRunner.java @@ -16,6 +16,7 @@ package androidx.test.internal.runner.junit4; import static androidx.test.platform.app.InstrumentationRegistry.getArguments; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import androidx.test.internal.runner.RunnerArgs; import androidx.test.internal.runner.junit4.statement.RunAfters; @@ -23,14 +24,18 @@ import androidx.test.internal.runner.junit4.statement.UiThreadStatement; import androidx.test.internal.util.AndroidRunnerParams; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.junit.internal.runners.statements.FailOnTimeout; +import org.junit.internal.runners.model.ReflectiveCallable; +import org.junit.internal.runners.statements.Fail; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; +import org.junit.runners.model.TestTimedOutException; /** A specialized {@link BlockJUnit4ClassRunner} that can handle timeouts */ public class AndroidJUnit4ClassRunner extends BlockJUnit4ClassRunner { @@ -55,13 +60,35 @@ public AndroidJUnit4ClassRunner(Class klass) throws InitializationError { this(klass, RunnerArgs.parseTestTimeout(getArguments())); } + private static final ThreadLocal currentTestStartedLatch = new ThreadLocal<>(); + private static final ThreadLocal currentTestFinishedLatch = new ThreadLocal<>(); + /** Returns a {@link Statement} that invokes {@code method} on {@code test} */ @Override protected Statement methodInvoker(FrameworkMethod method, Object test) { + final Statement invoker; if (UiThreadStatement.shouldRunOnUiThread(method)) { - return new UiThreadStatement(super.methodInvoker(method, test), true); + invoker = new UiThreadStatement(super.methodInvoker(method, test), true); + } else { + invoker = super.methodInvoker(method, test); } - return super.methodInvoker(method, test); + return new Statement() { + @Override + public void evaluate() throws Throwable { + CountDownLatch startLatch = currentTestStartedLatch.get(); + if (startLatch != null) { + startLatch.countDown(); + } + try { + invoker.evaluate(); + } finally { + CountDownLatch finishLatch = currentTestFinishedLatch.get(); + if (finishLatch != null) { + finishLatch.countDown(); + } + } + } + }; } @Override @@ -76,28 +103,101 @@ protected Statement withAfters(FrameworkMethod method, Object target, Statement return afters.isEmpty() ? statement : new RunAfters(method, statement, afters, target); } + @Override + protected Statement methodBlock(FrameworkMethod method) { + Object test; + try { + test = + new ReflectiveCallable() { + @Override + protected Object runReflectiveCall() throws Throwable { + return createTest(); + } + }.run(); + } catch (Throwable e) { + return new Fail(e); + } + + Statement statement = methodInvoker(method, test); + statement = possiblyExpectingExceptions(method, test, statement); + statement = withBefores(method, test, statement); + statement = withAfters(method, test, statement); + statement = withPotentialTimeout(method, test, statement); + try { + java.lang.reflect.Method withRulesMethod = + BlockJUnit4ClassRunner.class.getDeclaredMethod( + "withRules", FrameworkMethod.class, Object.class, Statement.class); + withRulesMethod.setAccessible(true); + statement = (Statement) withRulesMethod.invoke(this, method, test, statement); + } catch (Exception e) { + throw new RuntimeException(e); + } + return statement; + } + /** * Default to {@link org.junit.Test#timeout()} level timeout if set. Otherwise, set the timeout * that was passed to the instrumentation via argument. */ @Override protected Statement withPotentialTimeout(FrameworkMethod method, Object test, Statement next) { - // test level timeout i.e @Test(timeout = 123) long timeout = getTimeout(method.getAnnotation(Test.class)); - - // use runner arg timeout if test level timeout is not present if (timeout <= 0 && perTestTimeout > 0) { timeout = perTestTimeout; } + final long finalTimeout = timeout; - if (timeout <= 0) { - // no timeout was set + if (finalTimeout <= 0 || UiThreadStatement.shouldRunOnUiThread(method)) { return next; } - // Cannot switch to use builder as that is not supported in JUnit 4.10 which is what is - // available in AOSP. - return new FailOnTimeout(next, timeout); + return new Statement() { + @Override + @SuppressWarnings("Interruption") // We want to interrupt the thread to stop the test. + public void evaluate() throws Throwable { + final AtomicReference failure = new AtomicReference<>(); + final CountDownLatch testStartedLatch = new CountDownLatch(1); + final CountDownLatch testFinishedLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(1); + + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + currentTestStartedLatch.set(testStartedLatch); + currentTestFinishedLatch.set(testFinishedLatch); + try { + next.evaluate(); + } catch (Throwable t) { + failure.set(t); + } finally { + testStartedLatch.countDown(); + testFinishedLatch.countDown(); + doneLatch.countDown(); + currentTestStartedLatch.remove(); + currentTestFinishedLatch.remove(); + } + } + }, + "Time-limited test"); + thread.setDaemon(true); + thread.start(); + + testStartedLatch.await(); + boolean finishedInTime = testFinishedLatch.await(finalTimeout, MILLISECONDS); + + if (!finishedInTime) { + thread.interrupt(); + throw new TestTimedOutException(finalTimeout, MILLISECONDS); + } + + doneLatch.await(); + if (failure.get() != null) { + throw failure.get(); + } + } + }; } private long getTimeout(Test annotation) { diff --git a/runner/android_junit_runner/javatests/androidx/test/internal/runner/junit4/AndroidAnnotatedBuilderTest.java b/runner/android_junit_runner/javatests/androidx/test/internal/runner/junit4/AndroidAnnotatedBuilderTest.java index 616fa615a..965ac1a34 100644 --- a/runner/android_junit_runner/javatests/androidx/test/internal/runner/junit4/AndroidAnnotatedBuilderTest.java +++ b/runner/android_junit_runner/javatests/androidx/test/internal/runner/junit4/AndroidAnnotatedBuilderTest.java @@ -26,13 +26,17 @@ import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.Collection; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.junit.runner.JUnitCore; +import org.junit.runner.Result; import org.junit.runner.RunWith; import org.junit.runner.Runner; import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; +import org.junit.runners.model.InitializationError; import org.junit.runners.model.RunnerBuilder; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -133,4 +137,41 @@ public Runner buildAndroidRunner(Class runnerClass, Class t // attempt to create a runner for a class with no @RunWith annotation ab.runnerForClass(NoRunWithClass.class); } + + @SuppressWarnings("NonFinalStaticField") // Static fields are needed to check thread assignment. + public static class TimeoutTestClass { + static Thread beforeThread; + static Thread testThread; + static Thread afterThread; + + @Before + public void before() { + beforeThread = Thread.currentThread(); + } + + @Test(timeout = 5000) + public void testWithTimeout() { + testThread = Thread.currentThread(); + } + + @After + public void after() { + afterThread = Thread.currentThread(); + } + } + + @Test + public void testThreadsSameWithTimeout() throws InitializationError { + TimeoutTestClass.beforeThread = null; + TimeoutTestClass.testThread = null; + TimeoutTestClass.afterThread = null; + + AndroidJUnit4ClassRunner runner = new AndroidJUnit4ClassRunner(TimeoutTestClass.class, 0); + Result result = new JUnitCore().run(runner); + + assertEquals(0, result.getFailureCount()); + Assert.assertNotNull(TimeoutTestClass.beforeThread); + assertEquals(TimeoutTestClass.beforeThread, TimeoutTestClass.testThread); + assertEquals(TimeoutTestClass.testThread, TimeoutTestClass.afterThread); + } }