|
| 1 | +r"""Run a TPU workload with Shared Pathways Service. |
| 2 | +
|
| 3 | +Run your TPU workload locally using Shared Pathways Service, the service will |
| 4 | +deploy a Pathways proxy to run the TPU-specific components of your workload on |
| 5 | +the requested TPU slices. |
| 6 | +
|
| 7 | +Example: |
| 8 | +python3 run_tpu_workload.py \ |
| 9 | + --cluster my-cluster \ |
| 10 | + --project my-project \ |
| 11 | + --region=us-central1 \ |
| 12 | + --gcs_bucket=my-gcs-bucket \ |
| 13 | + --pathways_service=pathways-head:8000 \ |
| 14 | + --tpu_type=tpuv6e:4x8 \ |
| 15 | + --tpu_count=1 \ |
| 16 | + --command "python3 my_workload.py ..." |
| 17 | +
|
| 18 | +""" |
| 19 | + |
| 20 | +import subprocess |
| 21 | +from collections.abc import Sequence |
| 22 | + |
| 23 | +from absl import app |
| 24 | +from absl import flags |
| 25 | +from absl import logging |
| 26 | +from pathwaysutils.experimental.shared_pathways_service import isc_pathways |
| 27 | + |
| 28 | + |
| 29 | +FLAGS = flags.FLAGS |
| 30 | + |
| 31 | +flags.DEFINE_string("cluster", None, "The name of the GKE cluster.") |
| 32 | +flags.DEFINE_string("project", None, "The GCP project ID.") |
| 33 | +flags.DEFINE_string("region", None, "The GCP region.") |
| 34 | +flags.DEFINE_string("gcs_bucket", None, "The Google Cloud Storage bucket.") |
| 35 | +flags.DEFINE_string( |
| 36 | + "pathways_service", |
| 37 | + None, |
| 38 | + "The address and port of the Pathways Resource Manager.", |
| 39 | +) |
| 40 | +flags.DEFINE_string( |
| 41 | + "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." |
| 42 | +) |
| 43 | +flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.") |
| 44 | +flags.DEFINE_string( |
| 45 | + "proxy_job_name", |
| 46 | + None, |
| 47 | + "The name to use for the GKE job for proxy. If not provided, a random name" |
| 48 | + " will be generated.", |
| 49 | +) |
| 50 | +flags.DEFINE_string( |
| 51 | + "proxy_server_image", |
| 52 | + None, |
| 53 | + "The proxy server image to use. If not provided, a default will be used.", |
| 54 | +) |
| 55 | +flags.DEFINE_list( |
| 56 | + "proxy_options", |
| 57 | + None, |
| 58 | + "Configuration options for the Pathways proxy. Specify entries in the form" |
| 59 | + ' "key:value". For example: --proxy_options=use_insecure_credentials:true', |
| 60 | +) |
| 61 | +flags.DEFINE_string("command", None, "The command to run on TPUs.") |
| 62 | + |
| 63 | + |
| 64 | +def run_workload( |
| 65 | + *, |
| 66 | + cluster: str, |
| 67 | + project: str, |
| 68 | + region: str, |
| 69 | + gcs_bucket: str, |
| 70 | + pathways_service: str, |
| 71 | + tpu_type: str, |
| 72 | + tpu_count: int, |
| 73 | + command: str, |
| 74 | + proxy_job_name: str | None = None, |
| 75 | + proxy_server_image: str | None = None, |
| 76 | + proxy_options: Sequence[str] | None = None, |
| 77 | + connect_fn=isc_pathways.connect, |
| 78 | +) -> None: |
| 79 | + """Runs the TPU workload within a Shared Pathways connection. |
| 80 | +
|
| 81 | + Args: |
| 82 | + cluster: The name of the GKE cluster. |
| 83 | + project: The GCP project ID. |
| 84 | + region: The GCP region. |
| 85 | + gcs_bucket: The Google Cloud Storage bucket. |
| 86 | + pathways_service: The address and port of the Pathways Resource Manager. |
| 87 | + tpu_type: The TPU machine type and topology. |
| 88 | + tpu_count: The number of TPU slices. |
| 89 | + command: The command to run on TPUs. |
| 90 | + proxy_job_name: The name to use for the GKE job for proxy. |
| 91 | + proxy_server_image: The proxy server image to use. |
| 92 | + proxy_options: Configuration options for the Pathways proxy. |
| 93 | + connect_fn: The function to use for establishing the connection context. |
| 94 | + """ |
| 95 | + parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options) |
| 96 | + |
| 97 | + logging.info("Connecting to Shared Pathways Service...") |
| 98 | + with connect_fn( |
| 99 | + cluster=cluster, |
| 100 | + project=project, |
| 101 | + region=region, |
| 102 | + gcs_bucket=gcs_bucket, |
| 103 | + pathways_service=pathways_service, |
| 104 | + expected_tpu_instances={tpu_type: tpu_count}, |
| 105 | + proxy_job_name=proxy_job_name, |
| 106 | + proxy_server_image=proxy_server_image or isc_pathways.DEFAULT_PROXY_IMAGE, |
| 107 | + proxy_options=parsed_proxy_options, |
| 108 | + ): |
| 109 | + logging.info("Connection established. Running command: %s", command) |
| 110 | + try: |
| 111 | + subprocess.run(command, shell=True, check=True) |
| 112 | + except subprocess.CalledProcessError as e: |
| 113 | + logging.error("Command failed with error: %s", e) |
| 114 | + raise |
| 115 | + |
| 116 | + |
| 117 | +def main(argv: Sequence[str]) -> None: |
| 118 | + if len(argv) > 1: |
| 119 | + raise app.UsageError("Too many command-line arguments.") |
| 120 | + |
| 121 | + flags.mark_flags_as_required([ |
| 122 | + "cluster", |
| 123 | + "project", |
| 124 | + "region", |
| 125 | + "gcs_bucket", |
| 126 | + "pathways_service", |
| 127 | + "command", |
| 128 | + ]) |
| 129 | + |
| 130 | + run_workload( |
| 131 | + cluster=FLAGS.cluster, |
| 132 | + project=FLAGS.project, |
| 133 | + region=FLAGS.region, |
| 134 | + gcs_bucket=FLAGS.gcs_bucket, |
| 135 | + pathways_service=FLAGS.pathways_service, |
| 136 | + tpu_type=FLAGS.tpu_type, |
| 137 | + tpu_count=FLAGS.tpu_count, |
| 138 | + command=FLAGS.command, |
| 139 | + proxy_job_name=FLAGS.proxy_job_name, |
| 140 | + proxy_server_image=FLAGS.proxy_server_image, |
| 141 | + proxy_options=FLAGS.proxy_options, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + app.run(main) |
0 commit comments