Skip to content

Commit 58b1aae

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix JaxRuntimeError during profiler stop_trace array verification
When executing a profiler request with Xprof Trace Options, the IFRT proxy outputs a single JAX shaped array. Previously, the Python profiler expected 0 outputs, causing a JaxRuntimeError (`Mismatch between out_handlers and num_results: 0 vs 1`). This CL updates the profiler state implementation to correctly expect `(1,)` output from `PluginExecutable.call()` when stopping a trace with Xprof options, consuming the URL suffix parameter and fixing the crash. PiperOrigin-RevId: 888268746
1 parent 44d0853 commit 58b1aae

2 files changed

Lines changed: 48 additions & 8 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@
3535

3636
class _ProfileState:
3737
executable: plugin_executable.PluginExecutable | None = None
38+
profile_request: Mapping[str, Any] | None = None
3839
lock: threading.Lock
3940

4041
def __init__(self) -> None:
4142
self.executable = None
43+
self.profile_request = None
4244
self.lock = threading.Lock()
4345

4446
def reset(self) -> None:
4547
self.executable = None
48+
self.profile_request = None
4649

4750

4851
_first_profile_start = True
@@ -153,6 +156,7 @@ def _start_pathways_trace_from_profile_request(
153156
_profile_state.executable = plugin_executable.PluginExecutable(
154157
json.dumps({"profileRequest": profile_request})
155158
)
159+
_profile_state.profile_request = profile_request
156160
try:
157161
_, result_future = _profile_state.executable.call()
158162
result_future.result()
@@ -233,8 +237,19 @@ def stop_trace() -> None:
233237
if _profile_state.executable is None:
234238
raise ValueError("stop_trace called before a trace is being taken!")
235239
try:
236-
_, result_future = _profile_state.executable.call()
237-
result_future.result()
240+
if (
241+
_profile_state.profile_request
242+
and "xprofTraceOptions" in _profile_state.profile_request
243+
):
244+
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
245+
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
246+
else:
247+
out_avals = ()
248+
out_shardings = ()
249+
250+
_profile_state.executable.call(
251+
out_avals=out_avals, out_shardings=out_shardings
252+
)
238253
finally:
239254
_profile_state.reset()
240255
finally:

pathwaysutils/test/profiling_test.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,9 @@ def test_lock_released_on_stop_failure(self):
213213
"""Tests that the lock is released if stop_trace fails."""
214214
profiling.start_trace("gs://test_bucket/test_dir3")
215215
self.assertFalse(profiling._profile_state.lock.locked())
216-
mock_result = (
217-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
216+
self.mock_plugin_executable_cls.return_value.call.side_effect = (
217+
RuntimeError("stop failed")
218218
)
219-
mock_result.result.side_effect = RuntimeError("stop failed")
220219
with self.assertRaisesRegex(RuntimeError, "stop failed"):
221220
profiling.stop_trace()
222221
self.assertFalse(profiling._profile_state.lock.locked())
@@ -277,6 +276,33 @@ def test_stop_trace_success(self):
277276
with self.subTest("executable_is_none"):
278277
self.assertIsNone(profiling._profile_state.executable)
279278

279+
@mock.patch("jax.version.__version_info__", (0, 9, 2))
280+
def test_stop_trace_with_xprof_options_passes_out_avals(self):
281+
from jax import numpy as jnp
282+
283+
options = jax.profiler.ProfileOptions()
284+
options.duration_ms = 2000
285+
286+
# Bypass start_trace and explicitly populate profile state
287+
request = profiling._create_profile_request(
288+
"gs://test_bucket/test_dir", options
289+
)
290+
profiling._profile_state.profile_request = request
291+
profiling._profile_state.executable = (
292+
self.mock_plugin_executable_cls.return_value
293+
)
294+
295+
profiling.stop_trace()
296+
297+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
298+
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
299+
self.assertIn("out_avals", kwargs)
300+
self.assertIn("out_shardings", kwargs)
301+
self.assertLen(kwargs["out_avals"], 1)
302+
# Check that it's an object dtype ShapedArray
303+
self.assertEqual(kwargs["out_avals"][0].shape, (1,))
304+
self.assertEqual(kwargs["out_avals"][0].dtype, jnp.object_)
305+
280306
def test_stop_trace_before_start_error(self):
281307
with self.assertRaisesRegex(
282308
ValueError, "stop_trace called before a trace is being taken!"
@@ -496,10 +522,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
496522
"""Tests that original_stop_trace is called if pathways stop_trace fails."""
497523
profiling.start_trace("gs://test_bucket/test_dir")
498524
self.assertFalse(profiling._profile_state.lock.locked())
499-
mock_result = (
500-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
525+
self.mock_plugin_executable_cls.return_value.call.side_effect = (
526+
RuntimeError("stop failed")
501527
)
502-
mock_result.result.side_effect = RuntimeError("stop failed")
503528
with self.assertRaisesRegex(RuntimeError, "stop failed"):
504529
profiling.stop_trace()
505530
self.mock_original_stop_trace.assert_called_once()

0 commit comments

Comments
 (0)