-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathprofiling.py
More file actions
398 lines (318 loc) · 12.8 KB
/
profiling.py
File metadata and controls
398 lines (318 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiling Utilities."""
import asyncio
import dataclasses
import json
import logging
import os
import threading
from typing import Any, Mapping
import urllib.parse
import fastapi
import jax
from jax import numpy as jnp
from pathwaysutils import plugin_executable
import requests
import uvicorn
_logger = logging.getLogger(__name__)
class _ProfileState:
executable: plugin_executable.PluginExecutable | None = None
profile_request: Mapping[str, Any] | None = None
lock: threading.Lock
def __init__(self) -> None:
self.executable = None
self.profile_request = None
self.lock = threading.Lock()
def reset(self) -> None:
self.executable = None
self.profile_request = None
_first_profile_start = True
_profile_state = _ProfileState()
_original_start_trace = jax.profiler.start_trace
_original_stop_trace = jax.profiler.stop_trace
def toy_computation() -> None:
"""A toy computation to run before the first profile."""
x = jax.jit(lambda x: x + 1)(jnp.array(1))
x.block_until_ready()
def _is_default_profile_options(
profiler_options: jax.profiler.ProfileOptions,
) -> bool:
if jax.version.__version_info__ < (0, 9, 2):
return True
default_options = jax.profiler.ProfileOptions()
return (
profiler_options.host_tracer_level == default_options.host_tracer_level
and profiler_options.python_tracer_level
== default_options.python_tracer_level
and profiler_options.duration_ms == default_options.duration_ms
and not getattr(profiler_options, "advanced_configuration", None)
)
def _create_profile_request(
log_dir: os.PathLike[str] | str,
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> Mapping[str, Any]:
"""Creates a profile request mapping from the given options."""
profile_request: dict[str, Any] = {
"traceLocation": str(log_dir),
}
if profiler_options is None or _is_default_profile_options(profiler_options):
return profile_request
advanced_config = None
if getattr(profiler_options, "advanced_configuration", None):
advanced_config = {}
for k, v in getattr(profiler_options, "advanced_configuration").items():
# Convert python dict to tensorflow.ProfileOptions.AdvancedConfigValue
# json-compatible dict
if isinstance(v, bool):
advanced_config[k] = {"boolValue": v}
elif isinstance(v, int):
advanced_config[k] = {"intValue": v}
elif isinstance(v, str):
advanced_config[k] = {"stringValue": v}
else:
raise ValueError(
f"Unsupported advanced configuration value type: {type(v)}. "
"Supported types are bool, int, and str."
)
xprof_options: dict[str, Any] = {
"traceDirectory": str(log_dir),
}
if profiler_options.host_tracer_level != 2:
xprof_options["hostTraceLevel"] = profiler_options.host_tracer_level
pw_trace_opts: dict[str, Any] = {}
if profiler_options.python_tracer_level:
pw_trace_opts["enablePythonTracer"] = bool(
profiler_options.python_tracer_level
)
if advanced_config:
pw_trace_opts["advancedConfiguration"] = advanced_config
if pw_trace_opts:
xprof_options["pwTraceOptions"] = pw_trace_opts
profile_request["xprofTraceOptions"] = xprof_options
if profiler_options.duration_ms > 0:
profile_request["maxDurationSecs"] = profiler_options.duration_ms / 1000.0
return profile_request
def _start_pathways_trace_from_profile_request(
profile_request: Mapping[str, Any],
) -> None:
"""Starts a profiler trace on Pathways components from a profile request.
This will only profile the Pathways components and not the JAX client code.
Args:
profile_request: A mapping containing the profile request options.
"""
with _profile_state.lock:
global _first_profile_start
if _first_profile_start:
_first_profile_start = False
toy_computation()
if _profile_state.executable is not None:
raise ValueError(
"start_trace called while a trace is already being taken!"
)
_profile_state.executable = plugin_executable.PluginExecutable(
json.dumps({"profileRequest": profile_request})
)
_profile_state.profile_request = profile_request
try:
_, result_future = _profile_state.executable.call()
result_future.result()
except Exception:
_logger.exception("Failed to start trace")
_profile_state.reset()
raise
def start_trace(
log_dir: os.PathLike[str] | str,
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> None:
"""Starts a profiler trace.
The trace will capture CPU and TPU activity, including Python
functions and JAX on-device operations. Use :func:`stop_trace` to end the
trace and save the results to ``log_dir``.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only one trace may be collected at a time. A RuntimeError will be raised if
:func:`start_trace` is called while another trace is running.
Args:
log_dir: The GCS directory to save the profiler trace to (usually the
TensorBoard log directory), e.g., "gs://my_bucket/profiles".
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace. This feature
is experimental for Pathways on Cloud and may not be fully supported.
create_perfetto_trace: A boolean which, if true, additionally dumps a
``perfetto_trace.json.gz`` file that is compatible for upload with the
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
generated if ``create_perfetto_link`` is true. This could be useful if you
want to generate a Perfetto-compatible trace without blocking the process.
This feature is experimental for Pathways on Cloud and may not be fully
supported.
profiler_options: Profiler options to configure the profiler for collection.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
if create_perfetto_link or create_perfetto_trace:
_logger.warning(
"create_perfetto_link and create_perfetto_trace are experimental "
"features for Pathways on Cloud and may not be fully supported."
)
if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None:
_logger.warning(
"ProfileOptions are not supported until JAX 0.9.2 and will be omitted. "
"Some options can be specified via command line flags."
)
profiler_options = None
profile_request = _create_profile_request(log_dir, profiler_options)
_logger.debug("Profile request: %s", profile_request)
_start_pathways_trace_from_profile_request(profile_request)
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
)
def stop_trace() -> None:
"""Stops the currently-running profiler trace."""
try:
with _profile_state.lock:
if _profile_state.executable is None:
raise ValueError("stop_trace called before a trace is being taken!")
try:
if (
_profile_state.profile_request
and "xprofTraceOptions" in _profile_state.profile_request
):
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
else:
out_avals = ()
out_shardings = ()
_profile_state.executable.call(
out_avals=out_avals, out_shardings=out_shardings
)
finally:
_profile_state.reset()
finally:
_original_stop_trace()
_profiler_thread: threading.Thread | None = None
def start_server(port: int) -> None:
"""Starts the profiling server on port `port`.
The signature is slightly different from `jax.profiler.start_server`
because no handle to the server is returned because there is no
`xla_client.profiler.ProfilerServer` to return.
Args:
port : The port to start the server on.
"""
def server_loop(port: int):
_logger.debug("Starting JAX profiler server on port %s", port)
app = fastapi.FastAPI()
@dataclasses.dataclass
class ProfilingConfig:
duration_ms: int
repository_path: str
@app.post("/profiling")
async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
_logger.debug("Writing profiling data to %s", pc.repository_path)
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)
await asyncio.sleep(pc.duration_ms / 1e3)
await asyncio.to_thread(jax.profiler.stop_trace)
return {"response": "profiling completed"}
uvicorn.run(app, host="0.0.0.0", port=port, log_level="debug")
global _profiler_thread
if _profiler_thread is not None:
raise ValueError("Only one profiler server can be active at a time.")
_profiler_thread = threading.Thread(target=server_loop, args=(port,))
_profiler_thread.start()
def stop_server() -> None:
"""Raises an error if there is no active profiler server.
Pathways profiling servers are not stoppable at this time.
"""
if _profiler_thread is None:
raise ValueError("No active profiler server.")
def collect_profile(
port: int,
duration_ms: int,
host: str,
log_dir: os.PathLike[str] | str,
) -> bool:
"""Collects a JAX profile and saves it to the specified directory.
Args:
port: The port on which the JAX profiler server is running.
duration_ms: The duration in milliseconds for which to collect the profile.
host: The host on which the JAX profiler server is running.
log_dir: The GCS path to save the profile data.
Returns:
True if the profile was collected successfully, False otherwise.
Raises:
ValueError: If the log_dir is not a GCS path.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
request_json = {
"duration_ms": duration_ms,
"repository_path": log_dir,
}
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
try:
response = requests.post(address, json=request_json)
response.raise_for_status()
except requests.exceptions.RequestException:
_logger.exception("Failed to collect profiling data")
return False
return True
def monkey_patch_jax() -> None:
"""Monkey patches JAX with Pathways versions of functions.
The signatures in patched functions should match the original.
Patched functions are:
- `jax.profiler.start_trace`
https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.start_trace.html
- `jax.profiler.stop_trace`
https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.stop_trace.html
- `jax.profiler.start_server`
https://jax.readthedocs.io/en/latest/_autosummary/jax.profiler.start_server.html
- `jax.profiler.stop_server`
"""
def start_trace_patch(
log_dir,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None,
) -> None:
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
start_trace(
log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
profiler_options=profiler_options,
)
jax.profiler.start_trace = start_trace_patch
jax._src.profiler.start_trace = start_trace_patch # pylint: disable=protected-access
def stop_trace_patch() -> None:
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
stop_trace()
jax.profiler.stop_trace = stop_trace_patch
jax._src.profiler.stop_trace = stop_trace_patch # pylint: disable=protected-access
def start_server_patch(port: int) -> None:
_logger.debug(
"jax.profile.start_server patched with pathways' start_server"
)
start_server(port)
jax.profiler.start_server = start_server_patch
def stop_server_patch() -> None:
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
stop_server()
jax.profiler.stop_server = stop_server_patch