-
Notifications
You must be signed in to change notification settings - Fork 120
Expand file tree
/
Copy pathExecutionParameters.java
More file actions
98 lines (79 loc) · 3.4 KB
/
ExecutionParameters.java
File metadata and controls
98 lines (79 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package software.amazon.ai.mms.plugins.endpoint;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Properties;
import software.amazon.ai.mms.servingsdk.Context;
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;
import software.amazon.ai.mms.servingsdk.http.Request;
import software.amazon.ai.mms.servingsdk.http.Response;
/**
The modified endpoint source code for the jar used in this container.
You can create this endpoint by moving it by cloning the MMS repo:
> git clone https://github.com/awslabs/mxnet-model-server.git
Copy this file into plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoints/
and then from the plugins directory, run:
> ./gradlew fJ
Modify file in plugins/endpoint/resources/META-INF/services/* to specify this file location
Then build the JAR:
> ./gradlew build
The jar should be available in plugins/endpoints/build/libs as endpoints-1.0.jar
**/
@Endpoint(
urlPattern = "execution-parameters",
endpointType = EndpointTypes.INFERENCE,
description = "Execution parameters endpoint")
public class ExecutionParameters extends ModelServerEndpoint {
@Override
public void doGet(Request req, Response rsp, Context ctx) throws IOException {
Properties prop = ctx.getConfig();
// 6 * 1024 * 1024
int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456"));
SagemakerXgboostResponse response = new SagemakerXgboostResponse();
response.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1")));
response.setBatchStrategy("MULTI_RECORD");
response.setMaxPayloadInMB(maxRequestSize / (1024 * 1024));
rsp.getOutputStream()
.write(
new GsonBuilder()
.setPrettyPrinting()
.create()
.toJson(response)
.getBytes(StandardCharsets.UTF_8));
}
/** Response for Model server endpoint */
public static class SagemakerXgboostResponse {
@SerializedName("MaxConcurrentTransforms")
private int maxConcurrentTransforms;
@SerializedName("BatchStrategy")
private String batchStrategy;
@SerializedName("MaxPayloadInMB")
private int maxPayloadInMB;
public SagemakerXgboostResponse() {
maxConcurrentTransforms = 4;
batchStrategy = "MULTI_RECORD";
maxPayloadInMB = 6;
}
public int getMaxConcurrentTransforms() {
return maxConcurrentTransforms;
}
public String getBatchStrategy() {
return batchStrategy;
}
public int getMaxPayloadInMB() {
return maxPayloadInMB;
}
public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) {
maxConcurrentTransforms = newMaxConcurrentTransforms;
}
public void setBatchStrategy(String newBatchStrategy) {
batchStrategy = newBatchStrategy;
}
public void setMaxPayloadInMB(int newMaxPayloadInMB) {
maxPayloadInMB = newMaxPayloadInMB;
}
}
}