diff --git a/setup.py b/setup.py index 378c908d..b3295e4a 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ "boto3>=1.35.3,<2.0", "botocore>=1.35.6 ", "kubernetes==33.1.0", + "kr8s>=0.20.0", "pyyaml==6.0.2", "ratelimit==2.2.1", "tabulate==0.9.0", diff --git a/src/sagemaker/hyperpod/cli/commands/space.py b/src/sagemaker/hyperpod/cli/commands/space.py index 8dbdb2b2..b44cce2d 100644 --- a/src/sagemaker/hyperpod/cli/commands/space.py +++ b/src/sagemaker/hyperpod/cli/commands/space.py @@ -10,6 +10,8 @@ _hyperpod_telemetry_emitter, ) from sagemaker.hyperpod.common.telemetry.constants import Feature +from sagemaker.hyperpod.cli.constants.space_constants import DEFAULT_SPACE_PORT +from sagemaker.hyperpod.common.cli_decorators import handle_cli_exceptions @click.command("hyp-space") @@ -18,6 +20,7 @@ schema_pkg="hyperpod_space_template", registry=SCHEMA_REGISTRY, ) +@handle_cli_exceptions() def space_create(version, debug, config): """Create a space resource.""" space_config = SpaceConfig(**config) @@ -29,6 +32,7 @@ def space_create(version, debug, config): @click.command("hyp-space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") +@handle_cli_exceptions() def space_list(namespace, output): """List space resources.""" spaces = HPSpace.list(namespace=namespace) @@ -70,6 +74,7 @@ def space_list(namespace, output): @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") +@handle_cli_exceptions() def space_describe(name, namespace, output): """Describe a space resource.""" current_space = HPSpace.get(name=name, namespace=namespace) @@ -86,6 +91,7 @@ def space_describe(name, namespace, output): @click.command("hyp-space") @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@handle_cli_exceptions() def space_delete(name, namespace): """Delete a space resource.""" current_space = HPSpace.get(name=name, namespace=namespace) @@ -99,6 +105,7 @@ def space_delete(name, namespace): registry=SCHEMA_REGISTRY, is_update=True, ) +@handle_cli_exceptions() def space_update(version, config): """Update a space resource.""" current_space = HPSpace.get(name=config['name'], namespace=config['namespace']) @@ -112,6 +119,7 @@ def space_update(version, config): @click.command("hyp-space") @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@handle_cli_exceptions() def space_start(name, namespace): """Start a space resource.""" current_space = HPSpace.get(name=name, namespace=namespace) @@ -122,6 +130,7 @@ def space_start(name, namespace): @click.command("hyp-space") @click.option("--name", required=True, help="Name of the space") @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@handle_cli_exceptions() def space_stop(name, namespace): """Stop a space resource.""" current_space = HPSpace.get(name=name, namespace=namespace) @@ -134,8 +143,30 @@ def space_stop(name, namespace): @click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") @click.option("--pod-name", required=False, help="Name of the pod to get logs from") @click.option("--container", required=False, help="Name of the container to get logs from") +@handle_cli_exceptions() def space_get_logs(name, namespace, pod_name, container): """Get logs for a space resource.""" current_space = HPSpace.get(name=name, namespace=namespace) logs = current_space.get_logs(pod_name=pod_name, container=container) click.echo(logs) + + +@click.command("hyp-space") +@click.option("--name", required=True, help="Name of the space") +@click.option("--namespace", "-n", required=False, default="default", help="Kubernetes namespace") +@click.option("--local-port", required=False, default=DEFAULT_SPACE_PORT, help="Localhost port that is mapped to the space") +def space_portforward(name, namespace, local_port): + """Port forward to localhost for a space resource.""" + # Validate input port + try: + local_port = int(local_port) + except ValueError: + raise ValueError("Port values must be valid integers") + + if not (1 <= local_port <= 65535): + raise ValueError(f"Port must be between 1 and 65535, got {local_port}") + + current_space = HPSpace.get(name=name, namespace=namespace) + click.echo(f"Forwarding from local port {local_port} to space `{name}` in namespace `{namespace}`.") + click.echo(f"Please access the service via `http://localhost:{local_port}`. Press Ctrl+C to stop.") + current_space.portforward_space(local_port) diff --git a/src/sagemaker/hyperpod/cli/commands/space_access.py b/src/sagemaker/hyperpod/cli/commands/space_access.py index 1de7e96c..e0fc1408 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_access.py +++ b/src/sagemaker/hyperpod/cli/commands/space_access.py @@ -4,6 +4,7 @@ _hyperpod_telemetry_emitter, ) from sagemaker.hyperpod.common.telemetry.constants import Feature +from sagemaker.hyperpod.common.cli_decorators import handle_cli_exceptions @click.command("hyp-space-access") @@ -14,6 +15,7 @@ default="vscode-remote", help="Remote access type supported values: [vscode-remote, web-ui] [default: vscode-remote]" ) +@handle_cli_exceptions() def space_access_create(name, namespace, connection_type): """Create a space access resource.""" space = HPSpace.get(name=name, namespace=namespace) diff --git a/src/sagemaker/hyperpod/cli/commands/space_template.py b/src/sagemaker/hyperpod/cli/commands/space_template.py index ab84ee5c..d0bfd6a0 100644 --- a/src/sagemaker/hyperpod/cli/commands/space_template.py +++ b/src/sagemaker/hyperpod/cli/commands/space_template.py @@ -3,10 +3,12 @@ import yaml from tabulate import tabulate from sagemaker.hyperpod.space.hyperpod_space_template import HPSpaceTemplate +from sagemaker.hyperpod.common.cli_decorators import handle_cli_exceptions @click.command("hyp-space-template") @click.option("--file", "-f", required=True, help="YAML file containing the configuration") +@handle_cli_exceptions() def space_template_create(file): """Create a space-template resource.""" template = HPSpaceTemplate(file_path=file) @@ -17,6 +19,7 @@ def space_template_create(file): @click.command("hyp-space-template") @click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["table", "json"]), default="table") +@handle_cli_exceptions() def space_template_list(namespace, output): """List space-template resources.""" templates = HPSpaceTemplate.list(namespace) @@ -43,6 +46,7 @@ def space_template_list(namespace, output): @click.option("--name", required=True, help="Name of the space template") @click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--output", "-o", type=click.Choice(["yaml", "json"]), default="yaml") +@handle_cli_exceptions() def space_template_describe(name, namespace, output): """Describe a space-template resource.""" template = HPSpaceTemplate.get(name, namespace) @@ -56,6 +60,7 @@ def space_template_describe(name, namespace, output): @click.command("hyp-space-template") @click.option("--name", required=True, help="Name of the space template") @click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") +@handle_cli_exceptions() def space_template_delete(name, namespace): """Delete a space-template resource.""" template = HPSpaceTemplate.get(name, namespace) @@ -67,6 +72,7 @@ def space_template_delete(name, namespace): @click.option("--name", required=True, help="Name of the space template") @click.option("--namespace", "-n", required=False, default=None, help="Kubernetes namespace") @click.option("--file", "-f", required=True, help="YAML file containing the updated template") +@handle_cli_exceptions() def space_template_update(name, namespace, file): """Update a space-template resource.""" template = HPSpaceTemplate.get(name, namespace) diff --git a/src/sagemaker/hyperpod/cli/constants/space_constants.py b/src/sagemaker/hyperpod/cli/constants/space_constants.py index b595a7aa..eff5e11a 100644 --- a/src/sagemaker/hyperpod/cli/constants/space_constants.py +++ b/src/sagemaker/hyperpod/cli/constants/space_constants.py @@ -13,8 +13,8 @@ SPACE_GROUP = "workspace.jupyter.org" SPACE_VERSION = "v1alpha1" SPACE_PLURAL = "workspaces" +DEFAULT_SPACE_PORT = "8888" # Immutable fields that cannot be updated after space creation IMMUTABLE_FIELDS = { "storage", # storage is immutable per Go struct validation } -ENABLE_MIG_PROFILE_VALIDATION = False diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index d33b5f85..a33aee29 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -48,6 +48,7 @@ space_start, space_stop, space_get_logs, + space_portforward, ) from sagemaker.hyperpod.cli.commands.space_template import ( space_template_create, @@ -164,7 +165,10 @@ def stop(): pass - +@cli.group(cls=CLICommand) +def portforward(): + """Port forward for space resources.""" + pass @cli.group(cls=CLICommand) @@ -252,7 +256,7 @@ def exec(): get_logs.add_command(custom_get_logs) get_logs.add_command(space_get_logs) - +portforward.add_command(space_portforward) get_operator_logs.add_command(pytorch_get_operator_logs) get_operator_logs.add_command(js_get_operator_logs) diff --git a/src/sagemaker/hyperpod/space/hyperpod_space.py b/src/sagemaker/hyperpod/space/hyperpod_space.py index 756fd12f..5ccadbb7 100644 --- a/src/sagemaker/hyperpod/space/hyperpod_space.py +++ b/src/sagemaker/hyperpod/space/hyperpod_space.py @@ -1,24 +1,24 @@ import logging import yaml import boto3 -from typing import List, Optional, ClassVar, Dict, Any +from typing import List, Optional, ClassVar, Dict, Set, Any from pydantic import BaseModel, Field, ConfigDict, model_validator from kubernetes import client, config from kubernetes.client.rest import ApiException +from kr8s.objects import Pod from sagemaker.hyperpod.common.config.metadata import Metadata +from hyperpod_space_template.v1_0.model import ResourceRequirements from sagemaker.hyperpod.common.utils import ( handle_exception, get_default_namespace, setup_logging, verify_kubernetes_version_compatibility, - get_current_cluster, - get_current_region, - get_cluster_instance_types, ) from sagemaker.hyperpod.space.utils import ( map_kubernetes_response_to_model, - get_pod_instance_type, + validate_space_mig_resources, + validate_mig_profile_in_cluster, ) from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( _hyperpod_telemetry_emitter, @@ -28,17 +28,14 @@ SPACE_GROUP, SPACE_VERSION, SPACE_PLURAL, - ENABLE_MIG_PROFILE_VALIDATION, + DEFAULT_SPACE_PORT, ) from sagemaker.hyperpod.cli.constants.space_access_constants import ( SPACE_ACCESS_GROUP, SPACE_ACCESS_VERSION, SPACE_ACCESS_PLURAL, ) -from hyperpod_space_template.v1_0.model import SpaceConfig - -if ENABLE_MIG_PROFILE_VALIDATION: - from sagemaker.hyperpod.training.hyperpod_pytorch_job import list_accelerator_partition_types +from hyperpod_space_template.v1_0.model import SpaceConfig, ResourceRequirements class HPSpace(BaseModel): @@ -214,6 +211,104 @@ def verify_kube_config(cls): except Exception as e: raise RuntimeError(f"Failed to load kubeconfig: {e}") + @staticmethod + def _extract_mig_profiles(resources: Optional[ResourceRequirements]) -> Set[str]: + """Extract MIG profile resource keys from resources without validation. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - resources + - ResourceRequirements or None + - The resource requirements to extract MIG profiles from + + **Returns:** + + set: Set of MIG profile resource keys found in the resources + """ + if not resources: + return set() + + mig_profiles = set() + + if resources.requests: + mig_profiles.update([ + key for key in resources.requests.keys() + if key.startswith("nvidia.com/mig-") + ]) + + if resources.limits: + mig_profiles.update([ + key for key in resources.limits.keys() + if key.startswith("nvidia.com/mig-") + ]) + + return mig_profiles + + def _validate_and_extract_mig_profiles(self, resources: Optional[ResourceRequirements]) -> Set[str]: + """Validate MIG resources and extract MIG profiles. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - resources + - ResourceRequirements or None + - The resource requirements to validate + + **Returns:** + + set: Set of MIG profile resource keys found in the resources + + **Raises:** + + RuntimeError: If MIG validation fails or profiles are invalid + """ + if not resources: + return set() + + # Validate requests + if resources.requests: + valid, err = validate_space_mig_resources(resources.requests) + if not valid: + raise RuntimeError(err) + + # Validate limits + if resources.limits: + valid, err = validate_space_mig_resources(resources.limits) + if not valid: + raise RuntimeError(err) + + # Extract MIG profiles + mig_profiles = self._extract_mig_profiles(resources) + + # Validate that requests and limits use the same MIG profile + if len(mig_profiles) > 1: + raise RuntimeError( + "MIG profile mismatch: requests and limits must use the same MIG profile. " + f"Found: {', '.join(mig_profiles)}" + ) + + # Validate MIG profile exists in cluster + if mig_profiles: + mig_profile = list(mig_profiles)[0] + valid, err = validate_mig_profile_in_cluster(mig_profile) + if not valid: + raise RuntimeError(err) + + return mig_profiles + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_space") def create(self, debug: bool = False): """Create and submit the HyperPod Space to the Kubernetes cluster. @@ -252,32 +347,13 @@ def create(self, debug: bool = False): >>> # Create a space with default settings >>> space.create() """ - self.verify_kube_config() logger = self.get_logger() logger = setup_logging(logger, debug) - # Validate supported MIG profiles for the cluster - if ENABLE_MIG_PROFILE_VALIDATION: - if self.config.resources: - mig_profiles = set() - if self.config.resources.requests: - mig_profiles.update([key for key in self.config.resources.requests.keys() if key.startswith("nvidia.com/mig")]) - if self.config.resources.limits: - mig_profiles.update([key for key in self.config.resources.limits.keys() if key.startswith("nvidia.com/mig")]) - - if len(mig_profiles) > 1: - raise RuntimeError("Space only supports one MIG profile") - - if mig_profiles: - cluster_instance_types = get_cluster_instance_types( - get_current_cluster(), - get_current_region() - ) - supported_mig_profiles = {profile for instance_type in cluster_instance_types for profile in list_accelerator_partition_types(instance_type)} - if list(mig_profiles)[0] not in supported_mig_profiles: - raise RuntimeError(f"Accelerator partition type '{list(mig_profiles)[0]}' does not exist in this cluster. Use 'hyp list-accelerator-partition-type' to check for available resources.") + # Validate and extract MIG profiles + self._validate_and_extract_mig_profiles(self.config.resources) # Convert config to domain model domain_config = self.config.to_domain() @@ -551,32 +627,29 @@ def update(self, **kwargs): self.verify_kube_config() logger = self.get_logger() - # Validate supported MIG profile for node which the Space is running on - if ENABLE_MIG_PROFILE_VALIDATION: - if "resources" in kwargs: - mig_profiles = set() - mig_profiles.update([key for key in kwargs["resources"].get("requests", {}).keys() if key.startswith("nvidia.com/mig")]) - mig_profiles.update([key for key in kwargs["resources"].get("limits", {}).keys() if key.startswith("nvidia.com/mig")]) - - if len(mig_profiles) > 1: - raise RuntimeError("Space only supports one MIG profile") - - if mig_profiles: - pods = self.list_pods() - if not pods: - raise RuntimeError(f"No pods found for space '{self.config.name}'") - - node_instance_type = get_pod_instance_type(pods[0], self.config.namespace) - supported_mig_profiles = set(list_accelerator_partition_types(node_instance_type)) - if list(mig_profiles)[0] not in supported_mig_profiles: - raise RuntimeError(f"Accelerator partition type '{list(mig_profiles)[0]}' does not exist in this cluster. Use 'hyp list-accelerator-partition-type' to check for available resources.") - - # Ensure existing MIG profile gets removed before setting a new one - existing_config = HPSpace.get(self.config.name, self.config.namespace).config - existing_mig_profiles = [key for key in existing_config.resources.requests.keys() if key.startswith("nvidia.com/mig")] - if existing_mig_profiles: - kwargs["resources"]["requests"].update({existing_mig_profiles[0]: None}) - kwargs["resources"]["limits"].update({existing_mig_profiles[0]: None}) + # Validate MIG profile configuration + if "resources" in kwargs: + resources = kwargs["resources"] + + if isinstance(resources, dict): + resources = ResourceRequirements(**resources) + + # Validate and extract MIG profiles + mig_profiles = self._validate_and_extract_mig_profiles(resources) + + # Remove existing MIG profiles if changing to a different one + if mig_profiles: + mig_profile = list(mig_profiles)[0] + + existing_config = HPSpace.get(self.config.name, self.config.namespace).config + existing_mig_profiles = self._extract_mig_profiles(existing_config.resources) + + if existing_mig_profiles and mig_profile not in existing_mig_profiles: + # Remove existing MIG profiles by setting to None + for existing_profile in existing_mig_profiles: + if existing_profile != mig_profile: + kwargs["resources"].setdefault("requests", {})[existing_profile] = None + kwargs["resources"].setdefault("limits", {})[existing_profile] = None custom_api = client.CustomObjectsApi() @@ -821,4 +894,81 @@ def create_space_access(self, connection_type: str = "vscode-remote") -> Dict[st } except Exception as e: logger.error(f"Failed to create space access for {self.config.name}!") - handle_exception(e, self.config.name, self.config.namespace) \ No newline at end of file + handle_exception(e, self.config.name, self.config.namespace) + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "portforward_space") + def portforward_space(self, local_port: str, remote_port: str = DEFAULT_SPACE_PORT): + """Forward local port to the space pod for development access. + + Creates a port forwarding connection from a local port to a remote port + on the space pod, enabling direct access to services running inside the + space. + + **Parameters:** + + .. list-table:: + :header-rows: 1 + :widths: 20 20 60 + + * - Parameter + - Type + - Description + * - local_port + - str + - The local port to forward from + * - remote_port + - str, optional + - The remote port on the space pod to forward to (default: DEFAULT_SPACE_PORT) + + **Raises:** + + RuntimeError: If no pods are found for the space or if the space is not in Available status + KeyboardInterrupt: When the user stops the port forwarding with Ctrl+C + Exception: If the port forwarding setup fails or Kubernetes API call fails + + .. dropdown:: Usage Examples + :open: + + .. code-block:: python + + >>> # Forward local port 8080 to default remote port + >>> space = HPSpace.get("myspace") + >>> space.portforward_space("8080") + + >>> # Forward local port 3000 to remote port 8888 + >>> space.portforward_space("3000", "8888") + + >>> # Access forwarded service (in another terminal) + >>> # curl http://localhost:8080 + """ + + self.verify_kube_config() + logger = self.get_logger() + + # Check if space is in Available status + if self.status and self.status.get("conditions"): + is_available = False + for condition in self.status["conditions"]: + if condition.get("type") == "Available" and condition.get("status") == "True": + is_available = True + break + + if not is_available: + raise RuntimeError(f"Space '{self.config.name}' is not in Available status. Port forwarding is only allowed for available spaces.") + + pods = self.list_pods() + if not pods: + raise RuntimeError(f"No pods found for space '{self.config.name}'") + + pod_name = pods[0] + pod = Pod.get(name=pod_name, namespace=self.config.namespace) + pf = pod.portforward(remote_port=int(remote_port), local_port=int(local_port)) + + logger.debug(f"Forwarding from local port {local_port} to space pod: {pod_name}.") + + try: + pf.run_forever() + except KeyboardInterrupt: + logger.debug("Stopping space port forward...") + finally: + pf.stop() diff --git a/src/sagemaker/hyperpod/space/utils.py b/src/sagemaker/hyperpod/space/utils.py index da8200a1..5c5ec8d3 100644 --- a/src/sagemaker/hyperpod/space/utils.py +++ b/src/sagemaker/hyperpod/space/utils.py @@ -1,9 +1,11 @@ """Utility functions for space operations.""" +import os import re -from typing import Dict, Any, Set, List +from typing import Dict, Any, Set, List, Tuple, Optional from pydantic import BaseModel from kubernetes import client +from sagemaker.hyperpod.training.constants import VALIDATE_PROFILE_IN_CLUSTER def camel_to_snake(name: str) -> str: @@ -89,3 +91,57 @@ def get_pod_instance_type(pod_name: str, namespace: str = "default") -> str: return instance_type raise RuntimeError(f"Instance type not found for node '{pod.spec.node_name}'") + + +def validate_space_mig_resources(resources: Optional[Dict[str, Optional[str]]]) -> Tuple[bool, str]: + """Validate MIG profile configuration in space resources. + + Ensures that: + 1. Only one MIG profile is specified + 2. MIG profiles are not mixed with full GPU requests + + Args: + resources: Dictionary of resource requests or limits (e.g., {"nvidia.com/gpu": "1", "cpu": "2"}) + + Returns: + Tuple of (is_valid, error_message) + """ + if not resources: + return True, "" + + # Extract GPU-related resource keys + mig_profiles = [key for key in resources.keys() if key.startswith("nvidia.com/mig-")] + has_full_gpu = "nvidia.com/gpu" in resources + + # Check for multiple MIG profiles + if len(mig_profiles) > 1: + return False, "Space only supports one MIG profile" + + # Check for mixing full GPU with MIG partitions + if has_full_gpu and mig_profiles: + return False, "Cannot mix full GPU (nvidia.com/gpu) with MIG partitions (nvidia.com/mig-*)" + + return True, "" + + +def validate_mig_profile_in_cluster(mig_profile: str) -> Tuple[bool, str]: + """Validate that a MIG profile exists on at least one node in the cluster. + + Args: + mig_profile: Full MIG profile resource key (e.g., 'nvidia.com/mig-1g.5gb') + + Returns: + Tuple of (is_valid, error_message) + """ + if os.getenv(VALIDATE_PROFILE_IN_CLUSTER) == "false": + return True, "" + + v1 = client.CoreV1Api() + for node in v1.list_node().items: + if node.status and node.status.allocatable: + allocatable = node.status.allocatable.get(mig_profile) + if allocatable and int(allocatable) > 0: + return True, "" + + return False, (f"Accelerator partition type '{mig_profile}' does not exist in this cluster. " + f"Use 'hyp list-accelerator-partition-type' to check for available resources.") diff --git a/test/integration_tests/space/cli/test_cli_space.py b/test/integration_tests/space/cli/test_cli_space.py index b0d0a012..e912668f 100644 --- a/test/integration_tests/space/cli/test_cli_space.py +++ b/test/integration_tests/space/cli/test_cli_space.py @@ -1,10 +1,14 @@ import time import pytest +import threading +import socket +import requests from click.testing import CliRunner from sagemaker.hyperpod.cli.commands.space import ( space_create, space_list, space_describe, space_delete, - space_update, space_start, space_stop, space_get_logs + space_update, space_start, space_stop, space_get_logs, space_portforward ) +from sagemaker.hyperpod.space.hyperpod_space import HPSpace from test.integration_tests.utils import get_time_str # --------- Test Configuration --------- @@ -25,6 +29,44 @@ def space_name(): class TestSpaceCLI: """Integration tests for HyperPod Space CLI commands.""" + def _wait_for_space_available(self, space_name, namespace="default", timeout=300): + """Wait for space to become available.""" + start_time = time.time() + while time.time() - start_time < timeout: + try: + space = HPSpace.get(name=space_name, namespace=namespace) + status = space.status + if status and status.get("conditions"): + for condition in status["conditions"]: + if condition.get("type") == "Available" and condition.get("status") == "True": + return True + time.sleep(10) + except Exception as e: + print(f"Error checking space status: {e}") + time.sleep(10) + return False + + def _is_port_available(self, port): + """Check if a port is available for use.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(('localhost', port)) + return True + except OSError: + return False + + def _test_http_endpoint(self, port, timeout=30): + """Test if HTTP endpoint responds with 200.""" + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}", timeout=5) + if response.status_code == 200: + return True + except (requests.exceptions.RequestException, requests.exceptions.ConnectionError): + time.sleep(3) + return False + @pytest.mark.dependency(name="create") def test_space_create(self, runner, space_name): """Test creating a space via CLI.""" @@ -86,6 +128,7 @@ def test_space_describe_json(self, runner, space_name): assert "{" in result.output and "}" in result.output @pytest.mark.dependency(depends=["create"]) + @pytest.mark.dependency(name="stop") def test_space_stop(self, runner, space_name): """Test stopping a space.""" result = runner.invoke(space_stop, [ @@ -95,7 +138,8 @@ def test_space_stop(self, runner, space_name): assert result.exit_code == 0, result.output assert f"Space '{space_name}' stop requested" in result.output - @pytest.mark.dependency(depends=["create"]) + @pytest.mark.dependency(depends=["stop"]) + @pytest.mark.dependency(name="start") def test_space_start(self, runner, space_name): """Test starting a space.""" result = runner.invoke(space_start, [ @@ -116,6 +160,53 @@ def test_space_update(self, runner, space_name): assert result.exit_code == 0, result.output assert f"Space '{space_name}' updated successfully" in result.output + @pytest.mark.dependency(depends=["start"]) + def test_space_portforward(self, runner, space_name): + """Test port forwarding to a space.""" + # Find an available port + test_port = 8080 + while not self._is_port_available(test_port) and test_port < 9000: + test_port += 1 + + if test_port >= 9000: + pytest.skip("No available ports found for testing") + + # Wait for space to become available + print(f"Waiting for space '{space_name}' to become available...") + if not self._wait_for_space_available(space_name, NAMESPACE): + pytest.skip(f"Space '{space_name}' did not become available within timeout") + + # Start port forwarding in a separate thread + portforward_thread = None + portforward_exception = None + + def run_portforward(): + nonlocal portforward_exception + try: + result = runner.invoke(space_portforward, [ + "--name", space_name, + "--namespace", NAMESPACE, + "--local-port", str(test_port) + ], catch_exceptions=False) + if result.exit_code != 0: + portforward_exception = Exception(f"Port forward failed: {result.output}") + except Exception as e: + portforward_exception = e + + portforward_thread = threading.Thread(target=run_portforward, daemon=True) + portforward_thread.start() + + # Check if port forwarding thread encountered an error + if portforward_exception: + raise portforward_exception + + # Test localhost HTTP endpoint + print(f"Testing HTTP endpoint at localhost:{test_port}") + if self._test_http_endpoint(test_port): + print("✓ HTTP endpoint returned 200 status") + else: + pytest.fail(f"HTTP endpoint at localhost:{test_port} did not return 200 status within timeout") + @pytest.mark.dependency(depends=["create"]) def test_space_get_logs(self, runner, space_name): """Test getting logs from a space.""" @@ -138,15 +229,6 @@ def test_space_delete(self, runner, space_name): assert result.exit_code == 0, result.output assert f"Requested deletion for Space '{space_name}'" in result.output - def test_space_list_empty_namespace(self, runner): - """Test listing spaces in an empty namespace.""" - result = runner.invoke(space_list, [ - "--namespace", "nonexistent-namespace", - "--output", "table" - ]) - assert result.exit_code == 0, result.output - assert "No spaces found" in result.output - def test_space_describe_nonexistent(self, runner): """Test describing a nonexistent space.""" result = runner.invoke(space_describe, [ diff --git a/test/integration_tests/space/cli/test_cli_space_template.py b/test/integration_tests/space/cli/test_cli_space_template.py index baee8b50..cf9606fb 100644 --- a/test/integration_tests/space/cli/test_cli_space_template.py +++ b/test/integration_tests/space/cli/test_cli_space_template.py @@ -225,15 +225,6 @@ def test_space_template_delete(self, runner, template_name): assert result.exit_code == 0, result.output assert f"Requested deletion for Space template '{template_name}' in namespace '{NAMESPACE}'" in result.output - def test_space_template_list_empty_namespace(self, runner): - """Test listing space templates in an empty namespace.""" - result = runner.invoke(space_template_list, [ - "--namespace", "nonexistent-namespace", - "--output", "table" - ]) - assert result.exit_code == 0, result.output - assert "No space templates found" in result.output - def test_space_template_describe_nonexistent(self, runner): """Test describing a nonexistent space template.""" result = runner.invoke(space_template_describe, [ diff --git a/test/unit_tests/cli/test_space.py b/test/unit_tests/cli/test_space.py index 8d9eaf63..6fea1202 100644 --- a/test/unit_tests/cli/test_space.py +++ b/test/unit_tests/cli/test_space.py @@ -12,19 +12,21 @@ space_start, space_stop, space_get_logs, + space_portforward, ) +@patch('sagemaker.hyperpod.common.cli_decorators._namespace_exists', return_value=True) class TestSpaceCommands: """Test cases for space commands""" - def setup_method(self): + def setup_method(self, mock_namespace_exists): self.runner = CliRunner() self.mock_hp_space = Mock() @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') - def test_space_create_success(self, mock_load_schema, mock_hp_space_class): + def test_space_create_success(self, mock_load_schema, mock_hp_space_class, mock_namespace_exists): """Test successful space creation""" # Mock schema loading mock_load_schema.return_value = { @@ -67,7 +69,7 @@ def test_space_create_success(self, mock_load_schema, mock_hp_space_class): mock_hp_space_instance.create.assert_called_once() @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') - def test_space_create_missing_required_args(self, mock_load_schema): + def test_space_create_missing_required_args(self, mock_load_schema, mock_namespace_exists): """Test space creation with missing required arguments""" mock_load_schema.return_value = { "properties": {"name": {"type": "string"}}, @@ -79,7 +81,7 @@ def test_space_create_missing_required_args(self, mock_load_schema): assert 'Missing option' in result.output @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_list_table_output(self, mock_hp_space_class): + def test_space_list_table_output(self, mock_hp_space_class, mock_namespace_exists): """Test space list with table output""" # Mock HPSpace instances with config and status mock_space1 = Mock() @@ -111,7 +113,7 @@ def test_space_list_table_output(self, mock_hp_space_class): mock_hp_space_class.list.assert_called_once_with(namespace='test-ns') @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_list_json_output(self, mock_hp_space_class): + def test_space_list_json_output(self, mock_hp_space_class, mock_namespace_exists): """Test space list with JSON output""" # Mock HPSpace instances mock_space1 = Mock() @@ -129,7 +131,7 @@ def test_space_list_json_output(self, mock_hp_space_class): assert output_json == [{"name": "space1", "namespace": "ns1"}] @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_list_empty(self, mock_hp_space_class): + def test_space_list_empty(self, mock_hp_space_class, mock_namespace_exists): """Test space list with no items""" mock_hp_space_class.list.return_value = [] @@ -141,7 +143,7 @@ def test_space_list_empty(self, mock_hp_space_class): assert "No spaces found" in result.output @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_describe_yaml_output(self, mock_hp_space_class): + def test_space_describe_yaml_output(self, mock_hp_space_class, mock_namespace_exists): """Test space describe with YAML output""" mock_resource = {"metadata": {"name": "test-space"}} @@ -157,7 +159,7 @@ def test_space_describe_yaml_output(self, mock_hp_space_class): mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_describe_json_output(self, mock_hp_space_class): + def test_space_describe_json_output(self, mock_hp_space_class, mock_namespace_exists): """Test space describe with JSON output""" mock_resource = {"metadata": {"name": "test-space"}} mock_hp_space_instance = Mock() @@ -175,7 +177,7 @@ def test_space_describe_json_output(self, mock_hp_space_class): assert output_json == mock_resource @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_delete_success(self, mock_hp_space_class): + def test_space_delete_success(self, mock_hp_space_class, mock_namespace_exists): """Test successful space deletion""" mock_hp_space_instance = Mock() mock_hp_space_class.get.return_value = mock_hp_space_instance @@ -193,7 +195,7 @@ def test_space_delete_success(self, mock_hp_space_class): @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') @patch('sagemaker.hyperpod.cli.space_utils.load_schema_for_version') - def test_space_update_success(self, mock_load_schema, mock_hp_space_class): + def test_space_update_success(self, mock_load_schema, mock_hp_space_class, mock_namespace_exists): """Test successful space update""" # Mock schema loading mock_load_schema.return_value = { @@ -234,7 +236,7 @@ def test_space_update_success(self, mock_load_schema, mock_hp_space_class): mock_hp_space_instance.update.assert_called_once() @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_start_success(self, mock_hp_space_class): + def test_space_start_success(self, mock_hp_space_class, mock_namespace_exists): """Test successful space start""" mock_hp_space_instance = Mock() mock_hp_space_class.get.return_value = mock_hp_space_instance @@ -250,7 +252,7 @@ def test_space_start_success(self, mock_hp_space_class): mock_hp_space_instance.start.assert_called_once() @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_stop_success(self, mock_hp_space_class): + def test_space_stop_success(self, mock_hp_space_class, mock_namespace_exists): """Test successful space stop""" mock_hp_space_instance = Mock() mock_hp_space_class.get.return_value = mock_hp_space_instance @@ -266,7 +268,7 @@ def test_space_stop_success(self, mock_hp_space_class): mock_hp_space_instance.stop.assert_called_once() @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_get_logs_success(self, mock_hp_space_class): + def test_space_get_logs_success(self, mock_hp_space_class, mock_namespace_exists): """Test successful space get logs""" mock_hp_space_instance = Mock() mock_hp_space_instance.get_logs.return_value = "test logs" @@ -283,7 +285,7 @@ def test_space_get_logs_success(self, mock_hp_space_class): mock_hp_space_instance.get_logs.assert_called_once_with(pod_name=None, container=None) @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') - def test_space_get_logs_no_pods(self, mock_hp_space_class): + def test_space_get_logs_no_pods(self, mock_hp_space_class, mock_namespace_exists): """Test space get logs with no pods""" mock_hp_space_instance = Mock() mock_hp_space_instance.get_logs.return_value = "" @@ -297,7 +299,7 @@ def test_space_get_logs_no_pods(self, mock_hp_space_class): assert result.exit_code == 0 # HPSpace.get_logs() handles the "no pods" case internally - def test_missing_required_arguments(self): + def test_missing_required_arguments(self, mock_namespace_exists): """Test commands with missing required arguments""" # Test create without name result = self.runner.invoke(space_create, ['--namespace', 'test-ns']) @@ -333,3 +335,26 @@ def test_missing_required_arguments(self): result = self.runner.invoke(space_get_logs, ['--namespace', 'test-ns']) assert result.exit_code == 2 assert "Missing option '--name'" in result.output + + # Test portforward without name + result = self.runner.invoke(space_portforward, ['--namespace', 'test-ns', '--local-port', '8080']) + assert result.exit_code == 2 + assert "Missing option '--name'" in result.output + + @patch('sagemaker.hyperpod.cli.commands.space.HPSpace') + def test_space_portforward_success(self, mock_hp_space_class, mock_namespace_exists): + """Test successful space port forwarding""" + mock_hp_space_instance = Mock() + mock_hp_space_class.get.return_value = mock_hp_space_instance + + result = self.runner.invoke(space_portforward, [ + '--name', 'test-space', + '--namespace', 'test-ns', + '--local-port', '8080' + ]) + + assert result.exit_code == 0 + assert "Forwarding from local port 8080 to space `test-space` in namespace `test-ns`" in result.output + assert "Please access the service via `http://localhost:8080`. Press Ctrl+C to stop." in result.output + mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='test-ns') + mock_hp_space_instance.portforward_space.assert_called_once_with(8080) diff --git a/test/unit_tests/cli/test_space_access.py b/test/unit_tests/cli/test_space_access.py index 717047e7..71bb0326 100644 --- a/test/unit_tests/cli/test_space_access.py +++ b/test/unit_tests/cli/test_space_access.py @@ -51,3 +51,4 @@ def test_space_access_create_default_values(self, mock_hp_space_class): assert "https://default-url.com" in result.output mock_hp_space_class.get.assert_called_once_with(name='test-space', namespace='default') mock_space_instance.create_space_access.assert_called_once_with(connection_type='vscode-remote') + diff --git a/test/unit_tests/cli/test_space_template.py b/test/unit_tests/cli/test_space_template.py index fa9f25ae..f435408a 100644 --- a/test/unit_tests/cli/test_space_template.py +++ b/test/unit_tests/cli/test_space_template.py @@ -102,17 +102,19 @@ def test_space_template_list_empty(self, mock_hp_space_template): self.assertIn("No space templates found", result.output) mock_hp_space_template.list.assert_called_once_with(None) + @patch("sagemaker.hyperpod.common.cli_decorators._namespace_exists") @patch("sagemaker.hyperpod.cli.commands.space_template.HPSpaceTemplate") - def test_space_template_list_with_namespace(self, mock_hp_space_template): + def test_space_template_list_with_namespace(self, mock_hp_space_template, mock_namespace_exists): """Test space template list with namespace parameter""" + mock_namespace_exists.return_value = True mock_template1 = Mock() mock_template1.name = "template1" mock_template1.namespace = "test-namespace" mock_template1.config_data = {"spec": {"displayName": "Template 1", "defaultImage": "image1"}} mock_hp_space_template.list.return_value = [mock_template1] - + result = self.runner.invoke(space_template_list, ["--namespace", "test-namespace", "--output", "table"]) - + self.assertEqual(result.exit_code, 0) self.assertIn("template1", result.output) self.assertIn("test-namespace", result.output) diff --git a/test/unit_tests/test_hyperpod_space.py b/test/unit_tests/test_hyperpod_space.py index b0de3933..a9d51a77 100644 --- a/test/unit_tests/test_hyperpod_space.py +++ b/test/unit_tests/test_hyperpod_space.py @@ -3,7 +3,7 @@ from kubernetes.client.rest import ApiException from sagemaker.hyperpod.space.hyperpod_space import HPSpace -from hyperpod_space_template.v1_0.model import SpaceConfig +from hyperpod_space_template.v1_0.model import SpaceConfig, ResourceRequirements class TestHPSpace(unittest.TestCase): @@ -723,3 +723,357 @@ def test_create_space_access_failure(self, mock_handle_exception, mock_verify_co "test-space", "test-namespace" ) + + @patch('sagemaker.hyperpod.space.hyperpod_space.Pod') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + def test_portforward_space_success(self, mock_list_pods, mock_verify_config, mock_pod_class): + """Test successful port forwarding""" + mock_list_pods.return_value = ["test-pod"] + + mock_pod = Mock() + mock_pf = Mock() + mock_pod.portforward.return_value = mock_pf + mock_pod_class.get.return_value = mock_pod + + self.hp_space.portforward_space("8080", "8888") + + mock_verify_config.assert_called_once() + mock_list_pods.assert_called_once() + mock_pod_class.get.assert_called_once_with(name="test-pod", namespace="test-namespace") + mock_pod.portforward.assert_called_once_with(remote_port=8888, local_port=8080) + mock_pf.run_forever.assert_called_once() + mock_pf.stop.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.Pod') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + def test_portforward_space_default_remote_port(self, mock_list_pods, mock_verify_config, mock_pod_class): + """Test port forwarding with default remote port""" + mock_list_pods.return_value = ["test-pod"] + + mock_pod = Mock() + mock_pf = Mock() + mock_pod.portforward.return_value = mock_pf + mock_pod_class.get.return_value = mock_pod + + self.hp_space.portforward_space("8080") + + mock_pod.portforward.assert_called_once_with(remote_port=8888, local_port=8080) + + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + def test_portforward_space_no_pods(self, mock_list_pods, mock_verify_config): + """Test port forwarding when no pods are found""" + mock_list_pods.return_value = [] + + with self.assertRaises(RuntimeError) as context: + self.hp_space.portforward_space("8080") + + self.assertIn("No pods found for space 'test-space'", str(context.exception)) + mock_verify_config.assert_called_once() + mock_list_pods.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.Pod') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + def test_portforward_space_keyboard_interrupt(self, mock_list_pods, mock_verify_config, mock_pod_class): + """Test port forwarding with KeyboardInterrupt""" + mock_list_pods.return_value = ["test-pod"] + + mock_pod = Mock() + mock_pf = Mock() + mock_pf.run_forever.side_effect = KeyboardInterrupt() + mock_pod.portforward.return_value = mock_pf + mock_pod_class.get.return_value = mock_pod + + # Should not raise exception, should handle gracefully + self.hp_space.portforward_space("8080", "8888") + + mock_pf.run_forever.assert_called_once() + mock_pf.stop.assert_called_once() + + @patch('sagemaker.hyperpod.space.hyperpod_space.Pod') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'list_pods') + def test_portforward_space_exception_in_finally(self, mock_list_pods, mock_verify_config, mock_pod_class): + """Test port forwarding ensures cleanup even with exceptions""" + mock_list_pods.return_value = ["test-pod"] + + mock_pod = Mock() + mock_pf = Mock() + mock_pf.run_forever.side_effect = Exception("Port forward failed") + mock_pod.portforward.return_value = mock_pf + mock_pod_class.get.return_value = mock_pod + + with self.assertRaises(Exception) as context: + self.hp_space.portforward_space("8080", "8888") + + self.assertIn("Port forward failed", str(context.exception)) + mock_pf.run_forever.assert_called_once() + mock_pf.stop.assert_called_once() # Ensure cleanup happens + + def test_extract_mig_profiles_no_mig(self): + """Test extraction with no MIG profiles""" + resources = ResourceRequirements( + requests={"cpu": "2", "memory": "4Gi"}, + limits={"cpu": "4", "memory": "8Gi"} + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, set()) + + def test_extract_mig_profiles_requests_only(self): + """Test extraction with MIG profiles in requests only""" + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2", "cpu": "2"} + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, {"nvidia.com/mig-1g.5gb"}) + + def test_extract_mig_profiles_limits_only(self): + """Test extraction with MIG profiles in limits only""" + resources = ResourceRequirements( + limits={"nvidia.com/mig-2g.10gb": "4", "memory": "8Gi"} + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, {"nvidia.com/mig-2g.10gb"}) + + def test_extract_mig_profiles_both_same(self): + """Test extraction with same MIG profile in both requests and limits""" + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2", "cpu": "2"}, + limits={"nvidia.com/mig-1g.5gb": "2", "cpu": "4"} + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, {"nvidia.com/mig-1g.5gb"}) + + def test_extract_mig_profiles_both_different(self): + """Test extraction with different MIG profiles in requests and limits""" + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"}, + limits={"nvidia.com/mig-2g.10gb": "2"} + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, {"nvidia.com/mig-1g.5gb", "nvidia.com/mig-2g.10gb"}) + + def test_extract_mig_profiles_multiple_in_requests(self): + """Test extraction with multiple MIG profiles in requests""" + resources = ResourceRequirements( + requests={ + "nvidia.com/mig-1g.5gb": "2", + "nvidia.com/mig-2g.10gb": "1", + "cpu": "4" + } + ) + + result = HPSpace._extract_mig_profiles(resources) + self.assertEqual(result, {"nvidia.com/mig-1g.5gb", "nvidia.com/mig-2g.10gb"}) + + def test_validate_and_extract_mig_profiles_none_resources(self): + """Test validation with None resources""" + result = self.hp_space._validate_and_extract_mig_profiles(None) + self.assertEqual(result, set()) + + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_space_mig_resources') + def test_validate_and_extract_mig_profiles_no_mig(self, mock_validate): + """Test validation with no MIG profiles""" + + mock_validate.return_value = (True, "") + resources = ResourceRequirements( + requests={"cpu": "2", "memory": "4Gi"}, + limits={"cpu": "4", "memory": "8Gi"} + ) + + result = self.hp_space._validate_and_extract_mig_profiles(resources) + + self.assertEqual(result, set()) + + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_mig_profile_in_cluster') + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_space_mig_resources') + def test_validate_and_extract_mig_profiles_single_mig(self, mock_validate_resources, mock_validate_cluster): + """Test validation with single MIG profile""" + + mock_validate_resources.return_value = (True, "") + mock_validate_cluster.return_value = (True, "") + + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2", "cpu": "2"}, + limits={"nvidia.com/mig-1g.5gb": "2", "cpu": "4"} + ) + + result = self.hp_space._validate_and_extract_mig_profiles(resources) + + self.assertEqual(result, {"nvidia.com/mig-1g.5gb"}) + mock_validate_cluster.assert_called_once_with("nvidia.com/mig-1g.5gb") + + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_space_mig_resources') + def test_validate_and_extract_mig_profiles_mismatch(self, mock_validate): + """Test validation fails with mismatched MIG profiles in requests and limits""" + + mock_validate.return_value = (True, "") + + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"}, + limits={"nvidia.com/mig-2g.10gb": "2"} + ) + + with self.assertRaises(RuntimeError) as context: + self.hp_space._validate_and_extract_mig_profiles(resources) + + self.assertIn("MIG profile mismatch", str(context.exception)) + self.assertIn("nvidia.com/mig-1g.5gb", str(context.exception)) + self.assertIn("nvidia.com/mig-2g.10gb", str(context.exception)) + + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_space_mig_resources') + def test_validate_and_extract_mig_profiles_validation_fails_requests(self, mock_validate): + """Test validation fails when requests validation fails""" + + mock_validate.return_value = (False, "Multiple MIG profiles not allowed") + + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2", "nvidia.com/mig-2g.10gb": "1"} + ) + + with self.assertRaises(RuntimeError) as context: + self.hp_space._validate_and_extract_mig_profiles(resources) + + self.assertIn("Multiple MIG profiles not allowed", str(context.exception)) + + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_mig_profile_in_cluster') + @patch('sagemaker.hyperpod.space.hyperpod_space.validate_space_mig_resources') + def test_validate_and_extract_mig_profiles_cluster_validation_fails(self, mock_validate_resources, mock_validate_cluster): + """Test validation fails when cluster validation fails""" + + mock_validate_resources.return_value = (True, "") + mock_validate_cluster.return_value = (False, "MIG profile not found in cluster") + + resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"} + ) + + with self.assertRaises(RuntimeError) as context: + self.hp_space._validate_and_extract_mig_profiles(resources) + + self.assertIn("MIG profile not found in cluster", str(context.exception)) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, '_validate_and_extract_mig_profiles') + def test_create_with_mig_validation(self, mock_validate_mig, mock_verify_config, mock_custom_api_class): + """Test create calls MIG validation""" + + mock_validate_mig.return_value = {"nvidia.com/mig-1g.5gb"} + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + self.hp_space.config.resources = ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"} + ) + + mock_domain_config = { + "space_spec": { + "apiVersion": "workspace.jupyter.org/v1alpha1", + "kind": "Workspace", + "metadata": {"name": "test-space", "namespace": "test-namespace"}, + "spec": {"image": "test-image:latest"} + } + } + + with patch('hyperpod_space_template.v1_0.model.SpaceConfig.to_domain', return_value=mock_domain_config): + self.hp_space.create() + + mock_validate_mig.assert_called_once_with(self.hp_space.config.resources) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'get') + @patch.object(HPSpace, '_validate_and_extract_mig_profiles') + def test_update_with_mig_profile_change(self, mock_validate_mig, mock_get, mock_verify_config, mock_custom_api_class): + """Test update removes existing MIG profile when changing to new one""" + + # Setup existing space with existing MIG profile + existing_space = HPSpace(config=SpaceConfig( + name="test-space", + display_name="Test Space", + namespace="test-namespace", + resources=ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"}, + limits={"nvidia.com/mig-1g.5gb": "2"} + ) + )) + mock_get.return_value = existing_space + + # New MIG profile + mock_validate_mig.return_value = {"nvidia.com/mig-2g.10gb"} + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + # Update with new MIG profile + new_resources = { + "requests": {"nvidia.com/mig-2g.10gb": "1"}, + "limits": {"nvidia.com/mig-2g.10gb": "1"} + } + + self.hp_space.update(resources=new_resources) + + # Verify old MIG profile is set to None + self.assertEqual(new_resources["requests"]["nvidia.com/mig-1g.5gb"], None) + self.assertEqual(new_resources["limits"]["nvidia.com/mig-1g.5gb"], None) + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, 'get') + @patch.object(HPSpace, '_validate_and_extract_mig_profiles') + def test_update_with_same_mig_profile(self, mock_validate_mig, mock_get, mock_verify_config, mock_custom_api_class): + """Test update doesn't remove MIG profile when it's the same""" + + # Setup existing space with MIG profile + existing_space = HPSpace(config=SpaceConfig( + name="test-space", + display_name="Test Space", + namespace="test-namespace", + resources=ResourceRequirements( + requests={"nvidia.com/mig-1g.5gb": "2"} + ) + )) + mock_get.return_value = existing_space + + # Same MIG profile + mock_validate_mig.return_value = {"nvidia.com/mig-1g.5gb"} + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + # Update with same MIG profile + new_resources = { + "requests": {"nvidia.com/mig-1g.5gb": "4"} + } + + self.hp_space.update(resources=new_resources) + + # Verify MIG profile value remains "4" (not changed to "0") + self.assertEqual(new_resources["requests"]["nvidia.com/mig-1g.5gb"], "4") + + @patch('sagemaker.hyperpod.space.hyperpod_space.client.CustomObjectsApi') + @patch.object(HPSpace, 'verify_kube_config') + @patch.object(HPSpace, '_validate_and_extract_mig_profiles') + def test_update_converts_dict_to_resource_requirements(self, mock_validate_mig, mock_verify_config, mock_custom_api_class): + """Test update converts dict resources to ResourceRequirements""" + mock_validate_mig.return_value = set() + mock_custom_api = Mock() + mock_custom_api_class.return_value = mock_custom_api + + # Pass resources as dict + resources_dict = { + "requests": {"cpu": "2", "memory": "4Gi"} + } + + self.hp_space.update(resources=resources_dict) + + # Verify _validate_and_extract_mig_profiles was called with ResourceRequirements object + call_args = mock_validate_mig.call_args[0][0] + self.assertIsInstance(call_args, ResourceRequirements) diff --git a/test/unit_tests/test_space_utils.py b/test/unit_tests/test_space_utils.py index a0e6a3ef..eed38c05 100644 --- a/test/unit_tests/test_space_utils.py +++ b/test/unit_tests/test_space_utils.py @@ -1,9 +1,17 @@ """Unit tests for space utils module.""" +import os import unittest from unittest.mock import Mock, patch from kubernetes import client -from sagemaker.hyperpod.space.utils import camel_to_snake, get_model_fields, map_kubernetes_response_to_model, get_pod_instance_type +from sagemaker.hyperpod.space.utils import ( + camel_to_snake, + get_model_fields, + map_kubernetes_response_to_model, + get_pod_instance_type, + validate_space_mig_resources, + validate_mig_profile_in_cluster, +) from hyperpod_space_template.v1_0.model import SpaceConfig @@ -133,3 +141,142 @@ def test_get_pod_instance_type_no_instance_type_label(self, mock_core_v1): get_pod_instance_type('test-pod') self.assertIn("Instance type not found for node 'test-node'", str(context.exception)) + + def test_validate_space_mig_resources_none(self): + """Test validation with None resources.""" + valid, err = validate_space_mig_resources(None) + self.assertTrue(valid) + self.assertEqual(err, "") + + def test_validate_space_mig_resources_empty(self): + """Test validation with empty resources.""" + valid, err = validate_space_mig_resources({}) + self.assertTrue(valid) + self.assertEqual(err, "") + + def test_validate_space_mig_resources_single_mig_profile(self): + """Test validation with single MIG profile.""" + resources = {"nvidia.com/mig-1g.5gb": "2", "cpu": "4"} + valid, err = validate_space_mig_resources(resources) + self.assertTrue(valid) + self.assertEqual(err, "") + + def test_validate_space_mig_resources_multiple_mig_profiles(self): + """Test validation fails with multiple MIG profiles.""" + resources = { + "nvidia.com/mig-1g.5gb": "2", + "nvidia.com/mig-2g.10gb": "1", + "cpu": "4" + } + valid, err = validate_space_mig_resources(resources) + self.assertFalse(valid) + self.assertEqual(err, "Space only supports one MIG profile") + + def test_validate_space_mig_resources_mixed_gpu_and_mig(self): + """Test validation fails when mixing full GPU with MIG.""" + resources = { + "nvidia.com/gpu": "1", + "nvidia.com/mig-1g.5gb": "2", + "cpu": "4" + } + valid, err = validate_space_mig_resources(resources) + self.assertFalse(valid) + self.assertEqual(err, "Cannot mix full GPU (nvidia.com/gpu) with MIG partitions (nvidia.com/mig-*)") + + def test_validate_space_mig_resources_full_gpu_only(self): + """Test validation passes with full GPU only.""" + resources = {"nvidia.com/gpu": "1", "cpu": "4", "memory": "8Gi"} + valid, err = validate_space_mig_resources(resources) + self.assertTrue(valid) + self.assertEqual(err, "") + + @patch.dict(os.environ, {"VALIDATE_PROFILE_IN_CLUSTER": "false"}) + def test_validate_mig_profile_in_cluster_disabled(self): + """Test validation skipped when env var is false.""" + valid, err = validate_mig_profile_in_cluster("nvidia.com/mig-1g.5gb") + self.assertTrue(valid) + self.assertEqual(err, "") + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_validate_mig_profile_in_cluster_found(self, mock_core_v1): + """Test validation succeeds when MIG profile exists on a node.""" + # Mock node with MIG profile + mock_node1 = Mock() + mock_node1.status.allocatable = {"nvidia.com/mig-1g.5gb": "7"} + + mock_node2 = Mock() + mock_node2.status.allocatable = {"nvidia.com/gpu": "8"} + + mock_nodes = Mock() + mock_nodes.items = [mock_node1, mock_node2] + + mock_api = Mock() + mock_api.list_node.return_value = mock_nodes + mock_core_v1.return_value = mock_api + + valid, err = validate_mig_profile_in_cluster("nvidia.com/mig-1g.5gb") + + self.assertTrue(valid) + self.assertEqual(err, "") + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_validate_mig_profile_in_cluster_not_found(self, mock_core_v1): + """Test validation fails when MIG profile doesn't exist on any node.""" + # Mock nodes without the requested MIG profile + mock_node1 = Mock() + mock_node1.status.allocatable = {"nvidia.com/mig-2g.10gb": "4"} + + mock_node2 = Mock() + mock_node2.status.allocatable = {"nvidia.com/gpu": "8"} + + mock_nodes = Mock() + mock_nodes.items = [mock_node1, mock_node2] + + mock_api = Mock() + mock_api.list_node.return_value = mock_nodes + mock_core_v1.return_value = mock_api + + valid, err = validate_mig_profile_in_cluster("nvidia.com/mig-1g.5gb") + + self.assertFalse(valid) + self.assertIn("Accelerator partition type 'nvidia.com/mig-1g.5gb' does not exist", err) + self.assertIn("Use 'hyp list-accelerator-partition-type'", err) + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_validate_mig_profile_in_cluster_zero_allocatable(self, mock_core_v1): + """Test validation fails when MIG profile exists but has zero allocatable.""" + mock_node = Mock() + mock_node.status.allocatable = {"nvidia.com/mig-1g.5gb": "0"} + + mock_nodes = Mock() + mock_nodes.items = [mock_node] + + mock_api = Mock() + mock_api.list_node.return_value = mock_nodes + mock_core_v1.return_value = mock_api + + valid, err = validate_mig_profile_in_cluster("nvidia.com/mig-1g.5gb") + + self.assertFalse(valid) + self.assertIn("does not exist in this cluster", err) + + @patch('sagemaker.hyperpod.space.utils.client.CoreV1Api') + def test_validate_mig_profile_in_cluster_no_status(self, mock_core_v1): + """Test validation handles nodes without status.""" + mock_node1 = Mock() + mock_node1.status = None + + mock_node2 = Mock() + mock_node2.status.allocatable = {"nvidia.com/mig-1g.5gb": "7"} + + mock_nodes = Mock() + mock_nodes.items = [mock_node1, mock_node2] + + mock_api = Mock() + mock_api.list_node.return_value = mock_nodes + mock_core_v1.return_value = mock_api + + valid, err = validate_mig_profile_in_cluster("nvidia.com/mig-1g.5gb") + + self.assertTrue(valid) + self.assertEqual(err, "")