Skip to content

Commit ddf933f

Browse files
guptaakacopybara-github
authored andcommitted
Add CLI mode to Shared Pathways Service
This commit adds a new script `run_tpu_workload.py`, which allows users to provide a command to the Shared Pathways Service. The user does not need to make any code changes. They can simply run this script with `--command` flag. PiperOrigin-RevId: 885215367
1 parent a57c2a0 commit ddf933f

1 file changed

Lines changed: 146 additions & 0 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
r"""Run a TPU workload with Shared Pathways Service.
2+
3+
Run your TPU workload locally using Shared Pathways Service, the service will
4+
deploy a Pathways proxy to run the TPU-specific components of your workload on
5+
the requested TPU slices.
6+
7+
Example:
8+
python3 run_tpu_workload.py \
9+
--cluster my-cluster \
10+
--project my-project \
11+
--region=us-central1 \
12+
--gcs_bucket=my-gcs-bucket \
13+
--pathways_service=pathways-head:8000 \
14+
--tpu_type=tpuv6e:4x8 \
15+
--tpu_count=1 \
16+
--command "python3 my_workload.py ..."
17+
18+
"""
19+
20+
import subprocess
21+
from collections.abc import Sequence
22+
23+
from absl import app
24+
from absl import flags
25+
from absl import logging
26+
from pathwaysutils.experimental.shared_pathways_service import isc_pathways
27+
28+
29+
FLAGS = flags.FLAGS
30+
31+
flags.DEFINE_string("cluster", None, "The name of the GKE cluster.")
32+
flags.DEFINE_string("project", None, "The GCP project ID.")
33+
flags.DEFINE_string("region", None, "The GCP region.")
34+
flags.DEFINE_string("gcs_bucket", None, "The Google Cloud Storage bucket.")
35+
flags.DEFINE_string(
36+
"pathways_service",
37+
None,
38+
"The address and port of the Pathways Resource Manager.",
39+
)
40+
flags.DEFINE_string(
41+
"tpu_type", "tpuv6e:2x2", "The TPU machine type and topology."
42+
)
43+
flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.")
44+
flags.DEFINE_string(
45+
"proxy_job_name",
46+
None,
47+
"The name to use for the GKE job for proxy. If not provided, a random name"
48+
" will be generated.",
49+
)
50+
flags.DEFINE_string(
51+
"proxy_server_image",
52+
None,
53+
"The proxy server image to use. If not provided, a default will be used.",
54+
)
55+
flags.DEFINE_list(
56+
"proxy_options",
57+
None,
58+
"Configuration options for the Pathways proxy. Specify entries in the form"
59+
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
60+
)
61+
flags.DEFINE_string("command", None, "The command to run on TPUs.")
62+
63+
64+
def run_workload(
65+
*,
66+
cluster: str,
67+
project: str,
68+
region: str,
69+
gcs_bucket: str,
70+
pathways_service: str,
71+
tpu_type: str,
72+
tpu_count: int,
73+
command: str,
74+
proxy_job_name: str | None = None,
75+
proxy_server_image: str | None = None,
76+
proxy_options: Sequence[str] | None = None,
77+
connect_fn=isc_pathways.connect,
78+
) -> None:
79+
"""Runs the TPU workload within a Shared Pathways connection.
80+
81+
Args:
82+
cluster: The name of the GKE cluster.
83+
project: The GCP project ID.
84+
region: The GCP region.
85+
gcs_bucket: The Google Cloud Storage bucket.
86+
pathways_service: The address and port of the Pathways Resource Manager.
87+
tpu_type: The TPU machine type and topology.
88+
tpu_count: The number of TPU slices.
89+
command: The command to run on TPUs.
90+
proxy_job_name: The name to use for the GKE job for proxy.
91+
proxy_server_image: The proxy server image to use.
92+
proxy_options: Configuration options for the Pathways proxy.
93+
connect_fn: The function to use for establishing the connection context.
94+
"""
95+
parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options)
96+
97+
logging.info("Connecting to Shared Pathways Service...")
98+
with connect_fn(
99+
cluster=cluster,
100+
project=project,
101+
region=region,
102+
gcs_bucket=gcs_bucket,
103+
pathways_service=pathways_service,
104+
expected_tpu_instances={tpu_type: tpu_count},
105+
proxy_job_name=proxy_job_name,
106+
proxy_server_image=proxy_server_image or isc_pathways.DEFAULT_PROXY_IMAGE,
107+
proxy_options=parsed_proxy_options,
108+
):
109+
logging.info("Connection established. Running command: %s", command)
110+
try:
111+
subprocess.run(command, shell=True, check=True)
112+
except subprocess.CalledProcessError as e:
113+
logging.error("Command failed with error: %s", e)
114+
raise
115+
116+
117+
def main(argv: Sequence[str]) -> None:
118+
if len(argv) > 1:
119+
raise app.UsageError("Too many command-line arguments.")
120+
121+
flags.mark_flags_as_required([
122+
"cluster",
123+
"project",
124+
"region",
125+
"gcs_bucket",
126+
"pathways_service",
127+
"command",
128+
])
129+
130+
run_workload(
131+
cluster=FLAGS.cluster,
132+
project=FLAGS.project,
133+
region=FLAGS.region,
134+
gcs_bucket=FLAGS.gcs_bucket,
135+
pathways_service=FLAGS.pathways_service,
136+
tpu_type=FLAGS.tpu_type,
137+
tpu_count=FLAGS.tpu_count,
138+
command=FLAGS.command,
139+
proxy_job_name=FLAGS.proxy_job_name,
140+
proxy_server_image=FLAGS.proxy_server_image,
141+
proxy_options=FLAGS.proxy_options,
142+
)
143+
144+
145+
if __name__ == "__main__":
146+
app.run(main)

0 commit comments

Comments
 (0)