diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 34b315e..0b888e3 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -59,10 +59,22 @@ def toy_computation() -> None: def _create_profile_request( log_dir: os.PathLike[str] | str, + profiler_options: jax.profiler.ProfileOptions | None = None, ) -> Mapping[str, Any]: """Creates a profile request mapping from the given options.""" - profile_request = {} - profile_request["traceLocation"] = str(log_dir) + if profiler_options is None: + profiler_options = jax.profiler.ProfileOptions() + + profile_request = { + "traceLocation": str(log_dir), + "profilingStartTimeNs": profiler_options.start_timestamp_ns, + "profilingDurationMs": profiler_options.duration_ms, + "hostTraceLevel": profiler_options.host_tracer_level, + "pwTraceOptions": { + "advancedConfiguration": profiler_options.advanced_configuration, + "enablePythonTracer": bool(profiler_options.python_tracer_level), + }, + } return profile_request @@ -104,7 +116,7 @@ def start_trace( *, create_perfetto_link: bool = False, create_perfetto_trace: bool = False, - profiler_options: jax.profiler.ProfileOptions | None = None, # pylint: disable=unused-argument + profiler_options: jax.profiler.ProfileOptions | None = None, ) -> None: """Starts a profiler trace. @@ -133,7 +145,6 @@ def start_trace( This feature is experimental for Pathways on Cloud and may not be fully supported. profiler_options: Profiler options to configure the profiler for collection. - Options are not currently supported and ignored. """ if not str(log_dir).startswith("gs://"): raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}") @@ -144,7 +155,11 @@ def start_trace( "features for Pathways on Cloud and may not be fully supported." ) - _start_pathways_trace_from_profile_request(_create_profile_request(log_dir)) + profile_request = _create_profile_request(log_dir, profiler_options) + + _logger.debug("Profile request: %s", profile_request) + + _start_pathways_trace_from_profile_request(profile_request) _original_start_trace( log_dir=log_dir, diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index e2cbe4f..5edd48a 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -225,9 +225,18 @@ def test_start_trace_success(self): self.mock_toy_computation.assert_called_once() self.mock_plugin_executable_cls.assert_called_once_with( - json.dumps( - {"profileRequest": {"traceLocation": "gs://test_bucket/test_dir"}} - ) + json.dumps({ + "profileRequest": { + "traceLocation": "gs://test_bucket/test_dir", + "profilingStartTimeNs": 0, + "profilingDurationMs": 0, + "hostTraceLevel": 2, + "pwTraceOptions": { + "advancedConfiguration": {}, + "enablePythonTracer": True, + }, + } + }) ) self.mock_plugin_executable_cls.return_value.call.assert_called_once() self.mock_original_start_trace.assert_called_once_with( @@ -393,7 +402,71 @@ def test_monkey_patched_stop_server(self): def test_create_profile_request_no_options(self): request = profiling._create_profile_request("gs://bucket/dir") - self.assertEqual(request, {"traceLocation": "gs://bucket/dir"}) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + "profilingStartTimeNs": 0, + "profilingDurationMs": 0, + "hostTraceLevel": 2, + "pwTraceOptions": { + "advancedConfiguration": {}, + "enablePythonTracer": True, + }, + }, + ) + + def test_create_profile_request_default_options(self): + options = jax.profiler.ProfileOptions() + request = profiling._create_profile_request( + "gs://bucket/dir", profiler_options=options + ) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + "profilingStartTimeNs": 0, + "profilingDurationMs": 0, + "hostTraceLevel": 2, + "pwTraceOptions": { + "advancedConfiguration": {}, + "enablePythonTracer": True, + }, + }, + ) + + def test_create_profile_request_with_options(self): + options = jax.profiler.ProfileOptions() + options.host_tracer_level = 2 + options.python_tracer_level = 1 + options.duration_ms = 2000 + options.start_timestamp_ns = 123456789 + options.advanced_configuration = { + "tpu_num_chips_to_profile_per_task": 3, + "tpu_num_sparse_core_tiles_to_trace": 5, + "tpu_trace_mode": "TRACE_COMPUTE", + } + + request = profiling._create_profile_request( + "gs://bucket/dir", profiler_options=options + ) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + "hostTraceLevel": 2, + "profilingDurationMs": 2000, + "profilingStartTimeNs": 123456789, + "pwTraceOptions": { + "enablePythonTracer": True, + "advancedConfiguration": { + "tpu_num_chips_to_profile_per_task": 3, + "tpu_num_sparse_core_tiles_to_trace": 5, + "tpu_trace_mode": "TRACE_COMPUTE", + }, + }, + }, + ) @parameterized.parameters( ({"traceLocation": "gs://test_bucket/test_dir"},),