Skip to content

Commit 92fb5dc

Browse files
lukebaumanncopybara-github
authored andcommitted
Enable Pathways profiling with jax.profiler.ProfileOptions.
This change allows users to configure Pathways profiling by passing a jax.profiler.ProfileOptions object to the start_trace function. The options are translated into the Pathways profile request, enabling control over a subset of parameters. Explicitly, `start_timestamp_ms`, `duration_ms`, `host_tracer_level`, `advanced_configuration`, and `python_tracer_level`. PiperOrigin-RevId: 885730249
1 parent 4e6e6b7 commit 92fb5dc

2 files changed

Lines changed: 137 additions & 9 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,55 @@ def toy_computation() -> None:
5959

6060
def _create_profile_request(
6161
log_dir: os.PathLike[str] | str,
62+
profiler_options: jax.profiler.ProfileOptions | None = None,
6263
) -> Mapping[str, Any]:
6364
"""Creates a profile request mapping from the given options."""
64-
profile_request = {}
65-
profile_request["traceLocation"] = str(log_dir)
65+
if profiler_options is None:
66+
profiler_options = jax.profiler.ProfileOptions()
67+
68+
advanced_config = None
69+
if profiler_options.advanced_configuration:
70+
advanced_config = {}
71+
for k, v in profiler_options.advanced_configuration.items():
72+
# Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
73+
# json-compatible dict
74+
if isinstance(v, bool):
75+
advanced_config[k] = {"boolValue": v}
76+
elif isinstance(v, int):
77+
advanced_config[k] = {"intValue": v}
78+
elif isinstance(v, str):
79+
advanced_config[k] = {"stringValue": v}
80+
else:
81+
raise ValueError(
82+
f"Unsupported advanced configuration value type: {type(v)}. "
83+
"Supported types are bool, int, and str."
84+
)
85+
86+
xprof_options: dict[str, Any] = {
87+
"traceDirectory": str(log_dir),
88+
}
89+
90+
if profiler_options.host_tracer_level != 2:
91+
xprof_options["hostTraceLevel"] = profiler_options.host_tracer_level
92+
93+
pw_trace_opts: dict[str, Any] = {}
94+
if profiler_options.python_tracer_level:
95+
pw_trace_opts["enablePythonTracer"] = bool(
96+
profiler_options.python_tracer_level
97+
)
98+
99+
if advanced_config:
100+
pw_trace_opts["advancedConfiguration"] = advanced_config
101+
102+
if pw_trace_opts:
103+
xprof_options["pwTraceOptions"] = pw_trace_opts
104+
105+
profile_request: dict[str, Any] = {
106+
"xprofTraceOptions": xprof_options,
107+
}
108+
109+
if profiler_options.duration_ms > 0:
110+
profile_request["maxDurationSecs"] = profiler_options.duration_ms / 1000.0
66111

67112
return profile_request
68113

@@ -104,7 +149,7 @@ def start_trace(
104149
*,
105150
create_perfetto_link: bool = False,
106151
create_perfetto_trace: bool = False,
107-
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
152+
profiler_options: jax.profiler.ProfileOptions | None = None,
108153
) -> None:
109154
"""Starts a profiler trace.
110155
@@ -133,7 +178,6 @@ def start_trace(
133178
This feature is experimental for Pathways on Cloud and may not be fully
134179
supported.
135180
profiler_options: Profiler options to configure the profiler for collection.
136-
Options are not currently supported and ignored.
137181
"""
138182
if not str(log_dir).startswith("gs://"):
139183
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -144,7 +188,18 @@ def start_trace(
144188
"features for Pathways on Cloud and may not be fully supported."
145189
)
146190

147-
_start_pathways_trace_from_profile_request(_create_profile_request(log_dir))
191+
if jax.version.__version_info__ < (0, 8, 2) and profiler_options is not None:
192+
_logger.warning(
193+
"ProfileOptions are not supported until JAX 0.8.2 and will be omitted. "
194+
"Some options can be specified via command line flags."
195+
)
196+
profiler_options = None
197+
198+
profile_request = _create_profile_request(log_dir, profiler_options)
199+
200+
_logger.debug("Profile request: %s", profile_request)
201+
202+
_start_pathways_trace_from_profile_request(profile_request)
148203

149204
_original_start_trace(
150205
log_dir=log_dir,

pathwaysutils/test/profiling_test.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,18 @@ def test_start_trace_success(self):
225225

226226
self.mock_toy_computation.assert_called_once()
227227
self.mock_plugin_executable_cls.assert_called_once_with(
228-
json.dumps(
229-
{"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}}
230-
)
228+
json.dumps({
229+
"profileRequest": {
230+
"traceLocation": "gs://test_bucket/test_dir",
231+
"profilingStartTimeNs": 0,
232+
"profilingDurationMs": 0,
233+
"hostTraceLevel": 2,
234+
"pwTraceOptions": {
235+
"advancedConfiguration": {},
236+
"enablePythonTracer": True,
237+
},
238+
}
239+
})
231240
)
232241
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
233242
self.mock_original_start_trace.assert_called_once_with(
@@ -393,7 +402,71 @@ def test_monkey_patched_stop_server(self):
393402

394403
def test_create_profile_request_no_options(self):
395404
request = profiling._create_profile_request("gs://bucket/dir")
396-
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
405+
self.assertEqual(
406+
request,
407+
{
408+
"traceLocation": "gs://bucket/dir",
409+
"profilingStartTimeNs": 0,
410+
"profilingDurationMs": 0,
411+
"hostTraceLevel": 2,
412+
"pwTraceOptions": {
413+
"advancedConfiguration": {},
414+
"enablePythonTracer": True,
415+
},
416+
},
417+
)
418+
419+
def test_create_profile_request_default_options(self):
420+
options = jax.profiler.ProfileOptions()
421+
request = profiling._create_profile_request(
422+
"gs://bucket/dir", profiler_options=options
423+
)
424+
self.assertEqual(
425+
request,
426+
{
427+
"traceLocation": "gs://bucket/dir",
428+
"profilingStartTimeNs": 0,
429+
"profilingDurationMs": 0,
430+
"hostTraceLevel": 2,
431+
"pwTraceOptions": {
432+
"advancedConfiguration": {},
433+
"enablePythonTracer": True,
434+
},
435+
},
436+
)
437+
438+
def test_create_profile_request_with_options(self):
439+
options = jax.profiler.ProfileOptions()
440+
options.host_tracer_level = 2
441+
options.python_tracer_level = 1
442+
options.duration_ms = 2000
443+
options.start_timestamp_ns = 123456789
444+
options.advanced_configuration = {
445+
"tpu_num_chips_to_profile_per_task": 3,
446+
"tpu_num_sparse_core_tiles_to_trace": 5,
447+
"tpu_trace_mode": "TRACE_COMPUTE",
448+
}
449+
450+
request = profiling._create_profile_request(
451+
"gs://bucket/dir", profiler_options=options
452+
)
453+
self.assertEqual(
454+
request,
455+
{
456+
"traceLocation": "gs://bucket/dir",
457+
"hostTraceLevel": 2,
458+
"profilingDurationMs": 2000,
459+
"profilingStartTimeNs": 123456789,
460+
"pwTraceOptions": {
461+
"enablePythonTracer": True,
462+
"advancedConfiguration": {
463+
"tpu_num_chips_to_profile_per_task": 3,
464+
"tpu_num_sparse_core_tiles_to_trace": 5,
465+
"tpu_trace_mode": "TRACE_COMPUTE",
466+
},
467+
},
468+
},
469+
)
397470

398471
@parameterized.parameters(
399472
({"traceLocation": "gs://test_bucket/test_dir"},),

0 commit comments

Comments
 (0)