-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathinference.py
More file actions
108 lines (82 loc) · 3.87 KB
/
inference.py
File metadata and controls
108 lines (82 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""This module is for SageMaker inference.py."""
from __future__ import absolute_import
import os
import io
import cloudpickle
import shutil
import platform
from pathlib import Path
from functools import partial
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.validations.check_integrity import perform_integrity_check
import logging
logger = logging.getLogger(__name__)
inference_spec = None
schema_builder = None
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
def model_fn(model_dir):
"""Overrides default method for loading a model"""
shared_libs_path = Path(model_dir + "/shared_libs")
if shared_libs_path.exists():
# before importing, place dynamic linked libraries in shared lib path
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
serve_path = Path(__file__).parent.joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as file:
global inference_spec, schema_builder
obj = cloudpickle.load(file)
if isinstance(obj[0], InferenceSpec):
inference_spec, schema_builder = obj
if inference_spec:
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
def input_fn(input_data, content_type):
"""Deserializes the bytes that were received from the model server"""
try:
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0]
)
# Check if preprocess method is defined and call it
if hasattr(inference_spec, "preprocess"):
return inference_spec.preprocess(deserialized_data)
return deserialized_data
except Exception as e:
logger.error("Encountered error: %s in deserialize_response." % e)
raise Exception("Encountered error in deserialize_request.") from e
def predict_fn(input_data, predict_callable):
"""Invokes the model that is taken in by model server"""
return predict_callable(input_data)
def output_fn(predictions, accept_type):
"""Prediction is serialized to bytes and sent back to the customer"""
try:
if hasattr(inference_spec, "postprocess"):
predictions = inference_spec.postprocess(predictions)
if hasattr(schema_builder, "custom_output_translator"):
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
else:
return schema_builder.output_serializer.serialize(predictions)
except Exception as e:
logger.error("Encountered error: %s in serialize_response." % e)
raise Exception("Encountered error in serialize_response.") from e
def _run_preflight_diagnostics():
_py_vs_parity_check()
_pickle_file_integrity_check()
def _py_vs_parity_check():
container_py_vs = platform.python_version()
local_py_vs = os.getenv("LOCAL_PYTHON")
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
logger.warning(
f"The local python version {local_py_vs} differs from the python version "
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
)
def _pickle_file_integrity_check():
with open(SERVE_PATH, "rb") as f:
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
# on import, execute
_run_preflight_diagnostics()