diff --git a/test/unit_tests/inference/test_crd_validation.py b/test/unit_tests/inference/test_crd_basic_validation.py similarity index 100% rename from test/unit_tests/inference/test_crd_validation.py rename to test/unit_tests/inference/test_crd_basic_validation.py diff --git a/test/unit_tests/inference/test_inferenceendpointconfigs_crd_validation.py b/test/unit_tests/inference/test_inferenceendpointconfigs_crd_validation.py new file mode 100644 index 00000000..353a2bc5 --- /dev/null +++ b/test/unit_tests/inference/test_inferenceendpointconfigs_crd_validation.py @@ -0,0 +1,289 @@ +""" +Unit tests for InferenceEndpointConfig CRD required field validation. + +This module validates that all required fields are properly defined in the +InferenceEndpointConfig CRD YAML file used by the inference operator. +""" + +import unittest +import yaml +from pathlib import Path + + +class TestInferenceEndpointConfigRequiredFields(unittest.TestCase): + """Test class for validating required fields in InferenceEndpointConfig CRD.""" + + @classmethod + def setUpClass(cls): + """Load the CRD file once for all tests.""" + cls.base_path = Path(__file__).parent.parent.parent.parent + cls.crd_path = cls.base_path / "helm_chart" / "HyperPodHelmChart" / "charts" / "inference-operator" / "config" / "crd" + cls.crd_file = cls.crd_path / "inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml" + + with open(cls.crd_file, 'r', encoding='utf-8') as f: + cls.crd_content = yaml.safe_load(f) + + # Get v1 version schema + cls.v1_schema = None + for version in cls.crd_content.get('spec', {}).get('versions', []): + if version.get('name') == 'v1': + cls.v1_schema = version.get('schema', {}).get('openAPIV3Schema', {}) + break + + def test_crd_file_exists(self): + """Test that the InferenceEndpointConfig CRD file exists.""" + self.assertTrue(self.crd_file.exists(), f"CRD file does not exist: {self.crd_file}") + + def test_v1_version_exists(self): + """Test that v1 version schema exists in the CRD.""" + self.assertIsNotNone(self.v1_schema, "v1 version schema not found in CRD") + + def test_spec_required_fields(self): + """Test that spec has the required top-level fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}) + required_fields = spec_properties.get('required', []) + + expected_required = ['modelName', 'modelSourceConfig', 'worker'] + for field in expected_required: + self.assertIn(field, required_fields, + f"Field '{field}' should be required in spec") + + def test_modelname_field_exists(self): + """Test that modelName field is defined in spec.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + self.assertIn('modelName', spec_properties, "modelName field should exist in spec") + + model_name = spec_properties.get('modelName', {}) + self.assertEqual(model_name.get('type'), 'string', "modelName should be of type string") + self.assertIn('pattern', model_name, "modelName should have a pattern validation") + self.assertIn('maxLength', model_name, "modelName should have a maxLength validation") + + def test_model_source_config_required_fields(self): + """Test that modelSourceConfig has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config = spec_properties.get('modelSourceConfig', {}) + + required_fields = model_source_config.get('required', []) + self.assertIn('modelSourceType', required_fields, + "modelSourceType should be required in modelSourceConfig") + + def test_model_source_config_model_source_type_field(self): + """Test that modelSourceType field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config_props = spec_properties.get('modelSourceConfig', {}).get('properties', {}) + + self.assertIn('modelSourceType', model_source_config_props, + "modelSourceType should exist in modelSourceConfig") + + model_source_type = model_source_config_props.get('modelSourceType', {}) + self.assertIn('enum', model_source_type, "modelSourceType should have enum values") + self.assertIn('fsx', model_source_type.get('enum', []), "modelSourceType should support 'fsx'") + self.assertIn('s3', model_source_type.get('enum', []), "modelSourceType should support 's3'") + + def test_model_source_config_fsx_storage_required_fields(self): + """Test that fsxStorage has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config_props = spec_properties.get('modelSourceConfig', {}).get('properties', {}) + fsx_storage = model_source_config_props.get('fsxStorage', {}) + + required_fields = fsx_storage.get('required', []) + self.assertIn('fileSystemId', required_fields, + "fileSystemId should be required in fsxStorage") + + def test_model_source_config_fsx_storage_fields(self): + """Test that fsxStorage fields are properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config_props = spec_properties.get('modelSourceConfig', {}).get('properties', {}) + fsx_storage_props = model_source_config_props.get('fsxStorage', {}).get('properties', {}) + + self.assertIn('fileSystemId', fsx_storage_props, "fileSystemId should exist in fsxStorage") + self.assertIn('dnsName', fsx_storage_props, "dnsName should exist in fsxStorage") + self.assertIn('mountName', fsx_storage_props, "mountName should exist in fsxStorage") + + def test_model_source_config_s3_storage_required_fields(self): + """Test that s3Storage has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config_props = spec_properties.get('modelSourceConfig', {}).get('properties', {}) + s3_storage = model_source_config_props.get('s3Storage', {}) + + required_fields = s3_storage.get('required', []) + self.assertIn('bucketName', required_fields, + "bucketName should be required in s3Storage") + self.assertIn('region', required_fields, + "region should be required in s3Storage") + + def test_model_source_config_s3_storage_fields(self): + """Test that s3Storage fields are properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_source_config_props = spec_properties.get('modelSourceConfig', {}).get('properties', {}) + s3_storage_props = model_source_config_props.get('s3Storage', {}).get('properties', {}) + + self.assertIn('bucketName', s3_storage_props, "bucketName should exist in s3Storage") + self.assertIn('region', s3_storage_props, "region should exist in s3Storage") + + def test_worker_required_fields(self): + """Test that worker section has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker = spec_properties.get('worker', {}) + + required_fields = worker.get('required', []) + expected_required = ['image', 'modelInvocationPort', 'modelVolumeMount', 'resources'] + for field in expected_required: + self.assertIn(field, required_fields, + f"Field '{field}' should be required in worker") + + def test_worker_image_field(self): + """Test that worker.image field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + + self.assertIn('image', worker_props, "image should exist in worker") + image = worker_props.get('image', {}) + self.assertEqual(image.get('type'), 'string', "image should be of type string") + + def test_worker_model_invocation_port_required_fields(self): + """Test that modelInvocationPort has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + model_invocation_port = worker_props.get('modelInvocationPort', {}) + + required_fields = model_invocation_port.get('required', []) + self.assertIn('containerPort', required_fields, + "containerPort should be required in modelInvocationPort") + + def test_worker_model_invocation_port_fields(self): + """Test that modelInvocationPort fields are properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + port_props = worker_props.get('modelInvocationPort', {}).get('properties', {}) + + self.assertIn('containerPort', port_props, "containerPort should exist in modelInvocationPort") + container_port = port_props.get('containerPort', {}) + self.assertEqual(container_port.get('type'), 'integer', "containerPort should be of type integer") + self.assertIn('minimum', container_port, "containerPort should have minimum validation") + self.assertIn('maximum', container_port, "containerPort should have maximum validation") + + self.assertIn('name', port_props, "name should exist in modelInvocationPort") + name = port_props.get('name', {}) + self.assertIn('pattern', name, "name should have pattern validation") + + def test_worker_model_volume_mount_required_fields(self): + """Test that modelVolumeMount has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + model_volume_mount = worker_props.get('modelVolumeMount', {}) + + required_fields = model_volume_mount.get('required', []) + self.assertIn('name', required_fields, + "name should be required in modelVolumeMount") + + def test_worker_model_volume_mount_fields(self): + """Test that modelVolumeMount fields are properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + volume_mount_props = worker_props.get('modelVolumeMount', {}).get('properties', {}) + + self.assertIn('name', volume_mount_props, "name should exist in modelVolumeMount") + self.assertIn('mountPath', volume_mount_props, "mountPath should exist in modelVolumeMount") + + def test_worker_resources_field(self): + """Test that worker.resources field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + + self.assertIn('resources', worker_props, "resources should exist in worker") + resources = worker_props.get('resources', {}) + self.assertEqual(resources.get('type'), 'object', "resources should be of type object") + + def test_worker_resources_fields(self): + """Test that resources has expected sub-fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + worker_props = spec_properties.get('worker', {}).get('properties', {}) + resources_props = worker_props.get('resources', {}).get('properties', {}) + + self.assertIn('limits', resources_props, "limits should exist in resources") + self.assertIn('requests', resources_props, "requests should exist in resources") + + def test_optional_fields_exist(self): + """Test that optional but commonly used fields exist in spec.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + + optional_fields = [ + 'endpointName', + 'instanceType', + 'instanceTypes', + 'replicas', + 'autoScalingSpec', + 'loadBalancer', + 'metrics', + 'tlsConfig', + 'tags' + ] + + for field in optional_fields: + self.assertIn(field, spec_properties, + f"Optional field '{field}' should exist in spec") + + def test_endpoint_name_validation(self): + """Test that endpointName has proper validation.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + endpoint_name = spec_properties.get('endpointName', {}) + + self.assertIn('pattern', endpoint_name, "endpointName should have pattern validation") + self.assertIn('maxLength', endpoint_name, "endpointName should have maxLength validation") + + def test_instance_type_validation(self): + """Test that instanceType has proper validation.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + instance_type = spec_properties.get('instanceType', {}) + + self.assertEqual(instance_type.get('type'), 'string', "instanceType should be of type string") + self.assertIn('pattern', instance_type, "instanceType should have pattern validation") + # Check that the pattern requires 'ml.' prefix + pattern = instance_type.get('pattern', '') + self.assertIn('ml', pattern.lower(), "instanceType pattern should require 'ml.' prefix") + + def test_replicas_field(self): + """Test that replicas field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + replicas = spec_properties.get('replicas', {}) + + self.assertEqual(replicas.get('type'), 'integer', "replicas should be of type integer") + self.assertIn('default', replicas, "replicas should have a default value") + + def test_auto_scaling_spec_fields(self): + """Test that autoScalingSpec has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + auto_scaling_spec = spec_properties.get('autoScalingSpec', {}).get('properties', {}) + + expected_fields = [ + 'minReplicaCount', + 'maxReplicaCount', + 'cooldownPeriod', + 'pollingInterval' + ] + + for field in expected_fields: + self.assertIn(field, auto_scaling_spec, + f"Field '{field}' should exist in autoScalingSpec") + + def test_metrics_fields(self): + """Test that metrics section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + metrics_props = spec_properties.get('metrics', {}).get('properties', {}) + + self.assertIn('enabled', metrics_props, "enabled should exist in metrics") + self.assertIn('metricsScrapeIntervalSeconds', metrics_props, + "metricsScrapeIntervalSeconds should exist in metrics") + + def test_load_balancer_fields(self): + """Test that loadBalancer section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + load_balancer_props = spec_properties.get('loadBalancer', {}).get('properties', {}) + + self.assertIn('healthCheckPath', load_balancer_props, "healthCheckPath should exist in loadBalancer") + self.assertIn('routingAlgorithm', load_balancer_props, "routingAlgorithm should exist in loadBalancer") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit_tests/inference/test_jumpstartmodels_crd_validation.py b/test/unit_tests/inference/test_jumpstartmodels_crd_validation.py new file mode 100644 index 00000000..ca6d204c --- /dev/null +++ b/test/unit_tests/inference/test_jumpstartmodels_crd_validation.py @@ -0,0 +1,359 @@ +""" +Unit tests for JumpStartModel CRD required field validation. + +This module validates that all required fields are properly defined in the +JumpStartModel CRD YAML file used by the inference operator. +""" + +import unittest +import yaml +from pathlib import Path + + +class TestJumpStartModelRequiredFields(unittest.TestCase): + """Test class for validating required fields in JumpStartModel CRD.""" + + @classmethod + def setUpClass(cls): + """Load the CRD file once for all tests.""" + cls.base_path = Path(__file__).parent.parent.parent.parent + cls.crd_path = cls.base_path / "helm_chart" / "HyperPodHelmChart" / "charts" / "inference-operator" / "config" / "crd" + cls.crd_file = cls.crd_path / "inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml" + + with open(cls.crd_file, 'r', encoding='utf-8') as f: + cls.crd_content = yaml.safe_load(f) + + # Get v1 version schema + cls.v1_schema = None + for version in cls.crd_content.get('spec', {}).get('versions', []): + if version.get('name') == 'v1': + cls.v1_schema = version.get('schema', {}).get('openAPIV3Schema', {}) + break + + def test_crd_file_exists(self): + """Test that the JumpStartModel CRD file exists.""" + self.assertTrue(self.crd_file.exists(), f"CRD file does not exist: {self.crd_file}") + + def test_v1_version_exists(self): + """Test that v1 version schema exists in the CRD.""" + self.assertIsNotNone(self.v1_schema, "v1 version schema not found in CRD") + + def test_spec_required_fields(self): + """Test that spec has the required top-level fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}) + required_fields = spec_properties.get('required', []) + + expected_required = ['model', 'server'] + for field in expected_required: + self.assertIn(field, required_fields, + f"Field '{field}' should be required in spec") + + def test_model_section_exists(self): + """Test that model section exists in spec.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + self.assertIn('model', spec_properties, "model section should exist in spec") + + def test_model_required_fields(self): + """Test that model section has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model = spec_properties.get('model', {}) + + required_fields = model.get('required', []) + expected_required = ['acceptEula', 'modelId'] + for field in expected_required: + self.assertIn(field, required_fields, + f"Field '{field}' should be required in model") + + def test_model_accept_eula_field(self): + """Test that model.acceptEula field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('acceptEula', model_props, "acceptEula should exist in model") + accept_eula = model_props.get('acceptEula', {}) + self.assertEqual(accept_eula.get('type'), 'boolean', "acceptEula should be of type boolean") + self.assertIn('default', accept_eula, "acceptEula should have a default value") + self.assertEqual(accept_eula.get('default'), False, "acceptEula default should be false") + + def test_model_model_id_field(self): + """Test that model.modelId field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('modelId', model_props, "modelId should exist in model") + model_id = model_props.get('modelId', {}) + self.assertEqual(model_id.get('type'), 'string', "modelId should be of type string") + self.assertIn('pattern', model_id, "modelId should have a pattern validation") + self.assertIn('maxLength', model_id, "modelId should have a maxLength validation") + + def test_model_model_version_field(self): + """Test that model.modelVersion field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('modelVersion', model_props, "modelVersion should exist in model") + model_version = model_props.get('modelVersion', {}) + self.assertEqual(model_version.get('type'), 'string', "modelVersion should be of type string") + self.assertIn('pattern', model_version, "modelVersion should have a pattern validation") + self.assertIn('minLength', model_version, "modelVersion should have a minLength validation") + self.assertIn('maxLength', model_version, "modelVersion should have a maxLength validation") + + def test_model_model_hub_name_field(self): + """Test that model.modelHubName field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('modelHubName', model_props, "modelHubName should exist in model") + model_hub_name = model_props.get('modelHubName', {}) + self.assertEqual(model_hub_name.get('type'), 'string', "modelHubName should be of type string") + self.assertIn('default', model_hub_name, "modelHubName should have a default value") + self.assertEqual(model_hub_name.get('default'), 'SageMakerPublicHub', + "modelHubName default should be 'SageMakerPublicHub'") + + def test_model_gated_model_download_role_field(self): + """Test that model.gatedModelDownloadRole field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('gatedModelDownloadRole', model_props, + "gatedModelDownloadRole should exist in model") + gated_role = model_props.get('gatedModelDownloadRole', {}) + self.assertEqual(gated_role.get('type'), 'string', + "gatedModelDownloadRole should be of type string") + self.assertIn('pattern', gated_role, "gatedModelDownloadRole should have a pattern validation") + self.assertIn('minLength', gated_role, "gatedModelDownloadRole should have a minLength validation") + self.assertIn('maxLength', gated_role, "gatedModelDownloadRole should have a maxLength validation") + + def test_model_additional_configs_field(self): + """Test that model.additionalConfigs field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + + self.assertIn('additionalConfigs', model_props, "additionalConfigs should exist in model") + additional_configs = model_props.get('additionalConfigs', {}) + self.assertEqual(additional_configs.get('type'), 'array', + "additionalConfigs should be of type array") + self.assertIn('maxItems', additional_configs, "additionalConfigs should have maxItems validation") + + def test_model_additional_configs_items_required_fields(self): + """Test that additionalConfigs items have required fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + model_props = spec_properties.get('model', {}).get('properties', {}) + additional_configs_items = model_props.get('additionalConfigs', {}).get('items', {}) + + required_fields = additional_configs_items.get('required', []) + self.assertIn('name', required_fields, "name should be required in additionalConfigs items") + self.assertIn('value', required_fields, "value should be required in additionalConfigs items") + + def test_server_section_exists(self): + """Test that server section exists in spec.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + self.assertIn('server', spec_properties, "server section should exist in spec") + + def test_server_required_fields(self): + """Test that server section has required fields defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + server = spec_properties.get('server', {}) + + required_fields = server.get('required', []) + self.assertIn('instanceType', required_fields, + "instanceType should be required in server") + + def test_server_instance_type_field(self): + """Test that server.instanceType field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + server_props = spec_properties.get('server', {}).get('properties', {}) + + self.assertIn('instanceType', server_props, "instanceType should exist in server") + instance_type = server_props.get('instanceType', {}) + self.assertEqual(instance_type.get('type'), 'string', "instanceType should be of type string") + self.assertIn('pattern', instance_type, "instanceType should have pattern validation") + # Check that the pattern requires 'ml.' prefix + pattern = instance_type.get('pattern', '') + self.assertIn('ml', pattern.lower(), "instanceType pattern should require 'ml.' prefix") + + def test_server_execution_role_field(self): + """Test that server.executionRole field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + server_props = spec_properties.get('server', {}).get('properties', {}) + + self.assertIn('executionRole', server_props, "executionRole should exist in server") + execution_role = server_props.get('executionRole', {}) + self.assertEqual(execution_role.get('type'), 'string', + "executionRole should be of type string") + self.assertIn('pattern', execution_role, "executionRole should have pattern validation") + self.assertIn('minLength', execution_role, "executionRole should have minLength validation") + self.assertIn('maxLength', execution_role, "executionRole should have maxLength validation") + + def test_server_accelerator_partition_type_field(self): + """Test that server.acceleratorPartitionType field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + server_props = spec_properties.get('server', {}).get('properties', {}) + + self.assertIn('acceleratorPartitionType', server_props, + "acceleratorPartitionType should exist in server") + accelerator_partition = server_props.get('acceleratorPartitionType', {}) + self.assertEqual(accelerator_partition.get('type'), 'string', + "acceleratorPartitionType should be of type string") + self.assertIn('pattern', accelerator_partition, + "acceleratorPartitionType should have pattern validation") + + def test_server_validations_field(self): + """Test that server.validations field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + server_props = spec_properties.get('server', {}).get('properties', {}) + + self.assertIn('validations', server_props, "validations should exist in server") + validations = server_props.get('validations', {}) + self.assertEqual(validations.get('type'), 'object', "validations should be of type object") + + def test_optional_fields_exist(self): + """Test that optional but commonly used fields exist in spec.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + + optional_fields = [ + 'replicas', + 'autoScalingSpec', + 'loadBalancer', + 'metrics', + 'tlsConfig', + 'sageMakerEndpoint', + 'environmentVariables', + 'maxDeployTimeInSeconds', + 'kvCacheSpec', + 'intelligentRoutingSpec' + ] + + for field in optional_fields: + self.assertIn(field, spec_properties, + f"Optional field '{field}' should exist in spec") + + def test_replicas_field(self): + """Test that replicas field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + replicas = spec_properties.get('replicas', {}) + + self.assertEqual(replicas.get('type'), 'integer', "replicas should be of type integer") + self.assertIn('default', replicas, "replicas should have a default value") + self.assertEqual(replicas.get('default'), 1, "replicas default should be 1") + + def test_sagemaker_endpoint_fields(self): + """Test that sageMakerEndpoint section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + sagemaker_endpoint_props = spec_properties.get('sageMakerEndpoint', {}).get('properties', {}) + + self.assertIn('name', sagemaker_endpoint_props, "name should exist in sageMakerEndpoint") + name = sagemaker_endpoint_props.get('name', {}) + self.assertEqual(name.get('type'), 'string', "name should be of type string") + self.assertIn('pattern', name, "name should have pattern validation") + self.assertIn('maxLength', name, "name should have maxLength validation") + + def test_environment_variables_field(self): + """Test that environmentVariables field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + env_vars = spec_properties.get('environmentVariables', {}) + + self.assertEqual(env_vars.get('type'), 'array', "environmentVariables should be of type array") + self.assertIn('maxItems', env_vars, "environmentVariables should have maxItems validation") + + def test_environment_variables_items_required_fields(self): + """Test that environmentVariables items have required fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + env_vars_items = spec_properties.get('environmentVariables', {}).get('items', {}) + + required_fields = env_vars_items.get('required', []) + self.assertIn('name', required_fields, "name should be required in environmentVariables items") + self.assertIn('value', required_fields, "value should be required in environmentVariables items") + + def test_auto_scaling_spec_fields(self): + """Test that autoScalingSpec has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + auto_scaling_spec = spec_properties.get('autoScalingSpec', {}).get('properties', {}) + + expected_fields = [ + 'minReplicaCount', + 'maxReplicaCount', + 'cooldownPeriod', + 'pollingInterval', + 'scaleDownStabilizationTime', + 'scaleUpStabilizationTime' + ] + + for field in expected_fields: + self.assertIn(field, auto_scaling_spec, + f"Field '{field}' should exist in autoScalingSpec") + + def test_metrics_fields(self): + """Test that metrics section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + metrics_props = spec_properties.get('metrics', {}).get('properties', {}) + + self.assertIn('enabled', metrics_props, "enabled should exist in metrics") + self.assertIn('metricsScrapeIntervalSeconds', metrics_props, + "metricsScrapeIntervalSeconds should exist in metrics") + self.assertIn('modelMetrics', metrics_props, "modelMetrics should exist in metrics") + + def test_load_balancer_fields(self): + """Test that loadBalancer section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + load_balancer_props = spec_properties.get('loadBalancer', {}).get('properties', {}) + + self.assertIn('healthCheckPath', load_balancer_props, "healthCheckPath should exist in loadBalancer") + self.assertIn('routingAlgorithm', load_balancer_props, "routingAlgorithm should exist in loadBalancer") + + routing_algorithm = load_balancer_props.get('routingAlgorithm', {}) + self.assertIn('enum', routing_algorithm, "routingAlgorithm should have enum values") + self.assertIn('least_outstanding_requests', routing_algorithm.get('enum', []), + "routingAlgorithm should support 'least_outstanding_requests'") + self.assertIn('round_robin', routing_algorithm.get('enum', []), + "routingAlgorithm should support 'round_robin'") + + def test_tls_config_fields(self): + """Test that tlsConfig section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + tls_config_props = spec_properties.get('tlsConfig', {}).get('properties', {}) + + self.assertIn('tlsCertificateOutputS3Uri', tls_config_props, + "tlsCertificateOutputS3Uri should exist in tlsConfig") + tls_uri = tls_config_props.get('tlsCertificateOutputS3Uri', {}) + self.assertIn('pattern', tls_uri, "tlsCertificateOutputS3Uri should have pattern validation") + + def test_kv_cache_spec_fields(self): + """Test that kvCacheSpec section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + kv_cache_props = spec_properties.get('kvCacheSpec', {}).get('properties', {}) + + self.assertIn('enableL1Cache', kv_cache_props, "enableL1Cache should exist in kvCacheSpec") + self.assertIn('enableL2Cache', kv_cache_props, "enableL2Cache should exist in kvCacheSpec") + self.assertIn('l2CacheSpec', kv_cache_props, "l2CacheSpec should exist in kvCacheSpec") + + def test_intelligent_routing_spec_fields(self): + """Test that intelligentRoutingSpec section has expected fields.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + intelligent_routing_props = spec_properties.get('intelligentRoutingSpec', {}).get('properties', {}) + + self.assertIn('enabled', intelligent_routing_props, "enabled should exist in intelligentRoutingSpec") + self.assertIn('routingStrategy', intelligent_routing_props, + "routingStrategy should exist in intelligentRoutingSpec") + + routing_strategy = intelligent_routing_props.get('routingStrategy', {}) + self.assertIn('enum', routing_strategy, "routingStrategy should have enum values") + expected_strategies = ['prefixaware', 'kvaware', 'session', 'roundrobin'] + for strategy in expected_strategies: + self.assertIn(strategy, routing_strategy.get('enum', []), + f"routingStrategy should support '{strategy}'") + + def test_max_deploy_time_field(self): + """Test that maxDeployTimeInSeconds field is properly defined.""" + spec_properties = self.v1_schema.get('properties', {}).get('spec', {}).get('properties', {}) + max_deploy_time = spec_properties.get('maxDeployTimeInSeconds', {}) + + self.assertEqual(max_deploy_time.get('type'), 'integer', + "maxDeployTimeInSeconds should be of type integer") + self.assertIn('default', max_deploy_time, "maxDeployTimeInSeconds should have a default value") + self.assertEqual(max_deploy_time.get('default'), 3600, + "maxDeployTimeInSeconds default should be 3600") + + +if __name__ == '__main__': + unittest.main()