diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 502309b..3d593d6 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -35,14 +35,17 @@ class _ProfileState: executable: plugin_executable.PluginExecutable | None = None + profile_request: Mapping[str, Any] | None = None lock: threading.Lock def __init__(self) -> None: self.executable = None + self.profile_request = None self.lock = threading.Lock() def reset(self) -> None: self.executable = None + self.profile_request = None _first_profile_start = True @@ -153,6 +156,7 @@ def _start_pathways_trace_from_profile_request( _profile_state.executable = plugin_executable.PluginExecutable( json.dumps({"profileRequest": profile_request}) ) + _profile_state.profile_request = profile_request try: _, result_future = _profile_state.executable.call() result_future.result() @@ -233,8 +237,19 @@ def stop_trace() -> None: if _profile_state.executable is None: raise ValueError("stop_trace called before a trace is being taken!") try: - _, result_future = _profile_state.executable.call() - result_future.result() + if ( + _profile_state.profile_request + and "xprofTraceOptions" in _profile_state.profile_request + ): + out_avals = [jax.core.ShapedArray((1,), jnp.object_)] + out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])] + else: + out_avals = () + out_shardings = () + + _profile_state.executable.call( + out_avals=out_avals, out_shardings=out_shardings + ) finally: _profile_state.reset() finally: diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index 7d524f8..d0bed7a 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -14,12 +14,12 @@ import json import logging -import unittest from unittest import mock from absl.testing import absltest from absl.testing import parameterized import jax +from jax import numpy as jnp from pathwaysutils import profiling import requests @@ -213,10 +213,9 @@ def test_lock_released_on_stop_failure(self): """Tests that the lock is released if stop_trace fails.""" profiling.start_trace("gs://test_bucket/test_dir3") self.assertFalse(profiling._profile_state.lock.locked()) - mock_result = ( - self.mock_plugin_executable_cls.return_value.call.return_value[1] + self.mock_plugin_executable_cls.return_value.call.side_effect = ( + RuntimeError("stop failed") ) - mock_result.result.side_effect = RuntimeError("stop failed") with self.assertRaisesRegex(RuntimeError, "stop failed"): profiling.stop_trace() self.assertFalse(profiling._profile_state.lock.locked()) @@ -277,6 +276,34 @@ def test_stop_trace_success(self): with self.subTest("executable_is_none"): self.assertIsNone(profiling._profile_state.executable) + @absltest.skipIf( + jax.version.__version_info__ < (0, 9, 2), + "ProfileOptions requires JAX 0.9.2 or newer", + ) + def test_stop_trace_with_xprof_options_passes_out_avals(self): + options = jax.profiler.ProfileOptions() + options.duration_ms = 2000 + + # Bypass start_trace and explicitly populate profile state + request = profiling._create_profile_request( + "gs://test_bucket/test_dir", options + ) + profiling._profile_state.profile_request = request + profiling._profile_state.executable = ( + self.mock_plugin_executable_cls.return_value + ) + + profiling.stop_trace() + + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + _, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args + self.assertIn("out_avals", kwargs) + self.assertIn("out_shardings", kwargs) + self.assertLen(kwargs["out_avals"], 1) + # Check that it's an object dtype ShapedArray + self.assertEqual(kwargs["out_avals"][0].shape, (1,)) + self.assertEqual(kwargs["out_avals"][0].dtype, jnp.object_) + def test_stop_trace_before_start_error(self): with self.assertRaisesRegex( ValueError, "stop_trace called before a trace is being taken!" @@ -406,7 +433,7 @@ def test_create_profile_request_default_options(self, profiler_options): }, ) - @unittest.skipIf( + @absltest.skipIf( jax.version.__version_info__ < (0, 9, 2), "ProfileOptions requires JAX 0.9.2 or newer", ) @@ -444,41 +471,45 @@ def test_create_profile_request_with_options(self): }, ) - @unittest.skipIf( + @absltest.skipIf( jax.version.__version_info__ < (0, 9, 2), "ProfileOptions requires JAX 0.9.2 or newer", ) @parameterized.parameters( ({"traceLocation": "gs://test_bucket/test_dir"},), - ({ - "traceLocation": "gs://test_bucket/test_dir", - "blockUntilStart": True, - "maxDurationSecs": 10.0, - "devices": {"deviceIds": [1, 2]}, - "includeResourceManagers": True, - "maxNumHosts": 5, - "xprofTraceOptions": { + ( + { + "traceLocation": "gs://test_bucket/test_dir", "blockUntilStart": True, - "traceDirectory": "gs://test_bucket/test_dir", + "maxDurationSecs": 10.0, + "devices": {"deviceIds": [1, 2]}, + "includeResourceManagers": True, + "maxNumHosts": 5, + "xprofTraceOptions": { + "blockUntilStart": True, + "traceDirectory": "gs://test_bucket/test_dir", + }, }, - },), - ({ - "traceLocation": "gs://bucket/dir", - "xprofTraceOptions": { - "hostTraceLevel": 0, - "traceOptions": { - "traceMode": "TRACE_COMPUTE", - "numSparseCoresToTrace": 1, - "numSparseCoreTilesToTrace": 2, - "numChipsToProfilePerTask": 3, - "powerTraceLevel": 4, - "enableFwThrottleEvent": True, - "enableFwPowerLevelEvent": True, - "enableFwThermalEvent": True, + ), + ( + { + "traceLocation": "gs://bucket/dir", + "xprofTraceOptions": { + "hostTraceLevel": 0, + "traceOptions": { + "traceMode": "TRACE_COMPUTE", + "numSparseCoresToTrace": 1, + "numSparseCoreTilesToTrace": 2, + "numChipsToProfilePerTask": 3, + "powerTraceLevel": 4, + "enableFwThrottleEvent": True, + "enableFwPowerLevelEvent": True, + "enableFwThermalEvent": True, + }, + "traceDirectory": "gs://bucket/dir", }, - "traceDirectory": "gs://bucket/dir", }, - },), + ), ) def test_start_pathways_trace_from_profile_request(self, profile_request): @@ -496,10 +527,9 @@ def test_original_stop_trace_called_on_stop_failure(self): """Tests that original_stop_trace is called if pathways stop_trace fails.""" profiling.start_trace("gs://test_bucket/test_dir") self.assertFalse(profiling._profile_state.lock.locked()) - mock_result = ( - self.mock_plugin_executable_cls.return_value.call.return_value[1] + self.mock_plugin_executable_cls.return_value.call.side_effect = ( + RuntimeError("stop failed") ) - mock_result.result.side_effect = RuntimeError("stop failed") with self.assertRaisesRegex(RuntimeError, "stop failed"): profiling.stop_trace() self.mock_original_stop_trace.assert_called_once()