From 7eecc6c2f7667876b252e7a9b6f0350478c32d8d Mon Sep 17 00:00:00 2001 From: Branden Vandermoon Date: Wed, 13 May 2026 19:57:59 +0000 Subject: [PATCH] Add grpc authentication to maxengine_server --- src/maxtext/configs/base.yml | 2 ++ src/maxtext/configs/types.py | 2 ++ .../inference/maxengine/maxengine_server.py | 15 ++++++++++++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 6e19ccc445..d50f5cfbfd 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -976,6 +976,8 @@ inference_metadata_file: "" # path to a json file inference_server: "MaxtextInterleavedServer" # inference server to start prefill_slice: "v5e-16" # slice to use for prefill in disaggregation mode generate_slice: "v5e-16" # slice to use for generatation in disaggregation mode +grpc_tls_certificate_path: "" # Path to the TLS certificate file for gRPC server +grpc_tls_private_key_path: "" # Path to the TLS private key file for gRPC server inference_benchmark_test: False enable_model_warmup: False enable_llm_inference_pool: False # Bool to launch inference server for llm_inference_gateway with their specified APIs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..6283b22cf2 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1573,6 +1573,8 @@ class InferenceServer(BaseModel): inference_server: str = Field("MaxtextInterleavedServer", description="Inference server to start.") prefill_slice: str = Field("v5e-16", description="Slice to use for prefill in disaggregation mode.") generate_slice: str = Field("v5e-16", description="Slice to use for generatation in disaggregation mode.") + grpc_tls_certificate_path: str = Field("", description="Path to the TLS certificate file for gRPC server.") + grpc_tls_private_key_path: str = Field("", description="Path to the TLS private key file for gRPC server.") class InferenceBenchmark(BaseModel): diff --git a/src/maxtext/inference/maxengine/maxengine_server.py b/src/maxtext/inference/maxengine/maxengine_server.py index 587a113c78..b4f48ad42c 100644 --- a/src/maxtext/inference/maxengine/maxengine_server.py +++ b/src/maxtext/inference/maxengine/maxengine_server.py @@ -66,6 +66,7 @@ def main(config): # Import the real server_lib now that it's known present. from jetstream.core import server_lib # type: ignore # pylint: disable=import-outside-toplevel + import grpc # pylint: disable=import-outside-toplevel import pathwaysutils # pylint: disable=unused-import,import-outside-toplevel pathwaysutils.initialize() @@ -78,15 +79,27 @@ def main(config): if config.prometheus_port != 0: metrics_server_config = config_lib.MetricsServerConfig(port=config.prometheus_port) + # Configure gRPC credentials. To enable secure TLS serving (e.g., for AIVS compliance or production), + # provide both `grpc_tls_certificate_path` (X.509 public cert) and `grpc_tls_private_key_path`. + # Otherwise, defaults to insecure credentials for local unit testing and development. + if config.grpc_tls_certificate_path and config.grpc_tls_private_key_path: + with open(config.grpc_tls_private_key_path, "rb") as f: + private_key = f.read() + with open(config.grpc_tls_certificate_path, "rb") as f: + certificate = f.read() + credentials = grpc.ssl_server_credentials([(private_key, certificate)]) + else: + credentials = grpc.insecure_server_credentials() + # We separate credential from run so that we can unit test it with # local credentials. - # TODO: Add grpc credentials for OSS. # pylint: disable=unexpected-keyword-arg jetstream_server = server_lib.run( threads=256, port=9000, config=server_config, devices=devices, + credentials=credentials, metrics_server_config=metrics_server_config, enable_jax_profiler=config.enable_jax_profiler if config.enable_jax_profiler else False, jax_profiler_port=config.jax_profiler_port if config.jax_profiler_port else 9999,