Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@

class _ProfileState:
executable: plugin_executable.PluginExecutable | None = None
profile_request: Mapping[str, Any] | None = None
lock: threading.Lock

def __init__(self) -> None:
self.executable = None
self.profile_request = None
self.lock = threading.Lock()

def reset(self) -> None:
self.executable = None
self.profile_request = None


_first_profile_start = True
Expand Down Expand Up @@ -153,6 +156,7 @@ def _start_pathways_trace_from_profile_request(
_profile_state.executable = plugin_executable.PluginExecutable(
json.dumps({"profileRequest": profile_request})
)
_profile_state.profile_request = profile_request
try:
_, result_future = _profile_state.executable.call()
result_future.result()
Expand Down Expand Up @@ -233,8 +237,19 @@ def stop_trace() -> None:
if _profile_state.executable is None:
raise ValueError("stop_trace called before a trace is being taken!")
try:
_, result_future = _profile_state.executable.call()
result_future.result()
if (
_profile_state.profile_request
and "xprofTraceOptions" in _profile_state.profile_request
):
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
else:
out_avals = ()
out_shardings = ()

_profile_state.executable.call(
out_avals=out_avals, out_shardings=out_shardings
)
finally:
_profile_state.reset()
finally:
Expand Down
98 changes: 64 additions & 34 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import json
import logging
import unittest
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from pathwaysutils import profiling
import requests

Expand Down Expand Up @@ -213,10 +213,9 @@ def test_lock_released_on_stop_failure(self):
"""Tests that the lock is released if stop_trace fails."""
profiling.start_trace("gs://test_bucket/test_dir3")
self.assertFalse(profiling._profile_state.lock.locked())
mock_result = (
self.mock_plugin_executable_cls.return_value.call.return_value[1]
self.mock_plugin_executable_cls.return_value.call.side_effect = (
RuntimeError("stop failed")
)
mock_result.result.side_effect = RuntimeError("stop failed")
with self.assertRaisesRegex(RuntimeError, "stop failed"):
profiling.stop_trace()
self.assertFalse(profiling._profile_state.lock.locked())
Expand Down Expand Up @@ -277,6 +276,34 @@ def test_stop_trace_success(self):
with self.subTest("executable_is_none"):
self.assertIsNone(profiling._profile_state.executable)

@absltest.skipIf(
jax.version.__version_info__ < (0, 9, 2),
"ProfileOptions requires JAX 0.9.2 or newer",
)
def test_stop_trace_with_xprof_options_passes_out_avals(self):
options = jax.profiler.ProfileOptions()
options.duration_ms = 2000

# Bypass start_trace and explicitly populate profile state
request = profiling._create_profile_request(
"gs://test_bucket/test_dir", options
)
profiling._profile_state.profile_request = request
profiling._profile_state.executable = (
self.mock_plugin_executable_cls.return_value
)

profiling.stop_trace()

self.mock_plugin_executable_cls.return_value.call.assert_called_once()
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
self.assertIn("out_avals", kwargs)
self.assertIn("out_shardings", kwargs)
self.assertLen(kwargs["out_avals"], 1)
# Check that it's an object dtype ShapedArray
self.assertEqual(kwargs["out_avals"][0].shape, (1,))
self.assertEqual(kwargs["out_avals"][0].dtype, jnp.object_)

def test_stop_trace_before_start_error(self):
with self.assertRaisesRegex(
ValueError, "stop_trace called before a trace is being taken!"
Expand Down Expand Up @@ -406,7 +433,7 @@ def test_create_profile_request_default_options(self, profiler_options):
},
)

@unittest.skipIf(
@absltest.skipIf(
jax.version.__version_info__ < (0, 9, 2),
"ProfileOptions requires JAX 0.9.2 or newer",
)
Expand Down Expand Up @@ -444,41 +471,45 @@ def test_create_profile_request_with_options(self):
},
)

@unittest.skipIf(
@absltest.skipIf(
jax.version.__version_info__ < (0, 9, 2),
"ProfileOptions requires JAX 0.9.2 or newer",
)
@parameterized.parameters(
({"traceLocation": "gs://test_bucket/test_dir"},),
({
"traceLocation": "gs://test_bucket/test_dir",
"blockUntilStart": True,
"maxDurationSecs": 10.0,
"devices": {"deviceIds": [1, 2]},
"includeResourceManagers": True,
"maxNumHosts": 5,
"xprofTraceOptions": {
(
{
"traceLocation": "gs://test_bucket/test_dir",
"blockUntilStart": True,
"traceDirectory": "gs://test_bucket/test_dir",
"maxDurationSecs": 10.0,
"devices": {"deviceIds": [1, 2]},
"includeResourceManagers": True,
"maxNumHosts": 5,
"xprofTraceOptions": {
"blockUntilStart": True,
"traceDirectory": "gs://test_bucket/test_dir",
},
},
},),
({
"traceLocation": "gs://bucket/dir",
"xprofTraceOptions": {
"hostTraceLevel": 0,
"traceOptions": {
"traceMode": "TRACE_COMPUTE",
"numSparseCoresToTrace": 1,
"numSparseCoreTilesToTrace": 2,
"numChipsToProfilePerTask": 3,
"powerTraceLevel": 4,
"enableFwThrottleEvent": True,
"enableFwPowerLevelEvent": True,
"enableFwThermalEvent": True,
),
(
{
"traceLocation": "gs://bucket/dir",
"xprofTraceOptions": {
"hostTraceLevel": 0,
"traceOptions": {
"traceMode": "TRACE_COMPUTE",
"numSparseCoresToTrace": 1,
"numSparseCoreTilesToTrace": 2,
"numChipsToProfilePerTask": 3,
"powerTraceLevel": 4,
"enableFwThrottleEvent": True,
"enableFwPowerLevelEvent": True,
"enableFwThermalEvent": True,
},
"traceDirectory": "gs://bucket/dir",
},
"traceDirectory": "gs://bucket/dir",
},
},),
),
)

def test_start_pathways_trace_from_profile_request(self, profile_request):
Expand All @@ -496,10 +527,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
"""Tests that original_stop_trace is called if pathways stop_trace fails."""
profiling.start_trace("gs://test_bucket/test_dir")
self.assertFalse(profiling._profile_state.lock.locked())
mock_result = (
self.mock_plugin_executable_cls.return_value.call.return_value[1]
self.mock_plugin_executable_cls.return_value.call.side_effect = (
RuntimeError("stop failed")
)
mock_result.result.side_effect = RuntimeError("stop failed")
with self.assertRaisesRegex(RuntimeError, "stop failed"):
profiling.stop_trace()
self.mock_original_stop_trace.assert_called_once()
Expand Down
Loading