Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ PROJECT_LICENSE=MIT
PROJECT_LICENSE_URL=https://github.com/graphql-java-kickstart/spring-java-servlet/blob/master/LICENSE.md
PROJECT_DEV_ID=oliemansm
PROJECT_DEV_NAME=Michiel Oliemans
LIB_GRAPHQL_JAVA_VER=22.3
LIB_GRAPHQL_JAVA_VER=25.0
LIB_JACKSON_VER=2.17.2
LIB_SLF4J_VER=2.0.16
LIB_LOMBOK_VER=1.18.34
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,29 @@

import graphql.ExecutionResult;
import graphql.GraphQLError;
import graphql.incremental.IncrementalExecutionResult;
import java.util.List;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.reactivestreams.Publisher;

@RequiredArgsConstructor
class DecoratedExecutionResult implements ExecutionResult {
public class DecoratedExecutionResult implements ExecutionResult {

private final ExecutionResult result;

boolean isAsynchronous() {
return result.getData() instanceof Publisher;
}

boolean isIncremental() {
return result instanceof IncrementalExecutionResult;
}

public IncrementalExecutionResult asIncrementalExecutionResult() {
return (IncrementalExecutionResult) result;
}

@Override
public List<GraphQLError> getErrors() {
return result.getErrors();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ public boolean isBatched() {
public boolean isAsynchronous() {
return false;
}

@Override
public boolean isIncremental() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ public boolean isAsynchronous() {
return false;
}

@Override
public boolean isIncremental() {
return false;
}

@Override
public boolean isError() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import graphql.ExecutionResult;
import graphql.ExecutionResultImpl;
import graphql.GraphQLError;
import graphql.incremental.DelayedIncrementalPartialResult;
import graphql.incremental.IncrementalPayload;
import graphql.kickstart.execution.config.ConfiguringObjectMapperProvider;
import graphql.kickstart.execution.config.GraphQLServletObjectMapperConfigurer;
import graphql.kickstart.execution.config.ObjectMapperProvider;
Expand All @@ -18,6 +20,7 @@
import java.io.InputStream;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -118,51 +121,71 @@ public byte[] serializeResultAsBytes(ExecutionResult executionResult) {
return getJacksonMapper().writeValueAsBytes(createResultFromExecutionResult(executionResult));
}

@SneakyThrows
public byte[] serializeDelayedIncrementalResultsAsBytes(DelayedIncrementalPartialResult delayedIncrementalPartialResult) {
return getJacksonMapper().writeValueAsBytes(createResultFromDelayedIncrementalPayloadResult(delayedIncrementalPartialResult));
}

public boolean areErrorsPresent(ExecutionResult executionResult) {
return graphQLErrorHandlerSupplier.get().errorsPresent(executionResult.getErrors());
}

public ExecutionResult sanitizeErrors(ExecutionResult executionResult) {
Object data = executionResult.getData();
public boolean areExtensionsPresent(ExecutionResult executionResult) {
Map<Object, Object> extensions = executionResult.getExtensions();
List<GraphQLError> errors = executionResult.getErrors();
return extensions != null && !extensions.isEmpty();
}

public ExecutionResult sanitizeErrors(ExecutionResult executionResult) {
GraphQLErrorHandler errorHandler = graphQLErrorHandlerSupplier.get();
if (errorHandler.errorsPresent(errors)) {
errors = errorHandler.processErrors(errors);
} else {
errors = null;
}
return new ExecutionResultImpl(data, errors, extensions);
return executionResult.transform(er -> {
List<GraphQLError> errors = executionResult.getErrors();
if (errorHandler.errorsPresent(errors)) {
errors = errorHandler.processErrors(errors);
} else {
errors = List.of();
}
er.errors(errors);
});
}

public DelayedIncrementalPartialResult sanitizeErrors(DelayedIncrementalPartialResult delayedIncrementalPartialResult) {
return delayedIncrementalPartialResult;
}

public Map<String, Object> createResultFromExecutionResult(ExecutionResult executionResult) {
ExecutionResult sanitizedExecutionResult = sanitizeErrors(executionResult);
return convertSanitizedExecutionResult(sanitizedExecutionResult);
}

public Map<String, Object> createResultFromDelayedIncrementalPayloadResult(DelayedIncrementalPartialResult delayedIncrementalPartialResult) {
DelayedIncrementalPartialResult sanitizedDelayedIncrementalPartialResult = sanitizeErrors(delayedIncrementalPartialResult);
return convertSanitizedDelayedIncrementalPartialResult(sanitizedDelayedIncrementalPartialResult);
}

public Map<String, Object> convertSanitizedExecutionResult(ExecutionResult executionResult) {
return convertSanitizedExecutionResult(executionResult, true);
}

public Map<String, Object> convertSanitizedDelayedIncrementalPartialResult(
DelayedIncrementalPartialResult delayedIncrementalPartialResult) {
return delayedIncrementalPartialResult.toSpecification();
}

public Map<String, Object> convertSanitizedExecutionResult(
ExecutionResult executionResult, boolean includeData) {
final Map<String, Object> result = new LinkedHashMap<>();

if (areErrorsPresent(executionResult)) {
result.put(
"errors",
executionResult.getErrors().stream()
.map(GraphQLError::toSpecification)
.collect(toList()));
final Map<String, Object> result = new HashMap<>(executionResult.toSpecification());

if (!areErrorsPresent(executionResult)) {
result.remove("errors");
}

if (executionResult.getExtensions() != null && !executionResult.getExtensions().isEmpty()) {
result.put("extensions", executionResult.getExtensions());
if (!includeData) {
result.remove("data");
}
result.putIfAbsent("data", null);

if (includeData) {
result.put("data", executionResult.getData());
if (!areExtensionsPresent(executionResult)) {
result.remove("extensions");
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ static GraphQLErrorQueryResult createError(int statusCode, String message) {

boolean isAsynchronous();

boolean isIncremental();

default DecoratedExecutionResult getResult() {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,9 @@ public boolean isBatched() {
public boolean isAsynchronous() {
return result.isAsynchronous();
}

@Override
public boolean isIncremental() {
return result.isIncremental();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public void put(Object key, Object value) {
map.put(key, value);
}

public void putAll(Map<Object, Object> values) {
map.putAll(values);
}

@Override
public DataLoaderRegistry getDataLoaderRegistry() {
return dataLoaderRegistry;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package graphql.kickstart.servlet;

import graphql.incremental.DelayedIncrementalPartialResult;
import graphql.kickstart.execution.GraphQLObjectMapper;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.ServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

class DelayedIncrementalPartialResultSubscriber implements Subscriber<DelayedIncrementalPartialResult> {

private final AtomicReference<Subscription> subscriptionRef;
private final AsyncContext asyncContext;
private final GraphQLObjectMapper graphQLObjectMapper;
private final CountDownLatch completedLatch = new CountDownLatch(1);

DelayedIncrementalPartialResultSubscriber(
AtomicReference<Subscription> subscriptionRef,
AsyncContext asyncContext,
GraphQLObjectMapper graphQLObjectMapper) {
this.subscriptionRef = subscriptionRef;
this.asyncContext = asyncContext;
this.graphQLObjectMapper = graphQLObjectMapper;
}

@Override
public void onSubscribe(Subscription subscription) {
subscriptionRef.set(subscription);
subscriptionRef.get().request(1);
}

@Override
public void onNext(DelayedIncrementalPartialResult delayedIncrementalPartialResult) {
try {
ServletResponse response = asyncContext.getResponse();
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8));
outputStream.write(HttpRequestHandler.MULTIPART_CONTENT_TYPE.getBytes(
StandardCharsets.UTF_8));
byte[] contentBytes = graphQLObjectMapper.serializeDelayedIncrementalResultsAsBytes(delayedIncrementalPartialResult);
outputStream.write(contentBytes);
outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8));
if (!delayedIncrementalPartialResult.hasNext()) {
outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8));
}
outputStream.flush();
subscriptionRef.get().request(1);
} catch (IOException ignored) {
// ignore
}
}

@Override
public void onError(Throwable t) {
asyncContext.complete();
completedLatch.countDown();
}

@Override
public void onComplete() {
asyncContext.complete();
completedLatch.countDown();
}

void await() throws InterruptedException {
completedLatch.await();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ public interface HttpRequestHandler {

String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8";
String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8";
String MULTIPART_MIXED = "multipart/mixed; boundary=\"-\"";
String MULTIPART_BOUNDARY = "---\r\n";
String MULTIPART_CONTENT_TYPE = "Content-Type: application/json; charset=UTF-8\r\n\r\n";

int STATUS_OK = 200;
int STATUS_BAD_REQUEST = 400;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ public QueryResponseWriter createWriter(
configuration.getObjectMapper(),
configuration.getSubscriptionTimeout());
}

if (queryResult.isIncremental()) {
return new SingleIncrementalQueryResponseWriter(
queryResult.getResult().asIncrementalExecutionResult(),
configuration.getObjectMapper(),
configuration.getSubscriptionTimeout());
}

if (queryResult.isError()) {
return new ErrorQueryResponseWriter(queryResult.getStatusCode(), queryResult.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package graphql.kickstart.servlet;

import graphql.ExecutionResult;
import graphql.incremental.DelayedIncrementalPartialResult;
import graphql.incremental.IncrementalExecutionResult;
import graphql.incremental.IncrementalPayload;
import graphql.kickstart.execution.GraphQLObjectMapper;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;

@RequiredArgsConstructor
class SingleIncrementalQueryResponseWriter implements QueryResponseWriter {

@Getter private final IncrementalExecutionResult result;
private final GraphQLObjectMapper graphQLObjectMapper;
private final long subscriptionTimeout;

@Override
public void write(HttpServletRequest request, HttpServletResponse response) throws IOException {
Objects.requireNonNull(request, "Http servlet request cannot be null");
response.setContentType(HttpRequestHandler.MULTIPART_MIXED);
response.setStatus(HttpRequestHandler.STATUS_OK);
response.setCharacterEncoding(StandardCharsets.UTF_8.name());

// Write the initial data
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8));
outputStream.write(HttpRequestHandler.MULTIPART_BOUNDARY.getBytes(StandardCharsets.UTF_8));
outputStream.write(HttpRequestHandler.MULTIPART_CONTENT_TYPE.getBytes(
StandardCharsets.UTF_8));
byte[] contentBytes = graphQLObjectMapper.serializeResultAsBytes(result);
outputStream.write(contentBytes);
outputStream.write("\r\n".getBytes(StandardCharsets.UTF_8));
outputStream.flush();

// If no more data is expected, we can just complete the response here
boolean isInAsyncThread = request.isAsyncStarted();
AsyncContext asyncContext =
isInAsyncThread ? request.getAsyncContext() : request.startAsync(request, response);
if (!result.hasNext()) {
asyncContext.complete();
return;
}

// Now handle any delayed incremental payloads
asyncContext.setTimeout(subscriptionTimeout);
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef));
DelayedIncrementalPartialResultSubscriber subscriber =
new DelayedIncrementalPartialResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper);
var publisher = result.getIncrementalItemPublisher();
publisher.subscribe(subscriber);

if (isInAsyncThread) {
// We need to delay the completion of async context until after the subscription has
// terminated, otherwise the AsyncContext is prematurely closed.
try {
subscriber.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
Loading