Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ static class EventStreamDeserializer extends SpecificShapeDeserializer {
public <T> void readStruct(Schema schema, T builder, ShapeDeserializer.StructMemberConsumer<T> consumer) {
var payloadWritten = false;
for (Schema member : schema.members()) {
if (member.hasTrait(TraitKey.EVENT_HEADER_TRAIT)) {
if (member.hasTrait(TraitKey.EVENT_HEADER_TRAIT)
&& headersDeserializer.headers.containsKey(member.memberName())) {
consumer.accept(builder, member, headersDeserializer);
} else if (member.hasTrait(TraitKey.EVENT_PAYLOAD_TRAIT)) {
consumer.accept(builder, member, codecDeserializer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.nio.charset.StandardCharsets;
import software.amazon.smithy.java.io.ByteBufferUtils;
import software.amazon.smithy.java.io.datastream.DataStream;
import software.amazon.smithy.java.protocoltests.harness.EventStreamClientTests;
import software.amazon.smithy.java.protocoltests.harness.HttpClientRequestTests;
import software.amazon.smithy.java.protocoltests.harness.HttpClientResponseTests;
import software.amazon.smithy.java.protocoltests.harness.ProtocolTest;
Expand All @@ -26,6 +27,40 @@
skipOperations = {
// We dont ignore defaults on input shapes
"aws.protocoltests.restjson#OperationWithDefaults",
},
skipTests = {
// Need to add exception type to header
"ClientErrorInput",
"DuplexClientErrorInput",
// Currently we are using JSON codec for plain text payload, need to correct it.
"StringPayloadOutput",
"DuplexStringPayloadOutput",
// eventstream:1.0.1 made ByteValue.encodeValue() and
// ShortValue.encodeValue() no-ops, producing malformed frames.
"ByteHeaderInput",
"DuplexByteHeaderInput",
"ShortHeaderInput",
"DuplexShortHeaderInput",
// Blob test params use inconsistent encoding conventions —
// headers use base64, payloads use raw strings.
"BlobPayloadInput",
"DuplexBlobPayloadInput",
"BlobPayloadOutput",
"DuplexBlobPayloadOutput",
"BlobHeaderInput",
"DuplexBlobHeaderInput",
"BlobHeaderOutput",
"DuplexBlobHeaderOutput",
"MultipleHeaderInput",
"DuplexMultipleHeaderInput",
"MultipleHeaderOutput",
"DuplexMultipleHeaderOutput",
// Decoder returns modeled error events instead of throwing
"ClientErrorOutput",
"DuplexClientErrorOutput",
// Client doesn't validate missing @required initial response members
"MissingRequiredInitialResponseOutput",
"DuplexMissingRequiredInitialResponseOutput"
})
public class RestJson1ProtocolTests {
private static final String EMPTY_BODY = "";
Expand Down Expand Up @@ -57,4 +92,9 @@ public void requestTest(DataStream expected, DataStream actual) {
public void responseTest(Runnable test) {
test.run();
}

@EventStreamClientTests
public void eventStreamClientTest(Runnable test) {
test.run();
}
}
1 change: 1 addition & 0 deletions protocol-test-harness/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
implementation(project(":client:client-http"))
implementation(project(":codecs:json-codec", configuration = "shadow"))
implementation(libs.assertj.core)
implementation(project(":aws:aws-event-streams"))

api(platform(libs.junit.bom))
api(libs.junit.jupiter.api)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@

package software.amazon.smithy.java.protocoltests.harness;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import software.amazon.eventstream.HeaderValue;
import software.amazon.eventstream.MessageDecoder;
import software.amazon.smithy.java.core.error.ModeledException;
import software.amazon.smithy.java.http.api.HttpMessage;
import software.amazon.smithy.java.http.api.HttpRequest;
import software.amazon.smithy.java.io.uri.SmithyUri;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.node.ObjectNode;
import software.amazon.smithy.model.node.StringNode;
import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase;
import software.amazon.smithy.protocoltests.traits.TestFailureExpectation;
import software.amazon.smithy.protocoltests.traits.eventstream.Event;
import software.amazon.smithy.protocoltests.traits.eventstream.EventHeaderValue;
import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase;

/**
* Provides a number of testing utilities for validating protocol test results.
Expand Down Expand Up @@ -52,7 +69,7 @@ static void assertUriEquals(HttpRequestTestCase testCase, SmithyUri uri) {
}

private static void assertQueryParamsEquals(List<String> expectedParams, String actualQuery) {
var expectedSet = paserQueryParamsList(expectedParams);
var expectedSet = parseQueryParamsList(expectedParams);
var actualSet = parseQueryParamsString(actualQuery);
assertEquals(expectedSet, actualSet, "Query parameters mismatch");
}
Expand All @@ -67,7 +84,7 @@ private static Set<String> parseQueryParamsString(String query) {
return result;
}

private static Set<String> paserQueryParamsList(List<String> params) {
private static Set<String> parseQueryParamsList(List<String> params) {
Set<String> result = new HashSet<>();
for (String paramPair : params) {
var pair = paramPair.split("=", 2);
Expand Down Expand Up @@ -106,4 +123,123 @@ private static String convertHeaderToString(String key, List<String> values) {
return value;
}).collect(Collectors.joining(", "));
}

static void assertEventHeaderEquals(String key, EventHeaderValue<?> expected, HeaderValue actual) {
switch (expected.getType()) {
case BOOLEAN -> assertEquals(expected.asBoolean(), actual.getBoolean(), key);
case BYTE -> assertEquals(expected.asByte(), actual.getByte(), key);
case SHORT -> assertEquals(expected.asShort(), actual.getShort(), key);
case INTEGER -> assertEquals(expected.asInteger(), actual.getInteger(), key);
case LONG -> assertEquals(expected.asLong(), actual.getLong(), key);
case STRING -> assertEquals(expected.asString(), actual.getString(), key);
case BLOB -> assertArrayEquals(expected.asBlob(), actual.getByteArray(), key);
case TIMESTAMP -> assertEquals(expected.asTimestamp(), actual.getTimestamp(), key);
}
}

static void assertEventStreamRequestEquals(HttpRequest request, Event event) {
var bodyBytes = request.body().asByteBuffer();
var decoder = new MessageDecoder();
decoder.feed(bodyBytes.duplicate());
var messages = decoder.getDecodedMessages();
var message = messages.getFirst();
var actualHeaders = message.getHeaders();
for (var entry : event.getHeaders().entrySet()) {
var key = entry.getKey();
var expected = entry.getValue();
var actual = actualHeaders.get(key);
assertThat(actual).as("Missing header: " + key).isNotNull();
Assertions.assertEventHeaderEquals(key, expected, actual);
}
for (var header : event.getForbidHeaders()) {
assertFalse(actualHeaders.containsKey(header));
}
for (var header : event.getRequireHeaders()) {
assertTrue(actualHeaders.containsKey(header));
}
event.getBody().ifPresent(expectedBody -> {
assertEventStreamBodyEquals(expectedBody,
new String(message.getPayload(), StandardCharsets.UTF_8),
event.getBodyMediaType().orElse(null));
});
}

static void assertInitialRequestEquals(EventStreamTestCase testCase, HttpRequest request) {
if (testCase.getInitialRequest().isPresent()) {
var initialRequest = testCase.getInitialRequest().get();
assertEquals(initialRequest.expectStringMember("uri").getValue(), request.uri().getPath());
assertEquals(initialRequest.expectStringMember("method").getValue(), request.method());
initialRequest.getStringMember("resolvedHost").ifPresent(host -> {
assertEquals(host.getValue(), request.uri().getHost());
});
var actualQueryParams = request.uri().getQuery();
if (actualQueryParams != null) {
assertInitialRequestQueryMatches(initialRequest, actualQueryParams);
}
assertInitialRequestHeaderMatches(initialRequest, request);
initialRequest.getStringMember("body").ifPresent(bodyNode -> {
assertEventStreamBodyEquals(bodyNode.getValue(),
new StringBuildingSubscriber(request.body()).getResult(),
initialRequest.getStringMember("bodyMediaType").map(StringNode::getValue).orElse(null));
});
}
}

private static void assertInitialRequestQueryMatches(ObjectNode initialRequest, String actualQuery) {
initialRequest.getArrayMember("queryParams").ifPresent(params -> {
assertQueryParamsEquals(params.getElementsAs(StringNode::getValue), actualQuery);
});
var queryParamSet = parseQueryParamsString(actualQuery);
initialRequest.getArrayMember("forbidQueryParams").ifPresent(params -> {
for (var param : params.getElementsAs(StringNode::getValue)) {
assertFalse(queryParamSet.contains(param));
}
});

initialRequest.getArrayMember("requireQueryParams").ifPresent(params -> {
for (var param : params.getElementsAs(StringNode::getValue)) {
assertTrue(queryParamSet.contains(param));
}
});
}

private static void assertInitialRequestHeaderMatches(ObjectNode initialRequest, HttpRequest actualRequest) {
var actualHeaders = actualRequest.headers().map();
initialRequest.getObjectMember("headers").ifPresent(headersNode -> {
Map<String, String> headers = new HashMap<>();
headersNode.getStringMap().forEach((k, v) -> headers.put(k, v.expectStringNode().getValue()));
assertHeadersEqual(actualRequest, headers);
});

initialRequest.getArrayMember("forbidHeaders").ifPresent(headers -> {
for (var header : headers.getElementsAs(StringNode::getValue)) {
assertFalse(actualHeaders.containsKey(header));
}
});

initialRequest.getArrayMember("requireHeaders").ifPresent(headers -> {
for (var header : headers.getElementsAs(StringNode::getValue)) {
assertTrue(actualHeaders.containsKey(header));
}
});
}

private static void assertEventStreamBodyEquals(String expectedBody, String actualBody, String bodyType) {
if ("application/json".equals(bodyType)) {
Node.assertEquals(Node.parse(expectedBody), Node.parse(actualBody));
} else {
assertEquals(expectedBody, actualBody);
}
}

static void assertExpectationEquals(EventStreamTestCase testCase, Throwable e) {
testCase.getExpectation()
.getFailure()
.flatMap(TestFailureExpectation::getErrorId)
.ifPresent(errorId -> {
assertInstanceOf(ModeledException.class, e);
assertEquals(errorId.getName(),
((ModeledException) e).schema().id().getName());
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.java.protocoltests.harness;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.platform.commons.annotation.Testable;

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@TestTemplate
@Testable
@Timeout(5)
@ExtendWith(EventStreamClientTestsProtocolTestProvider.class)
public @interface EventStreamClientTests {}
Loading