@@ -57,12 +57,72 @@ def toy_computation() -> None:
5757 x .block_until_ready ()
5858
5959
60+ def _is_default_profile_options (
61+ profiler_options : jax .profiler .ProfileOptions ,
62+ ) -> bool :
63+ default_options = jax .profiler .ProfileOptions ()
64+ return (
65+ profiler_options .host_tracer_level == default_options .host_tracer_level
66+ and profiler_options .python_tracer_level
67+ == default_options .python_tracer_level
68+ and profiler_options .duration_ms == default_options .duration_ms
69+ and not profiler_options .advanced_configuration
70+ )
71+
72+
6073def _create_profile_request (
6174 log_dir : os .PathLike [str ] | str ,
75+ profiler_options : jax .profiler .ProfileOptions | None = None ,
6276) -> Mapping [str , Any ]:
6377 """Creates a profile request mapping from the given options."""
64- profile_request = {}
65- profile_request ["traceLocation" ] = str (log_dir )
78+ profile_request : dict [str , Any ] = {
79+ "traceLocation" : str (log_dir ),
80+ }
81+
82+ if profiler_options is None or _is_default_profile_options (profiler_options ):
83+ return profile_request
84+
85+ advanced_config = None
86+ if profiler_options .advanced_configuration :
87+ advanced_config = {}
88+ for k , v in profiler_options .advanced_configuration .items ():
89+ # Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
90+ # json-compatible dict
91+ if isinstance (v , bool ):
92+ advanced_config [k ] = {"boolValue" : v }
93+ elif isinstance (v , int ):
94+ advanced_config [k ] = {"intValue" : v }
95+ elif isinstance (v , str ):
96+ advanced_config [k ] = {"stringValue" : v }
97+ else :
98+ raise ValueError (
99+ f"Unsupported advanced configuration value type: { type (v )} . "
100+ "Supported types are bool, int, and str."
101+ )
102+
103+ xprof_options : dict [str , Any ] = {
104+ "traceDirectory" : str (log_dir ),
105+ }
106+
107+ if profiler_options .host_tracer_level != 2 :
108+ xprof_options ["hostTraceLevel" ] = profiler_options .host_tracer_level
109+
110+ pw_trace_opts : dict [str , Any ] = {}
111+ if profiler_options .python_tracer_level :
112+ pw_trace_opts ["enablePythonTracer" ] = bool (
113+ profiler_options .python_tracer_level
114+ )
115+
116+ if advanced_config :
117+ pw_trace_opts ["advancedConfiguration" ] = advanced_config
118+
119+ if pw_trace_opts :
120+ xprof_options ["pwTraceOptions" ] = pw_trace_opts
121+
122+ profile_request ["xprofTraceOptions" ] = xprof_options
123+
124+ if profiler_options .duration_ms > 0 :
125+ profile_request ["maxDurationSecs" ] = profiler_options .duration_ms / 1000.0
66126
67127 return profile_request
68128
@@ -104,7 +164,7 @@ def start_trace(
104164 * ,
105165 create_perfetto_link : bool = False ,
106166 create_perfetto_trace : bool = False ,
107- profiler_options : jax .profiler .ProfileOptions | None = None , # pylint: disable=unused-argument
167+ profiler_options : jax .profiler .ProfileOptions | None = None ,
108168) -> None :
109169 """Starts a profiler trace.
110170
@@ -133,7 +193,6 @@ def start_trace(
133193 This feature is experimental for Pathways on Cloud and may not be fully
134194 supported.
135195 profiler_options: Profiler options to configure the profiler for collection.
136- Options are not currently supported and ignored.
137196 """
138197 if not str (log_dir ).startswith ("gs://" ):
139198 raise ValueError (f"log_dir must be a GCS bucket path, got { log_dir } " )
@@ -144,7 +203,18 @@ def start_trace(
144203 "features for Pathways on Cloud and may not be fully supported."
145204 )
146205
147- _start_pathways_trace_from_profile_request (_create_profile_request (log_dir ))
206+ if jax .version .__version_info__ < (0 , 9 , 2 ) and profiler_options is not None :
207+ _logger .warning (
208+ "ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
209+ "Some options can be specified via command line flags."
210+ )
211+ profiler_options = None
212+
213+ profile_request = _create_profile_request (log_dir , profiler_options )
214+
215+ _logger .debug ("Profile request: %s" , profile_request )
216+
217+ _start_pathways_trace_from_profile_request (profile_request )
148218
149219 _original_start_trace (
150220 log_dir = log_dir ,
0 commit comments