Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,22 @@ def toy_computation() -> None:

def _create_profile_request(
log_dir: os.PathLike[str] | str,
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> Mapping[str, Any]:
"""Creates a profile request mapping from the given options."""
profile_request = {}
profile_request["traceLocation"] = str(log_dir)
if profiler_options is None:
profiler_options = jax.profiler.ProfileOptions()

profile_request = {
"traceLocation": str(log_dir),
"profilingStartTimeNs": profiler_options.start_timestamp_ns,
"profilingDurationMs": profiler_options.duration_ms,
"hostTraceLevel": profiler_options.host_tracer_level,
"pwTraceOptions": {
"advancedConfiguration": profiler_options.advanced_configuration,
"enablePythonTracer": bool(profiler_options.python_tracer_level),
},
}

return profile_request

Expand Down Expand Up @@ -104,7 +116,7 @@ def start_trace(
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> None:
"""Starts a profiler trace.

Expand Down Expand Up @@ -133,7 +145,6 @@ def start_trace(
This feature is experimental for Pathways on Cloud and may not be fully
supported.
profiler_options: Profiler options to configure the profiler for collection.
Options are not currently supported and ignored.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
Expand All @@ -144,7 +155,11 @@ def start_trace(
"features for Pathways on Cloud and may not be fully supported."
)

_start_pathways_trace_from_profile_request(_create_profile_request(log_dir))
profile_request = _create_profile_request(log_dir, profiler_options)

_logger.debug("Profile request: %s", profile_request)

_start_pathways_trace_from_profile_request(profile_request)

_original_start_trace(
log_dir=log_dir,
Expand Down
81 changes: 77 additions & 4 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,18 @@ def test_start_trace_success(self):

self.mock_toy_computation.assert_called_once()
self.mock_plugin_executable_cls.assert_called_once_with(
json.dumps(
{"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}}
)
json.dumps({
"profileRequest": {
"traceLocation": "gs://test_bucket/test_dir",
"profilingStartTimeNs": 0,
"profilingDurationMs": 0,
"hostTraceLevel": 2,
"pwTraceOptions": {
"advancedConfiguration": {},
"enablePythonTracer": True,
},
}
})
)
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
self.mock_original_start_trace.assert_called_once_with(
Expand Down Expand Up @@ -393,7 +402,71 @@ def test_monkey_patched_stop_server(self):

def test_create_profile_request_no_options(self):
request = profiling._create_profile_request("gs://bucket/dir")
self.assertEqual(request, {"traceLocation": "gs://bucket/dir"})
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
"profilingStartTimeNs": 0,
"profilingDurationMs": 0,
"hostTraceLevel": 2,
"pwTraceOptions": {
"advancedConfiguration": {},
"enablePythonTracer": True,
},
},
)

def test_create_profile_request_default_options(self):
options = jax.profiler.ProfileOptions()
request = profiling._create_profile_request(
"gs://bucket/dir", profiler_options=options
)
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
"profilingStartTimeNs": 0,
"profilingDurationMs": 0,
"hostTraceLevel": 2,
"pwTraceOptions": {
"advancedConfiguration": {},
"enablePythonTracer": True,
},
},
)

def test_create_profile_request_with_options(self):
options = jax.profiler.ProfileOptions()
options.host_tracer_level = 2
options.python_tracer_level = 1
options.duration_ms = 2000
options.start_timestamp_ns = 123456789
options.advanced_configuration = {
"tpu_num_chips_to_profile_per_task": 3,
"tpu_num_sparse_core_tiles_to_trace": 5,
"tpu_trace_mode": "TRACE_COMPUTE",
}

request = profiling._create_profile_request(
"gs://bucket/dir", profiler_options=options
)
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
"hostTraceLevel": 2,
"profilingDurationMs": 2000,
"profilingStartTimeNs": 123456789,
"pwTraceOptions": {
"enablePythonTracer": True,
"advancedConfiguration": {
"tpu_num_chips_to_profile_per_task": 3,
"tpu_num_sparse_core_tiles_to_trace": 5,
"tpu_trace_mode": "TRACE_COMPUTE",
},
},
},
)

@parameterized.parameters(
({"traceLocation": "gs://test_bucket/test_dir"},),
Expand Down
Loading