Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pathwaysutils/test/debug/timing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
import time
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from pathwaysutils.debug import timing

from google3.testing.pybase import googletest
from google3.testing.pybase import parameterized


class TimingTest(parameterized.TestCase):

Expand Down Expand Up @@ -83,4 +82,4 @@ def my_function():


if __name__ == "__main__":
googletest.main()
absltest.main()
8 changes: 3 additions & 5 deletions pathwaysutils/test/debug/watchdog_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
import traceback
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from pathwaysutils.debug import watchdog

from google3.testing.pybase import googletest
from google3.testing.pybase import parameterized


class WatchdogTest(parameterized.TestCase):
def test_watchdog_start_join(self):
with (
Expand Down Expand Up @@ -93,4 +91,4 @@ def test_log_thread_strack_succes(self, thread_ident, expected_log_output):


if __name__ == "__main__":
googletest.main()
absltest.main()
7 changes: 3 additions & 4 deletions pathwaysutils/test/initialize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

import os

from absl.testing import absltest
from absl.testing import parameterized
import jax
from pathwaysutils import _initialize

from google3.testing.pybase import googletest
from google3.testing.pybase import parameterized


class InitializeTest(parameterized.TestCase):

Expand Down Expand Up @@ -89,4 +88,4 @@ def test_persistence_enabled(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
6 changes: 3 additions & 3 deletions pathwaysutils/test/lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
import jax.extend
from pathwaysutils import lru_cache
from google3.testing.pybase import googletest


class LruCacheTest(googletest.TestCase):
class LruCacheTest(absltest.TestCase):

def test_cache_hits(self):
x = [100]
Expand Down Expand Up @@ -82,4 +82,4 @@ def f(i):


if __name__ == "__main__":
googletest.main()
absltest.main()
7 changes: 3 additions & 4 deletions pathwaysutils/test/persistence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import datetime

from absl.testing import absltest
import jax
import numpy as np
from pathwaysutils.persistence import helper

from google3.testing.pybase import googletest


class PersistenceTest(googletest.TestCase):
class PersistenceTest(absltest.TestCase):
location = "/path/to/location"
name = "name"
dtype = np.dtype(np.int32)
Expand Down Expand Up @@ -106,4 +105,4 @@ def test_get_bulk_write_request(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
6 changes: 3 additions & 3 deletions pathwaysutils/test/plugin_executable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
These should not exercise a specific feature that uses side channel, but rather
the general logic of the class.
"""
from absl.testing import absltest
import jax
from pathwaysutils import plugin_executable
from google3.testing.pybase import googletest

PluginExecutable = plugin_executable.PluginExecutable
XlaRuntimeError = jax.errors.JaxRuntimeError


class PluginExecutableTest(googletest.TestCase):
class PluginExecutableTest(absltest.TestCase):

def setUp(self):
jax.config.update("jax_platforms", "cpu")
Expand All @@ -26,4 +26,4 @@ def test_bad_program(self):
PluginExecutable("this is not json")

if __name__ == "__main__":
googletest.main()
absltest.main()
7 changes: 3 additions & 4 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
import logging
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import jax
from pathwaysutils import profiling
import requests

from google3.testing.pybase import googletest
from google3.testing.pybase import parameterized


class ProfilingTest(parameterized.TestCase):
"""Tests for Pathways on Cloud profiling."""
Expand Down Expand Up @@ -457,4 +456,4 @@ def test_jax_profiler_trace_calls_patched_functions(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
9 changes: 4 additions & 5 deletions pathwaysutils/test/proxy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,22 @@

from unittest import mock

from absl.testing import absltest
import jax
from jax.extend import backend
from pathwaysutils import jax as pw_jax
from pathwaysutils import proxy_backend

from google3.testing.pybase import googletest


class ProxyBackendTest(googletest.TestCase):
class ProxyBackendTest(absltest.TestCase):

def setUp(self):
super().setUp()
jax.config.update("jax_platforms", "proxy")
jax.config.update("jax_backend_target", "grpc://localhost:12345")
backend.clear_backends()

@googletest.skip("b/408025233")
@absltest.skip("b/408025233")
def test_no_proxy_backend_registration_raises_error(self):
self.assertRaises(RuntimeError, backend.backends)

Expand All @@ -48,4 +47,4 @@ def test_proxy_backend_registration(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Loading