Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,8 @@ inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
prefill_slice: "v5e-16" # slice to use for prefill in disaggregation mode
generate_slice: "v5e-16" # slice to use for generatation in disaggregation mode
grpc_tls_certificate_path: "" # Path to the TLS certificate file for gRPC server
grpc_tls_private_key_path: "" # Path to the TLS private key file for gRPC server
inference_benchmark_test: False
enable_model_warmup: False
enable_llm_inference_pool: False # Bool to launch inference server for llm_inference_gateway with their specified APIs
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,8 @@ class InferenceServer(BaseModel):
inference_server: str = Field("MaxtextInterleavedServer", description="Inference server to start.")
prefill_slice: str = Field("v5e-16", description="Slice to use for prefill in disaggregation mode.")
generate_slice: str = Field("v5e-16", description="Slice to use for generatation in disaggregation mode.")
grpc_tls_certificate_path: str = Field("", description="Path to the TLS certificate file for gRPC server.")
grpc_tls_private_key_path: str = Field("", description="Path to the TLS private key file for gRPC server.")


class InferenceBenchmark(BaseModel):
Expand Down
15 changes: 14 additions & 1 deletion src/maxtext/inference/maxengine/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main(config):

# Import the real server_lib now that it's known present.
from jetstream.core import server_lib # type: ignore # pylint: disable=import-outside-toplevel
import grpc # pylint: disable=import-outside-toplevel
import pathwaysutils # pylint: disable=unused-import,import-outside-toplevel

pathwaysutils.initialize()
Expand All @@ -78,15 +79,27 @@ def main(config):
if config.prometheus_port != 0:
metrics_server_config = config_lib.MetricsServerConfig(port=config.prometheus_port)

# Configure gRPC credentials. To enable secure TLS serving (e.g., for AIVS compliance or production),
# provide both `grpc_tls_certificate_path` (X.509 public cert) and `grpc_tls_private_key_path`.
# Otherwise, defaults to insecure credentials for local unit testing and development.
if config.grpc_tls_certificate_path and config.grpc_tls_private_key_path:
with open(config.grpc_tls_private_key_path, "rb") as f:
private_key = f.read()
with open(config.grpc_tls_certificate_path, "rb") as f:
certificate = f.read()
credentials = grpc.ssl_server_credentials([(private_key, certificate)])
else:
credentials = grpc.insecure_server_credentials()

# We separate credential from run so that we can unit test it with
# local credentials.
# TODO: Add grpc credentials for OSS.
# pylint: disable=unexpected-keyword-arg
jetstream_server = server_lib.run(
threads=256,
port=9000,
config=server_config,
devices=devices,
credentials=credentials,
metrics_server_config=metrics_server_config,
enable_jax_profiler=config.enable_jax_profiler if config.enable_jax_profiler else False,
jax_profiler_port=config.jax_profiler_port if config.jax_profiler_port else 9999,
Expand Down
Loading