Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,36 @@
<td>Map</td>
<td>Custom HTTP headers as key-value pairs. Example: <code class="highlighter-rouge">'X-Custom-Header:value,X-Another:value2'</code></td>
</tr>
<tr>
<td><h5>default-value</h5></td>
<td style="word-wrap: break-word;">(none)</td>
<td>String</td>
<td>Fallback value returned when all retry attempts are exhausted. The value is always configured as a STRING but is parsed to match the output column type at runtime. For example, use <code class="highlighter-rouge">'FAILED'</code> for a STRING output column, <code class="highlighter-rouge">'-1'</code> for an INT output column, or <code class="highlighter-rouge">'0.0'</code> for a DOUBLE output column. If not configured, an exception is thrown after all retries fail. Enables downstream routing of failed records, e.g., <code class="highlighter-rouge">WHERE result != 'FAILED'</code> to filter out failures.</td>
</tr>
<tr>
<td><h5>flatten-batch-dim</h5></td>
<td style="word-wrap: break-word;">false</td>
<td>Boolean</td>
<td>Whether to flatten the batch dimension for array inputs. When true, shape [1,N] becomes [N]. Defaults to false.</td>
</tr>
<tr>
<td><h5>max-retries</h5></td>
<td style="word-wrap: break-word;">0</td>
<td>Integer</td>
<td>Maximum number of retry attempts for failed inference requests. Retries are triggered by network errors and retryable server errors (HTTP 408 Request Timeout, HTTP 429 Too Many Requests, HTTP 503 Service Unavailable, HTTP 504 Gateway Timeout). Other client errors (HTTP 4xx) are not retried. Defaults to 0 (no retries).</td>
</tr>
<tr>
<td><h5>priority</h5></td>
<td style="word-wrap: break-word;">(none)</td>
<td>Integer</td>
<td>Request priority level (0-255). Higher values indicate higher priority.</td>
</tr>
<tr>
<td><h5>retry-backoff</h5></td>
<td style="word-wrap: break-word;">100 ms</td>
<td>Duration</td>
<td>Initial backoff duration for the exponential retry strategy. Each subsequent retry doubles the wait time: 100ms, 200ms, 400ms, etc. Only used when max-retries &gt; 0. Defaults to 100ms.</td>
</tr>
<tr>
<td><h5>sequence-end</h5></td>
<td style="word-wrap: break-word;">false</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ public abstract class AbstractTritonModelFunction extends AsyncPredictFunction {
private final String compression;
private final String authToken;
private final Map<String, String> customHeaders;
private final int maxRetries;
private final Duration retryBackoff;
private final String defaultValue;

public AbstractTritonModelFunction(
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
Expand All @@ -100,6 +103,9 @@ public AbstractTritonModelFunction(
this.compression = config.get(TritonOptions.COMPRESSION);
this.authToken = config.get(TritonOptions.AUTH_TOKEN);
this.customHeaders = config.get(TritonOptions.CUSTOM_HEADERS);
this.maxRetries = config.get(TritonOptions.MAX_RETRIES);
this.retryBackoff = config.get(TritonOptions.RETRY_BACKOFF);
this.defaultValue = config.get(TritonOptions.DEFAULT_VALUE);

// Validate input schema - support multiple types
validateInputSchema(factoryContext.getCatalogModel().getResolvedInputSchema());
Expand Down Expand Up @@ -168,7 +174,7 @@ protected void validateSingleColumnSchema(
column.getClass());

Preconditions.checkArgument(
expectedType != null && !expectedType.equals(column.getDataType().getLogicalType()),
expectedType == null || expectedType.equals(column.getDataType().getLogicalType()),
"%s column %s should be %s, but is a %s.",
inputOrOutput,
column.getName(),
Expand Down Expand Up @@ -319,4 +325,16 @@ protected String getAuthToken() {
protected Map<String, String> getCustomHeaders() {
return customHeaders;
}

protected int getMaxRetries() {
return maxRetries;
}

protected Duration getRetryBackoff() {
return retryBackoff;
}

protected String getDefaultValue() {
return defaultValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.DoubleType;
import org.apache.flink.table.types.logical.FloatType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.util.Preconditions;
Expand All @@ -52,8 +56,10 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPOutputStream;

/**
Expand Down Expand Up @@ -90,7 +96,7 @@ public class TritonInferenceModelFunction extends AbstractTritonModelFunction {
private static final ObjectMapper objectMapper = new ObjectMapper();

/** Reusable buffer for gzip compression to avoid repeated allocations. */
private final ByteArrayOutputStream compressionBuffer = new ByteArrayOutputStream(1024);
private transient ByteArrayOutputStream compressionBuffer;

private final LogicalType inputType;
private final LogicalType outputType;
Expand Down Expand Up @@ -137,6 +143,9 @@ public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {
"Unsupported compression algorithm: '%s'. Currently only 'gzip' is supported.",
getCompression());
// Only support GZIP: Compress request body with gzip using reusable buffer.
if (compressionBuffer == null) {
compressionBuffer = new ByteArrayOutputStream(1024);
}
compressionBuffer.reset();
try (GZIPOutputStream gzos = new GZIPOutputStream(compressionBuffer)) {
gzos.write(requestBody.getBytes(StandardCharsets.UTF_8));
Expand All @@ -161,64 +170,175 @@ public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {

Request request = requestBuilder.build();

httpClient
.newCall(request)
.enqueue(
new Callback() {
@Override
public void onFailure(Call call, IOException e) {
executeWithRetry(request, url, 0, future);

} catch (Exception e) {
LOG.error("Failed to build Triton inference request", e);
future.completeExceptionally(e);
}

return future;
}

private void executeWithRetry(
Request request,
String url,
int attempt,
CompletableFuture<Collection<RowData>> future) {
httpClient
.newCall(request)
.enqueue(
new Callback() {
@Override
public void onFailure(Call call, IOException e) {
if (attempt < getMaxRetries()) {
long backoffMs = getRetryBackoff().toMillis() << attempt;
LOG.warn(
"Triton inference network error on attempt {}/{}, "
+ "retrying in {}ms: {}",
attempt + 1,
getMaxRetries() + 1,
backoffMs,
e.getMessage());
CompletableFuture.delayedExecutor(
backoffMs, TimeUnit.MILLISECONDS)
.execute(
() ->
executeWithRetry(
request,
url,
attempt + 1,
future));
} else {
LOG.error(
"Triton inference request failed due to network error",
"Triton inference request failed after {} attempt(s) "
+ "due to network error",
attempt + 1,
e);

// Wrap IOException in TritonNetworkException
TritonNetworkException networkException =
completeWithDefaultOrException(
new TritonNetworkException(
String.format(
"Failed to connect to Triton server at %s: %s. "
+ "This may indicate network connectivity issues, DNS resolution failure, or server unavailability.",
url, e.getMessage()),
e);

future.completeExceptionally(networkException);
"Failed to connect to Triton server at %s "
+ "after %d attempt(s): %s. "
+ "This may indicate network connectivity issues, "
+ "DNS resolution failure, or server unavailability.",
url, attempt + 1, e.getMessage()),
e),
future);
}

@Override
public void onResponse(Call call, Response response)
throws IOException {
try {
if (!response.isSuccessful()) {
}

@Override
public void onResponse(Call call, Response response)
throws IOException {
try {
if (!response.isSuccessful()) {
int statusCode = response.code();
if (isRetryable(statusCode) && attempt < getMaxRetries()) {
long backoffMs =
getRetryBackoff().toMillis() << attempt;
LOG.warn(
"Triton inference failed with HTTP {} on attempt "
+ "{}/{}, retrying in {}ms",
statusCode,
attempt + 1,
getMaxRetries() + 1,
backoffMs);
CompletableFuture.delayedExecutor(
backoffMs, TimeUnit.MILLISECONDS)
.execute(
() ->
executeWithRetry(
request,
url,
attempt + 1,
future));
} else if (isRetryable(statusCode)) {
String errorBody =
response.body() != null
? response.body().string()
: "No error details provided";
LOG.error(
"Triton inference failed with HTTP {} after {} "
+ "attempt(s)",
statusCode,
attempt + 1);
completeWithDefaultOrException(
new TritonServerException(
String.format(
"Triton inference failed with HTTP %d "
+ "after %d attempt(s): %s",
statusCode,
attempt + 1,
errorBody),
statusCode),
future);
} else {
handleErrorResponse(response, future);
return;
}

String responseBody = response.body().string();
Collection<RowData> result =
parseInferenceResponse(responseBody);
future.complete(result);
} catch (JsonProcessingException e) {
LOG.error("Failed to parse Triton inference response", e);
future.completeExceptionally(
new TritonClientException(
"Failed to parse Triton response JSON: "
+ e.getMessage()
+ ". This may indicate an incompatible response format.",
400));
} catch (Exception e) {
LOG.error("Failed to process Triton inference response", e);
future.completeExceptionally(e);
} finally {
response.close();
return;
}

String responseBody = response.body().string();
Collection<RowData> result =
parseInferenceResponse(responseBody);
future.complete(result);
} catch (JsonProcessingException e) {
LOG.error("Failed to parse Triton inference response", e);
future.completeExceptionally(
new TritonClientException(
"Failed to parse Triton response JSON: "
+ e.getMessage()
+ ". This may indicate an incompatible response format.",
400));
} catch (Exception e) {
LOG.error("Failed to process Triton inference response", e);
future.completeExceptionally(e);
} finally {
response.close();
}
});
}
});
}

} catch (Exception e) {
LOG.error("Failed to build Triton inference request", e);
private boolean isRetryable(int statusCode) {
return statusCode == 408 || statusCode == 429 || statusCode == 503 || statusCode == 504;
}

private void completeWithDefaultOrException(
Exception e, CompletableFuture<Collection<RowData>> future) {
if (getDefaultValue() != null) {
LOG.warn(
"Triton inference failed after all retries, "
+ "returning configured default value: {}",
getDefaultValue());
try {
future.complete(Collections.singletonList(buildDefaultRowData()));
} catch (Exception parseEx) {
future.completeExceptionally(parseEx);
}
} else {
future.completeExceptionally(e);
}
}

return future;
private RowData buildDefaultRowData() {
String dv = getDefaultValue();
Object value;
if (outputType instanceof VarCharType) {
value = BinaryStringData.fromString(dv);
} else if (outputType instanceof IntType) {
value = Integer.parseInt(dv);
} else if (outputType instanceof BigIntType) {
value = Long.parseLong(dv);
} else if (outputType instanceof FloatType) {
value = Float.parseFloat(dv);
} else if (outputType instanceof DoubleType) {
value = Double.parseDouble(dv);
} else {
value = BinaryStringData.fromString(dv);
}
return GenericRowData.of(value);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public Set<ConfigOption<?>> optionalOptions() {
set.add(TritonOptions.COMPRESSION);
set.add(TritonOptions.AUTH_TOKEN);
set.add(TritonOptions.CUSTOM_HEADERS);
set.add(TritonOptions.MAX_RETRIES);
set.add(TritonOptions.RETRY_BACKOFF);
set.add(TritonOptions.DEFAULT_VALUE);
return set;
}

Expand Down
Loading