forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
128 lines (102 loc) · 4.47 KB
/
inference.py
File metadata and controls
128 lines (102 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""This module is for SageMaker inference.py."""
from __future__ import absolute_import
import os
import io
import cloudpickle
import shutil
import platform
from pathlib import Path
from functools import partial
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.validations.check_integrity import perform_integrity_check
import logging
logger = logging.getLogger(__name__)
inference_spec = None
schema_builder = None
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
def model_fn(model_dir, context=None):
"""Overrides default method for loading a model"""
shared_libs_path = Path(model_dir + "/shared_libs")
if shared_libs_path.exists():
# before importing, place dynamic linked libraries in shared lib path
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
serve_path = Path(__file__).parent.joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as file:
global inference_spec, schema_builder
obj = cloudpickle.load(file)
if isinstance(obj[0], InferenceSpec):
inference_spec, schema_builder = obj
if inference_spec:
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
def input_fn(input_data, content_type, context=None):
"""Deserializes the bytes that were received from the model server"""
try:
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
(
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type,
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
(
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type[0],
)
# Check if preprocess method is defined and call it
if hasattr(inference_spec, "preprocess"):
return inference_spec.preprocess(deserialized_data)
return deserialized_data
except Exception as e:
logger.error("Encountered error: %s in deserialize_response." % e)
raise Exception("Encountered error in deserialize_request.") from e
def predict_fn(input_data, predict_callable, context=None):
"""Invokes the model that is taken in by model server"""
return predict_callable(input_data)
def output_fn(predictions, accept_type, context=None):
"""Prediction is serialized to bytes and sent back to the customer"""
try:
if hasattr(inference_spec, "postprocess"):
predictions = inference_spec.postprocess(predictions)
if hasattr(schema_builder, "custom_output_translator"):
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
else:
return schema_builder.output_serializer.serialize(predictions)
except Exception as e:
logger.error("Encountered error: %s in serialize_response." % e)
raise Exception("Encountered error in serialize_response.") from e
def _run_preflight_diagnostics():
_py_vs_parity_check()
_pickle_file_integrity_check()
def _py_vs_parity_check():
container_py_vs = platform.python_version()
local_py_vs = os.getenv("LOCAL_PYTHON")
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
logger.warning(
f"The local python version {local_py_vs} differs from the python version "
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
)
def _pickle_file_integrity_check():
with open(SERVE_PATH, "rb") as f:
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
# on import, execute
_run_preflight_diagnostics()