Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.jobs.job_replica_http_client import (
SSH_CONNECT_TIMEOUT,
get_service_replica_client,
)
from dstack._internal.server.services.jobs.job_replica_tunnel import SSH_CONNECT_TIMEOUT
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
from dstack._internal.utils.common import get_current_datetime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,14 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import grpc

from dstack._internal.core.services.ssh.tunnel import (
SSH_DEFAULT_OPTIONS,
IPSocket,
SocketPair,
UnixSocket,
)
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.ssh import container_ssh_tunnel
from dstack._internal.utils.common import get_or_error
from dstack._internal.server.services.jobs.job_replica_tunnel import get_service_replica_tunnel

SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
# Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES).
_MAX_GRPC_MESSAGE_BYTES = 256 * 1024
_GRPC_CHANNEL_OPTIONS = (
Expand All @@ -29,29 +18,20 @@
)


@asynccontextmanager
async def get_service_replica_grpc_channel_over_uds(
uds_path: Path,
) -> AsyncGenerator[Any, None]:
target = f"unix://{uds_path}"
channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS)
try:
yield channel
finally:
await channel.close()


@asynccontextmanager
async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]:
options = {
**SSH_DEFAULT_OPTIONS,
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
}
job_spec = get_job_spec(job)
with TemporaryDirectory() as temp_dir:
# Keep the same socket file name as the HTTP helper for consistency.
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
async with container_ssh_tunnel(
job=job,
forwarded_sockets=[
SocketPair(
remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
local=UnixSocket(app_socket_path),
),
],
options=options,
):
target = f"unix://{app_socket_path}"
channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS)
try:
yield channel
finally:
await channel.close()
async with get_service_replica_tunnel(job) as uds_path:
async with get_service_replica_grpc_channel_over_uds(uds_path) as channel:
yield channel
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,26 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from tempfile import TemporaryDirectory

from httpx import AsyncClient, AsyncHTTPTransport

from dstack._internal.core.services.ssh.tunnel import (
SSH_DEFAULT_OPTIONS,
IPSocket,
SocketPair,
UnixSocket,
)
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.ssh import container_ssh_tunnel
from dstack._internal.utils.common import get_or_error
from dstack._internal.server.services.jobs.job_replica_tunnel import get_service_replica_tunnel

SSH_CONNECT_TIMEOUT = timedelta(seconds=10)

@asynccontextmanager
async def get_service_replica_http_client_over_uds(
uds_path: Path,
) -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(transport=AsyncHTTPTransport(uds=str(uds_path))) as client:
yield client


@asynccontextmanager
async def get_service_replica_client(
job: JobModel,
) -> AsyncGenerator[AsyncClient, None]:
options = {
**SSH_DEFAULT_OPTIONS,
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
}
job_spec = get_job_spec(job)
with TemporaryDirectory() as temp_dir:
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
async with container_ssh_tunnel(
job=job,
forwarded_sockets=[
SocketPair(
remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
local=UnixSocket(app_socket_path),
),
],
options=options,
):
async with AsyncClient(
transport=AsyncHTTPTransport(uds=str(app_socket_path))
) as client:
yield client
async with get_service_replica_tunnel(job) as uds_path:
async with get_service_replica_http_client_over_uds(uds_path) as client:
yield client
43 changes: 43 additions & 0 deletions src/dstack/_internal/server/services/jobs/job_replica_tunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""SSH tunnel to a job replica's service port, exposed as a local Unix domain socket."""

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from tempfile import TemporaryDirectory

from dstack._internal.core.services.ssh.tunnel import (
SSH_DEFAULT_OPTIONS,
IPSocket,
SocketPair,
UnixSocket,
)
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.ssh import container_ssh_tunnel
from dstack._internal.utils.common import get_or_error

SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
_REPLICA_SOCKET_NAME = "replica.sock"


@asynccontextmanager
async def get_service_replica_tunnel(job: JobModel) -> AsyncGenerator[Path, None]:
options = {
**SSH_DEFAULT_OPTIONS,
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
}
job_spec = get_job_spec(job)
with TemporaryDirectory() as temp_dir:
app_socket_path = (Path(temp_dir) / _REPLICA_SOCKET_NAME).absolute()
async with container_ssh_tunnel(
job=job,
forwarded_sockets=[
SocketPair(
remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
local=UnixSocket(app_socket_path),
),
],
options=options,
):
yield app_socket_path
Loading
Loading