Skip to content

Commit 641ac6b

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix JaxRuntimeError during profiler stop_trace array verification
When executing a profiler request with Xprof Trace Options, the IFRT proxy outputs a single JAX shaped array. Previously, the Python profiler expected 0 outputs, causing a JaxRuntimeError (`Mismatch between out_handlers and num_results: 0 vs 1`). This CL updates the profiler state implementation to correctly expect `(1,)` output from `PluginExecutable.call()` when stopping a trace with Xprof options, consuming the URL suffix parameter and fixing the crash. PiperOrigin-RevId: 888268746
1 parent 44d0853 commit 641ac6b

2 files changed

Lines changed: 41 additions & 10 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@
3535

3636
class _ProfileState:
3737
executable: plugin_executable.PluginExecutable | None = None
38+
profile_request: Mapping[str, Any] | None = None
3839
lock: threading.Lock
3940

4041
def __init__(self) -> None:
4142
self.executable = None
43+
self.profile_request = None
4244
self.lock = threading.Lock()
4345

4446
def reset(self) -> None:
4547
self.executable = None
48+
self.profile_request = None
4649

4750

4851
_first_profile_start = True
@@ -153,6 +156,7 @@ def _start_pathways_trace_from_profile_request(
153156
_profile_state.executable = plugin_executable.PluginExecutable(
154157
json.dumps({"profileRequest": profile_request})
155158
)
159+
_profile_state.profile_request = profile_request
156160
try:
157161
_, result_future = _profile_state.executable.call()
158162
result_future.result()
@@ -233,8 +237,19 @@ def stop_trace() -> None:
233237
if _profile_state.executable is None:
234238
raise ValueError("stop_trace called before a trace is being taken!")
235239
try:
236-
_, result_future = _profile_state.executable.call()
237-
result_future.result()
240+
if (
241+
_profile_state.profile_request
242+
and "xprofTraceOptions" in _profile_state.profile_request
243+
):
244+
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
245+
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
246+
else:
247+
out_avals = ()
248+
out_shardings = ()
249+
250+
_profile_state.executable.call(
251+
out_avals=out_avals, out_shardings=out_shardings
252+
)
238253
finally:
239254
_profile_state.reset()
240255
finally:

pathwaysutils/test/profiling_test.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,7 @@ def test_lock_released_on_stop_failure(self):
213213
"""Tests that the lock is released if stop_trace fails."""
214214
profiling.start_trace("gs://test_bucket/test_dir3")
215215
self.assertFalse(profiling._profile_state.lock.locked())
216-
mock_result = (
217-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
218-
)
219-
mock_result.result.side_effect = RuntimeError("stop failed")
216+
self.mock_plugin_executable_cls.return_value.call.side_effect = RuntimeError("stop failed")
220217
with self.assertRaisesRegex(RuntimeError, "stop failed"):
221218
profiling.stop_trace()
222219
self.assertFalse(profiling._profile_state.lock.locked())
@@ -277,6 +274,28 @@ def test_stop_trace_success(self):
277274
with self.subTest("executable_is_none"):
278275
self.assertIsNone(profiling._profile_state.executable)
279276

277+
@mock.patch("jax.version.__version_info__", (0, 9, 2))
278+
def test_stop_trace_with_xprof_options_passes_out_avals(self):
279+
from jax import numpy as jnp
280+
options = jax.profiler.ProfileOptions()
281+
options.duration_ms = 2000
282+
283+
# Bypass start_trace and explicitly populate profile state
284+
request = profiling._create_profile_request("gs://test_bucket/test_dir", options)
285+
profiling._profile_state.profile_request = request
286+
profiling._profile_state.executable = self.mock_plugin_executable_cls.return_value
287+
288+
profiling.stop_trace()
289+
290+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
291+
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
292+
self.assertIn("out_avals", kwargs)
293+
self.assertIn("out_shardings", kwargs)
294+
self.assertLen(kwargs["out_avals"], 1)
295+
# Check that it's an object dtype ShapedArray
296+
self.assertEqual(kwargs["out_avals"][0].shape, (1,))
297+
self.assertEqual(kwargs["out_avals"][0].dtype, jnp.object_)
298+
280299
def test_stop_trace_before_start_error(self):
281300
with self.assertRaisesRegex(
282301
ValueError, "stop_trace called before a trace is being taken!"
@@ -496,10 +515,7 @@ def test_original_stop_trace_called_on_stop_failure(self):
496515
"""Tests that original_stop_trace is called if pathways stop_trace fails."""
497516
profiling.start_trace("gs://test_bucket/test_dir")
498517
self.assertFalse(profiling._profile_state.lock.locked())
499-
mock_result = (
500-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
501-
)
502-
mock_result.result.side_effect = RuntimeError("stop failed")
518+
self.mock_plugin_executable_cls.return_value.call.side_effect = RuntimeError("stop failed")
503519
with self.assertRaisesRegex(RuntimeError, "stop failed"):
504520
profiling.stop_trace()
505521
self.mock_original_stop_trace.assert_called_once()

0 commit comments

Comments
 (0)