forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_tgi_integration.py
More file actions
141 lines (109 loc) · 4.63 KB
/
test_tgi_integration.py
File metadata and controls
141 lines (109 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import json
import uuid
import pytest
import logging
import boto3
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.train.configs import Compute
from sagemaker.core.resources import EndpointConfig
from sagemaker.core.helper.session_helper import Session
logger = logging.getLogger(__name__)
# Configuration - easily customizable
MODEL_ID = "t5-small" # Small text generation model
MODEL_NAME_PREFIX = "tgi-test-model"
ENDPOINT_NAME_PREFIX = "tgi-test-endpoint"
@pytest.mark.slow_test
def test_tgi_build_deploy_invoke_cleanup():
"""Integration test for TGI model build, deploy, invoke, and cleanup workflow"""
logger.info("Starting TGI integration test...")
core_model = None
core_endpoint = None
try:
# Build and deploy
logger.info("Building and deploying TGI model...")
core_model, core_endpoint = build_and_deploy()
# Make prediction
logger.info("Making prediction...")
make_prediction(core_endpoint)
# Test passed successfully
logger.info("TGI integration test completed successfully")
except Exception as e:
logger.error(f"TGI integration test failed: {str(e)}")
raise
finally:
# Cleanup resources
if core_model and core_endpoint:
logger.info("Cleaning up resources...")
cleanup_resources(core_model, core_endpoint)
def create_schema_builder():
"""Create schema builder for text generation - exact from backup file."""
from sagemaker.serve.builder.schema_builder import SchemaBuilder
sample_input = {"inputs": "What are falcons?", "parameters": {"max_new_tokens": 32}}
sample_output = [{"generated_text": "Falcons are small to medium-sized birds of prey."}]
return SchemaBuilder(sample_input, sample_output)
def build_and_deploy():
"""Build and deploy TGI model - exact logic from backup file."""
# Use HuggingFace model string for TGI (no local artifacts needed)
hf_model_id = MODEL_ID
schema_builder = create_schema_builder()
unique_id = str(uuid.uuid4())[:8]
compute = Compute(
instance_type="ml.g5.xlarge",
instance_count=1,
)
env_vars = {
"MERGE_LORA": "false", # Disable automatic LoRA detection
"TRUST_REMOTE_CODE": "false",
"DEBUG_ENV": "true",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "DEBUG"
}
model_builder = ModelBuilder(
model=hf_model_id, # Use HuggingFace model string
model_server=ModelServer.TGI,
schema_builder=schema_builder,
compute=compute,
env_vars=env_vars
)
# Build and deploy your model. Returns SageMaker Core Model and Endpoint objects
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
logger.info(f"Model Successfully Created: {core_model.model_name}")
core_endpoint = model_builder.deploy(
endpoint_name=f"{ENDPOINT_NAME_PREFIX}-{unique_id}",
initial_instance_count=1,
)
logger.info(f"Endpoint Successfully Created: {core_endpoint.endpoint_name}")
return core_model, core_endpoint
def make_prediction(core_endpoint):
"""Test invoke - exact logic from backup file."""
test_data = {
"inputs": "What are falcons?",
"parameters": {"max_new_tokens": 32}
}
result = core_endpoint.invoke(
body=json.dumps(test_data),
content_type="application/json"
)
# Decode the output of the invocation and print the result
prediction = json.loads(result.body.read().decode('utf-8'))
logger.info(f"Result of invoking endpoint: {prediction}")
def cleanup_resources(core_model, core_endpoint):
"""Fully clean up model and endpoint creation - preserving exact logic from manual test"""
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_model.delete()
core_endpoint.delete()
core_endpoint_config.delete()
logger.info("Model and Endpoint Successfully Deleted!")