Skip to content

Commit 0d126c6

Browse files
lukebaumanncopybara-github
authored andcommitted
Enable Pathways profiling with jax.profiler.ProfileOptions.
This change allows users to configure Pathways profiling by passing a jax.profiler.ProfileOptions object to the start_trace function. The options are translated into the Pathways profile request, enabling control over a subset of parameters. Explicitly, `start_timestamp_ms`, `duration_ms`, `host_tracer_level`, `advanced_configuration`, and `python_tracer_level`. Compatible with JAX 0.9.2 and Pathways images tagged with 0.9.2 and above. PiperOrigin-RevId: 885730249
1 parent 4e6e6b7 commit 0d126c6

2 files changed

Lines changed: 134 additions & 11 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,72 @@ def toy_computation() -> None:
5757
x.block_until_ready()
5858

5959

60+
def _is_default_profile_options(
61+
profiler_options: jax.profiler.ProfileOptions,
62+
) -> bool:
63+
default_options = jax.profiler.ProfileOptions()
64+
return (
65+
profiler_options.host_tracer_level == default_options.host_tracer_level
66+
and profiler_options.python_tracer_level
67+
== default_options.python_tracer_level
68+
and profiler_options.duration_ms == default_options.duration_ms
69+
and not profiler_options.advanced_configuration
70+
)
71+
72+
6073
def _create_profile_request(
6174
log_dir: os.PathLike[str] | str,
75+
profiler_options: jax.profiler.ProfileOptions | None = None,
6276
) -> Mapping[str, Any]:
6377
"""Creates a profile request mapping from the given options."""
64-
profile_request = {}
65-
profile_request["traceLocation"] = str(log_dir)
78+
profile_request: dict[str, Any] = {
79+
"traceLocation": str(log_dir),
80+
}
81+
82+
if profiler_options is None or _is_default_profile_options(profiler_options):
83+
return profile_request
84+
85+
advanced_config = None
86+
if profiler_options.advanced_configuration:
87+
advanced_config = {}
88+
for k, v in profiler_options.advanced_configuration.items():
89+
# Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
90+
# json-compatible dict
91+
if isinstance(v, bool):
92+
advanced_config[k] = {"boolValue": v}
93+
elif isinstance(v, int):
94+
advanced_config[k] = {"intValue": v}
95+
elif isinstance(v, str):
96+
advanced_config[k] = {"stringValue": v}
97+
else:
98+
raise ValueError(
99+
f"Unsupported advanced configuration value type: {type(v)}. "
100+
"Supported types are bool, int, and str."
101+
)
102+
103+
xprof_options: dict[str, Any] = {
104+
"traceDirectory": str(log_dir),
105+
}
106+
107+
if profiler_options.host_tracer_level != 2:
108+
xprof_options["hostTraceLevel"] = profiler_options.host_tracer_level
109+
110+
pw_trace_opts: dict[str, Any] = {}
111+
if profiler_options.python_tracer_level:
112+
pw_trace_opts["enablePythonTracer"] = bool(
113+
profiler_options.python_tracer_level
114+
)
115+
116+
if advanced_config:
117+
pw_trace_opts["advancedConfiguration"] = advanced_config
118+
119+
if pw_trace_opts:
120+
xprof_options["pwTraceOptions"] = pw_trace_opts
121+
122+
profile_request["xprofTraceOptions"] = xprof_options
123+
124+
if profiler_options.duration_ms > 0:
125+
profile_request["maxDurationSecs"] = profiler_options.duration_ms / 1000.0
66126

67127
return profile_request
68128

@@ -104,7 +164,7 @@ def start_trace(
104164
*,
105165
create_perfetto_link: bool = False,
106166
create_perfetto_trace: bool = False,
107-
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
167+
profiler_options: jax.profiler.ProfileOptions | None = None,
108168
) -> None:
109169
"""Starts a profiler trace.
110170
@@ -133,7 +193,6 @@ def start_trace(
133193
This feature is experimental for Pathways on Cloud and may not be fully
134194
supported.
135195
profiler_options: Profiler options to configure the profiler for collection.
136-
Options are not currently supported and ignored.
137196
"""
138197
if not str(log_dir).startswith("gs://"):
139198
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -144,7 +203,18 @@ def start_trace(
144203
"features for Pathways on Cloud and may not be fully supported."
145204
)
146205

147-
_start_pathways_trace_from_profile_request(_create_profile_request(log_dir))
206+
if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None:
207+
_logger.warning(
208+
"ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
209+
"Some options can be specified via command line flags."
210+
)
211+
profiler_options = None
212+
213+
profile_request = _create_profile_request(log_dir, profiler_options)
214+
215+
_logger.debug("Profile request: %s", profile_request)
216+
217+
_start_pathways_trace_from_profile_request(profile_request)
148218

149219
_original_start_trace(
150220
log_dir=log_dir,

pathwaysutils/test/profiling_test.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import logging
17+
import unittest
1718
from unittest import mock
1819

1920
from absl.testing import absltest
@@ -225,9 +226,11 @@ def test_start_trace_success(self):
225226

226227
self.mock_toy_computation.assert_called_once()
227228
self.mock_plugin_executable_cls.assert_called_once_with(
228-
json.dumps(
229-
{"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}}
230-
)
229+
json.dumps({
230+
"profileRequest": {
231+
"traceLocation": "gs://test_bucket/test_dir",
232+
}
233+
})
231234
)
232235
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
233236
self.mock_original_start_trace.assert_called_once_with(
@@ -391,10 +394,60 @@ def test_monkey_patched_stop_server(self):
391394

392395
mocks["stop_server"].assert_called_once()
393396

394-
def test_create_profile_request_no_options(self):
395-
request = profiling._create_profile_request("gs://bucket/dir")
396-
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
397+
@parameterized.parameters(None, jax.profiler.ProfileOptions())
398+
def test_create_profile_request_default_options(self, profiler_options):
399+
request = profiling._create_profile_request(
400+
"gs://bucket/dir", profiler_options=profiler_options
401+
)
402+
self.assertEqual(
403+
request,
404+
{
405+
"traceLocation": "gs://bucket/dir",
406+
},
407+
)
408+
409+
@unittest.skipIf(
410+
jax.version.__version_info__ < (0, 9, 2),
411+
"ProfileOptions requires JAX 0.9.2 or newer",
412+
)
413+
def test_create_profile_request_with_options(self):
414+
options = jax.profiler.ProfileOptions()
415+
options.host_tracer_level = 2
416+
options.python_tracer_level = 1
417+
options.duration_ms = 2000
418+
options.start_timestamp_ns = 123456789
419+
options.advanced_configuration = {
420+
"tpu_num_chips_to_profile_per_task": 3,
421+
"tpu_num_sparse_core_tiles_to_trace": 5,
422+
"tpu_trace_mode": "TRACE_COMPUTE",
423+
}
424+
425+
request = profiling._create_profile_request(
426+
"gs://bucket/dir", profiler_options=options
427+
)
428+
self.assertEqual(
429+
request,
430+
{
431+
"traceLocation": "gs://bucket/dir",
432+
"maxDurationSecs": 2.0,
433+
"xprofTraceOptions": {
434+
"traceDirectory": "gs://bucket/dir",
435+
"pwTraceOptions": {
436+
"enablePythonTracer": True,
437+
"advancedConfiguration": {
438+
"tpu_num_chips_to_profile_per_task": {"intValue": 3},
439+
"tpu_num_sparse_core_tiles_to_trace": {"intValue": 5},
440+
"tpu_trace_mode": {"stringValue": "TRACE_COMPUTE"},
441+
},
442+
},
443+
},
444+
},
445+
)
397446

447+
@unittest.skipIf(
448+
jax.version.__version_info__ < (0, 9, 2),
449+
"ProfileOptions requires JAX 0.9.2 or newer",
450+
)
398451
@parameterized.parameters(
399452
({"traceLocation": "gs://test_bucket/test_dir"},),
400453
({

0 commit comments

Comments
 (0)