From 8b64464be9c80bf90466800e71a46f47f65f3a6f Mon Sep 17 00:00:00 2001 From: Adwait Kumar Singh Date: Sun, 17 May 2026 18:03:47 -0700 Subject: [PATCH 1/3] Optimize CBOR Codec --- .../java/client/rpcv2/RpcV2CborProtocol.java | 7 - .../rpcv2/AbstractRpcV2ClientProtocol.java | 47 +- .../smithy/java/cbor/CborDeserializer.java | 696 +++++++++++++----- .../smithy/java/cbor/CborMemberLookup.java | 88 +++ .../amazon/smithy/java/cbor/CborParser.java | 538 -------------- .../amazon/smithy/java/cbor/CborReadUtil.java | 136 ++-- .../java/cbor/CborSchemaExtensions.java | 83 +++ .../smithy/java/cbor/CborSerdeProvider.java | 4 +- .../smithy/java/cbor/CborSerializer.java | 672 ++++++++++++++--- .../java/cbor/DefaultCborSerdeProvider.java | 21 +- .../smithy/java/cbor/Rpcv2CborCodec.java | 6 + .../amazon/smithy/java/cbor/Sink.java | 186 ----- ...y.java.core.schema.SchemaExtensionProvider | 1 + .../smithy/java/cbor/CborCodecTest.java | 590 +++++++++++++++ .../smithy/java/cbor/CborParserTest.java | 672 ----------------- .../amazon/java/cbor/CborComparator.java | 227 +----- codecs/json-codec/build.gradle.kts | 32 - .../amazon/smithy/java/json/JsonBench.java | 231 ------ .../java/codegen/generators/MapGenerator.java | 10 +- .../java/http/api/ArrayHttpHeaders.java | 9 + .../java/http/api/ModifiableHttpHeaders.java | 10 + .../http/api/ModifiableHttpResponseImpl.java | 5 +- .../smithy/java/http/api/HttpHeadersTest.java | 30 + 23 files changed, 1999 insertions(+), 2302 deletions(-) create mode 100644 codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborMemberLookup.java delete mode 100644 codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborParser.java create mode 100644 codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java delete mode 100644 codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Sink.java create mode 100644 codecs/cbor-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider create mode 100644 codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java delete mode 100644 codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborParserTest.java delete mode 100644 codecs/json-codec/src/jmh/java/software/amazon/smithy/java/json/JsonBench.java diff --git a/client/client-rpcv2-cbor/src/main/java/software/amazon/smithy/java/client/rpcv2/RpcV2CborProtocol.java b/client/client-rpcv2-cbor/src/main/java/software/amazon/smithy/java/client/rpcv2/RpcV2CborProtocol.java index f7b398fe7..896dfb471 100644 --- a/client/client-rpcv2-cbor/src/main/java/software/amazon/smithy/java/client/rpcv2/RpcV2CborProtocol.java +++ b/client/client-rpcv2-cbor/src/main/java/software/amazon/smithy/java/client/rpcv2/RpcV2CborProtocol.java @@ -11,8 +11,6 @@ import software.amazon.smithy.java.client.core.ClientProtocolFactory; import software.amazon.smithy.java.client.core.ProtocolSettings; import software.amazon.smithy.java.core.serde.Codec; -import software.amazon.smithy.java.http.api.HttpVersion; -import software.amazon.smithy.java.http.api.ModifiableHttpRequest; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.protocol.traits.Rpcv2CborTrait; @@ -32,11 +30,6 @@ protected Codec codec() { return CBOR_CODEC; } - @Override - protected void customizeRequestBuilder(ModifiableHttpRequest builder) { - builder.setHttpVersion(HttpVersion.HTTP_2); - } - public static final class Factory implements ClientProtocolFactory { @Override public ShapeId id() { diff --git a/client/client-rpcv2/src/main/java/software/amazon/smithy/java/client/rpcv2/AbstractRpcV2ClientProtocol.java b/client/client-rpcv2/src/main/java/software/amazon/smithy/java/client/rpcv2/AbstractRpcV2ClientProtocol.java index a8ef01b91..fa322f05c 100644 --- a/client/client-rpcv2/src/main/java/software/amazon/smithy/java/client/rpcv2/AbstractRpcV2ClientProtocol.java +++ b/client/client-rpcv2/src/main/java/software/amazon/smithy/java/client/rpcv2/AbstractRpcV2ClientProtocol.java @@ -27,7 +27,6 @@ import software.amazon.smithy.java.http.api.HttpRequest; import software.amazon.smithy.java.http.api.HttpResponse; import software.amazon.smithy.java.http.api.ModifiableHttpRequest; -import software.amazon.smithy.java.io.ByteBufferOutputStream; import software.amazon.smithy.java.io.datastream.DataStream; import software.amazon.smithy.java.io.uri.SmithyUri; import software.amazon.smithy.model.shapes.ShapeId; @@ -46,6 +45,8 @@ public abstract class AbstractRpcV2ClientProtocol extends HttpClientProtocol { private final ShapeId service; private final String payloadMediaType; private final String smithyProtocolValue; + private final String targetPathPrefix; + private final ModifiableHttpRequest templateRequest; private volatile HttpErrorDeserializer errorDeserializer; private static final String SMITHY_PROTOCOL_PREFIX = "rpc-v2-"; @@ -67,6 +68,13 @@ protected AbstractRpcV2ClientProtocol( this.payloadMediaType = payloadMediaType; this.smithyProtocolValue = SMITHY_PROTOCOL_PREFIX + payloadMediaType.substring(MEDIA_TYPE_PREFIX_LENGTH); + this.targetPathPrefix = "/service/" + service.getName() + "/operation/"; + + var tmpl = HttpRequest.create(); + tmpl.setMethod("POST"); + tmpl.addHeader(HeaderName.SMITHY_PROTOCOL, smithyProtocolValue); + tmpl.addHeader(HeaderName.ACCEPT, payloadMediaType); + this.templateRequest = tmpl; } /** Returns the codec used for serialization and deserialization. */ @@ -92,14 +100,6 @@ private HttpErrorDeserializer errorDeserializer() { return errorDeserializer; } - /** - * Hook for subclasses to customize the request builder before headers and body are set. - * For example, the CBOR protocol uses this to force HTTP/2. - */ - protected void customizeRequestBuilder(ModifiableHttpRequest builder) { - // default: no customization - } - @Override public HttpRequest createRequest( ApiOperation operation, @@ -107,27 +107,20 @@ public HttpRequest Context context, SmithyUri endpoint ) { - var target = "/service/" + service.getName() + "/operation/" + operation.schema().id().getName(); - var builder = HttpRequest.create().setMethod("POST").setUri(endpoint.withConcatPath(target)); - - customizeRequestBuilder(builder); + var target = targetPathPrefix + operation.schema().id().getName(); + var builder = templateRequest.toModifiableCopy(); + builder.setUri(endpoint.withConcatPath(target)); if (operation.inputSchema().hasTrait(TraitKey.UNIT_TYPE_TRAIT)) { - builder.addHeader(HeaderName.SMITHY_PROTOCOL, smithyProtocolValue) - .addHeader(HeaderName.ACCEPT, payloadMediaType) - .setBody(DataStream.ofEmpty()); + builder.setBody(DataStream.ofEmpty()); } else if (operation.inputEventBuilderSupplier() != null) { var encoderFactory = getEventEncoderFactory(operation); var body = RpcEventStreamsUtil.bodyForEventStreaming(encoderFactory, input); - builder.addHeader(HeaderName.SMITHY_PROTOCOL, smithyProtocolValue) - .addHeader(HeaderName.CONTENT_TYPE, "application/vnd.amazon.eventstream") - .addHeader(HeaderName.ACCEPT, payloadMediaType) - .setBody(body); + builder.addHeader(HeaderName.CONTENT_TYPE, "application/vnd.amazon.eventstream"); + builder.setBody(body); } else { - builder.addHeader(HeaderName.SMITHY_PROTOCOL, smithyProtocolValue) - .addHeader(HeaderName.CONTENT_TYPE, payloadMediaType) - .addHeader(HeaderName.ACCEPT, payloadMediaType) - .setBody(getBody(input)); + builder.addHeader(HeaderName.CONTENT_TYPE, payloadMediaType); + builder.setBody(getBody(input)); } return builder; } @@ -166,11 +159,7 @@ private static DataStream bodyDataStream(HttpResponse response) { } private DataStream getBody(SerializableStruct input) { - var sink = new ByteBufferOutputStream(); - try (var serializer = codec().createSerializer(sink)) { - input.serialize(serializer); - } - return DataStream.ofByteBuffer(sink.toByteBuffer(), payloadMediaType); + return DataStream.ofByteBuffer(codec().serialize(input), payloadMediaType); } private EventEncoderFactory getEventEncoderFactory(ApiOperation operation) { diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java index 7a9f37643..4746c1554 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java @@ -5,21 +5,20 @@ package software.amazon.smithy.java.cbor; +import static software.amazon.smithy.java.cbor.CborReadUtil.argLength; import static software.amazon.smithy.java.cbor.CborReadUtil.readByteString; +import static software.amazon.smithy.java.cbor.CborReadUtil.readPosInt; +import static software.amazon.smithy.java.cbor.CborReadUtil.readStrLen; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import software.amazon.smithy.java.cbor.CborParser.Token; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.TraitKey; import software.amazon.smithy.java.core.serde.SerializationException; @@ -27,113 +26,459 @@ import software.amazon.smithy.java.core.serde.document.Document; final class CborDeserializer implements ShapeDeserializer { - private static final class Canonicalizer { - private record Canonical(Schema member, byte[] utf8) implements Comparable { - @Override - public int compareTo(Canonical o) { - return Arrays.compare(utf8, o.utf8); - } - private Schema isSame(byte[] bytes, int off, int len) { - if (Arrays.compare(utf8, 0, utf8.length, bytes, off, off + len) == 0) { - return member; - } - return null; - } + static final class Token { + public static final byte TAG_FLAG = 1 << 4; + + public static final byte POS_INT = TYPE_POSINT; + public static final byte NEG_INT = TYPE_NEGINT; + public static final byte BYTE_STRING = TYPE_BYTESTRING; + public static final byte TEXT_STRING = TYPE_TEXTSTRING; + + public static final byte NULL = 4; + public static final byte KEY = 5; + public static final byte START_OBJECT = 6; + public static final byte START_ARRAY = 7; + public static final byte END_OBJECT = 8; + public static final byte END_ARRAY = 9; + + public static final byte POS_BIGINT = 10; + public static final byte NEG_BIGINT = 11; + public static final byte FLOAT = 12; + public static final byte BIG_DECIMAL = 14; + public static final byte TRUE = 15; + public static final byte FALSE = TRUE | TAG_FLAG; + + public static final byte EPOCH_IPOS = POS_INT | TAG_FLAG; + public static final byte EPOCH_INEG = NEG_INT | TAG_FLAG; + public static final byte EPOCH_F = FLOAT | TAG_FLAG; + + public static final byte FINISHED = -1; + + static String name(byte token) { + return switch (token) { + case POS_INT -> "POS_INT"; + case NEG_INT -> "NEG_INT"; + case BYTE_STRING -> "BYTE_STRING"; + case TEXT_STRING -> "TEXT_STRING"; + case POS_BIGINT -> "POS_BIGINT"; + case NEG_BIGINT -> "NEG_BIGINT"; + case FLOAT -> "FLOAT"; + case BIG_DECIMAL -> "BIG_DECIMAL"; + case TRUE -> "TRUE"; + case FALSE -> "FALSE"; + case EPOCH_IPOS -> "EPOCH_IPOS"; + case EPOCH_INEG -> "EPOCH_INEG"; + case EPOCH_F -> "EPOCH_F"; + case NULL -> "NULL"; + case KEY -> "KEY"; + case START_OBJECT -> "START_OBJECT"; + case START_ARRAY -> "START_ARRAY"; + case END_ARRAY -> "END_ARRAY"; + case END_OBJECT -> "END_OBJECT"; + case FINISHED -> "FINISHED"; + default -> throw new BadCborException("unknown token " + token); + }; } + } - private final Object[][] canonicals; + static final int MAJOR_TYPE_SHIFT = 5, + MAJOR_TYPE_MASK = 0b111_00000, + MINOR_TYPE_MASK = 0b0001_1111; - Canonicalizer(Schema schema) { - int biggest = 0; - Map> bySize = new HashMap<>(); - for (var member : schema.members()) { - byte[] utf8 = member.memberName().getBytes(StandardCharsets.UTF_8); - biggest = Math.max(biggest, utf8.length); - bySize.computeIfAbsent(utf8.length, $ -> new ArrayList<>()) - .add(new Canonical(member, utf8)); - } + static final byte TYPE_POSINT = 0, + TYPE_NEGINT = 1, + TYPE_BYTESTRING = 2, + TYPE_TEXTSTRING = 3, + TYPE_ARRAY = 4, + TYPE_MAP = 5, + TYPE_TAG = 6, + TYPE_SIMPLE = 7; - canonicals = new Object[biggest + 1][]; - for (var entry : bySize.entrySet()) { - int len = entry.getKey(); - var canonsForLen = entry.getValue().toArray(new Canonical[0]); - Arrays.sort(canonsForLen); - canonicals[len] = canonsForLen; - } - } + static final int ZERO_BYTES = 23, + ONE_BYTE = 24, + EIGHT_BYTES = 27, + INDEFINITE = 31; - Schema resolve(byte[] payload, int off, int len) { - if (len >= canonicals.length) { - return null; - } + static final int SIMPLE_FALSE = 20, + SIMPLE_TRUE = 21, + SIMPLE_NULL = 22, + SIMPLE_UNDEFINED = 23, + SIMPLE_VALUE_1 = 24, + SIMPLE_HALF_FLOAT = 25, + SIMPLE_FLOAT = 26, + SIMPLE_DOUBLE = 27; - Object[] canonicals = this.canonicals[len]; - if (canonicals == null) { - return null; - } + static final byte SIMPLE_STREAM_BREAK = (byte) ((TYPE_SIMPLE << MAJOR_TYPE_SHIFT) | INDEFINITE); - if (canonicals.length == 1) { - return getMemberIfSame(canonicals[0], payload, off, len); - } else { - for (int i = 0; i < canonicals.length; i++) { - var member = getMemberIfSame(canonicals[i], payload, off, len); - if (member != null) { - return member; - } - } - return null; - } - } + static final byte TAG_TIME_RFC3339 = 0, + TAG_TIME_EPOCH = 1, + TAG_POS_BIGNUM = 2, + TAG_NEG_BIGNUM = 3, + TAG_DECIMAL = 4; - private Schema getMemberIfSame(Object o, byte[] bytes, int off, int len) { - return ((Canonical) o).isSame(bytes, off, len); - } + private static final int FLAG_INDEFINITE_LEN = 1 << 31; + private static final int MASK_LEN = ~FLAG_INDEFINITE_LEN; + + static boolean isIndefinite(int itemLength) { + return itemLength < 0; } - private static final Map CANONICALIZERS = new ConcurrentHashMap<>(); + static int itemLength(int itemLength) { + return itemLength & MASK_LEN; + } - private final CborParser parser; private final CborSettings settings; private final byte[] payload; + private final int end; + private int idx; + private byte token; + + // Definite sizes shrink to zero, indefinite sizes start at -1 and decrement meaninglessly towards Long.MIN_VALUE. + // Must be long because we need to store 2 * size for maps, and a map can have up to Integer.MAX_VALUE elements. + // Count is left shifted one. Low bit is collection type: 0 == map, 1 == array. + private long currentState = 0; + private long[] previousStates = new long[4]; + private boolean inCollection = false; + private int historyDepth = 0; + private int itemLength = 0; + private int overhead = 0; // overhead is [0,8] + private boolean readingTag = false; CborDeserializer(byte[] payload, CborSettings settings) { - this.parser = new CborParser(payload); - this.settings = settings; this.payload = payload; - parser.advance(); + this.settings = settings; + this.end = payload.length; + this.idx = 0; + advance(); } CborDeserializer(ByteBuffer byteBuffer, CborSettings settings) { this.settings = settings; if (byteBuffer.hasArray()) { - byte[] payload = byteBuffer.array(); - this.payload = payload; + this.payload = byteBuffer.array(); int start = byteBuffer.arrayOffset() + byteBuffer.position(); - this.parser = new CborParser( - payload, - start, - start + byteBuffer.remaining()); + this.idx = start; + this.end = start + byteBuffer.remaining(); } else { int pos = byteBuffer.position(); this.payload = new byte[byteBuffer.remaining()]; byteBuffer.get(this.payload); - this.parser = new CborParser(this.payload); + this.idx = 0; + this.end = this.payload.length; byteBuffer.position(pos); } - parser.advance(); + advance(); + } + + private byte advance() { + return (token = nextToken0()); + } + + private byte nextToken0() { + if (inCollection) { + long state = currentState; + if (state >> 1 == 0) { + return getEndToken(state); + } else if ((state & 3) == 0) { + int i = (idx += itemLength(itemLength) + overhead); + if (i >= end) { + throwIncompleteCollectionException(); + } + return dispatchKey(payload[i]); + } + } + + int i = (idx += itemLength(itemLength) + overhead); + if (i >= end) { + return endOfBuffer(i); + } + + return dispatch(payload[i]); + } + + private byte dispatchKey(byte b) { + byte major = (byte) ((b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT); + if (major == TYPE_TEXTSTRING) { + byte minor = (byte) (b & MINOR_TYPE_MASK); + string(major, minor); + return Token.KEY; + } else if (b == SIMPLE_STREAM_BREAK) { + return endStreamCollection(); + } else { + throw new BadCborException("map keys must be strings"); + } + } + + private byte endOfBuffer(int i) { + itemLength = 0; + overhead = 0; + if (i > end) { + throw new BadCborException("unexpected end of payload"); + } + if (inCollection) { + throwIncompleteCollectionException(); + } + return Token.FINISHED; + } + + private byte getEndToken(long state) { + byte retVal = state == 0 ? Token.END_OBJECT : Token.END_ARRAY; + if (historyDepth > 0) { + currentState = previousStates[--historyDepth]; + } else { + inCollection = false; + currentState = 0; + } + return retVal; + } + + private byte dispatch(byte b) { + byte major = (byte) ((b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT); + byte minor = (byte) (b & MINOR_TYPE_MASK); + switch (major) { + case TYPE_POSINT: + case TYPE_NEGINT: + return integer(major, minor); + case TYPE_BYTESTRING: + case TYPE_TEXTSTRING: + return string(major, minor); + case TYPE_ARRAY: + case TYPE_MAP: + return collection(major, minor); + case TYPE_TAG: + return tag(minor); + case TYPE_SIMPLE: + return simple(major, minor); + default: + throw new BadCborException("unknown major type: " + major); + } + } + + private byte tag(byte minor) { + // RFC8949 3.4 permits nested tags, but I see no need to support anything beyond the simple + // tags that are relevant to the Smithy object model. + if (readingTag) + throw new BadCborException("nested tags not permitted"); + overhead = 1; + itemLength = 0; + readingTag = true; + byte next = advance(); + readingTag = false; + switch (minor) { + case TAG_TIME_EPOCH: + if (next != Token.FLOAT && next > Token.NEG_INT) + throw new BadCborException("malformed instant: got " + Token.name(next)); + return (byte) (next | Token.TAG_FLAG); + case TAG_POS_BIGNUM: + if (next != Token.BYTE_STRING) + throw new BadCborException("malformed +bignum: got " + Token.name(next)); + return Token.POS_BIGINT; + case TAG_NEG_BIGNUM: + if (next != Token.BYTE_STRING) + throw new BadCborException("malformed -bignum: got " + Token.name(next)); + return Token.NEG_BIGINT; + case TAG_DECIMAL: + tagDecimalFp(next); + return Token.BIG_DECIMAL; + default: + throw new BadCborException("unsupported tag minor " + minor); + } + } + + private void tagDecimalFp(byte next) { + if (next != Token.START_ARRAY) + badDecimalInitialType(next); + int start = idx; + byte token; + if ((token = advance()) > Token.NEG_INT) + badDecimalArgument1(token); + token = advance(); + int tmp = token & 0b11110; + if (tmp != Token.POS_INT && tmp != Token.POS_BIGINT) + badDecimalArgument2(token); + if ((token = advance()) != Token.END_ARRAY) + badDecimalFinalToken(token); + itemLength = idx - start + itemLength(itemLength) + overhead - 1; + overhead = 1; + idx = start; + } + + private static void badDecimalInitialType(byte next) { + throw new BadCborException("malformed BIG_DECIMAL: got " + Token.name(next)); + } + + private static void badDecimalArgument1(byte token) { + throw new BadCborException("malformed BIG_DECIMAL: expected int 1, got " + Token.name(token)); + } + + private static void badDecimalArgument2(byte token) { + throw new BadCborException("malformed BIG_DECIMAL: expected int 2, got " + Token.name(token)); + } + + private static void badDecimalFinalToken(byte token) { + throw new BadCborException("malformed BIG_DECIMAL: expected END_ARRAY, got " + Token.name(token)); + } + + private byte integer(byte major, byte minor) { + if (minor == INDEFINITE) + throw new BadCborException("numeric type has indefinite length"); + int argLength = argLength(minor); + if (argLength > 0) { + overhead = 0; + idx++; + } else { + overhead = 1; + } + itemLength = argLength; + // 2 because the count is left-shifted one (2 == 1 << 1) + currentState -= 2; + return major; + } + + private byte simple(byte major, byte minor) { + if (minor <= SIMPLE_VALUE_1) { + currentState -= 2; + itemLength = 1; + overhead = 0; + switch (minor) { + case SIMPLE_FALSE: + return Token.FALSE; + case SIMPLE_TRUE: + return Token.TRUE; + case SIMPLE_NULL: + case SIMPLE_UNDEFINED: + return Token.NULL; + default: + throw new BadCborException("bad simple minor type " + minor); + } + } else if (minor <= SIMPLE_DOUBLE) { + integer(major, minor); + return Token.FLOAT; + } else if (minor == INDEFINITE) { + return endStreamCollection(); + } + throw new BadCborException("illegal simple minor type " + minor); + } + + private byte endStreamCollection() { + itemLength = 0; + overhead = 1; + // note that we can leave the collection type in the low bit. all that matters is that + // the number is negative, and the low bit will only make a positive number more positive + // and a negative number more negative. + if (!inCollection || currentState >= 0) + throw new BadCborException("unexpected indefinite terminator"); + long state = currentState; + if (historyDepth > 0) { + currentState = previousStates[--historyDepth]; + } else { + inCollection = false; + currentState = 0; + } + return (state & 1) == 0 ? Token.END_OBJECT : Token.END_ARRAY; + } + + private byte string(byte major, byte minor) { + overhead = 0; + if (minor == INDEFINITE) { + readIndefiniteLength(major); + } else { + int argLen = argLength(minor); + itemLength = readImm(minor, argLen); + } + currentState -= 2; + return major; + } + + private int readImm(int minor, int argLen) { + if (argLen == 0) { + idx++; + return minor; + } else { + int ret = readPosInt(payload, ++idx, argLen); + idx += argLen; + return ret; + } + } + + private byte collection(byte major, byte minor) { + itemLength = 0; + long size; + if (minor == INDEFINITE) { + overhead = 1; + size = -1; + } else { + int argLen = argLength(minor); + overhead = 0; + size = readImm(minor, argLen); + } + if (inCollection) { + currentState -= 2; + if (historyDepth == previousStates.length) { + previousStates = Arrays.copyOf(previousStates, previousStates.length * 2); + } + previousStates[historyDepth++] = currentState; + } + inCollection = true; + if (major == TYPE_ARRAY) { + currentState = (size << 1) | 1; + return Token.START_ARRAY; + } else { + currentState = size << 2; + return Token.START_OBJECT; + } + } + + private void readIndefiniteLength(byte type) { + itemLength = 0; + int scan = ++idx; + while (true) { + if (scan >= end) + throw new BadCborException("non-terminating string"); + byte b = payload[scan]; + if (b == SIMPLE_STREAM_BREAK) { + overhead++; + break; + } + int major = (b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT; + int minor = b & MINOR_TYPE_MASK; + if (major != type) { + throw new BadCborException("major type misalign: " + type + " " + major); + } + if (minor == INDEFINITE) + throw new BadCborException("expected finite length"); + int argLen = argLength(minor); + int strLen = readStrLen(payload, scan, minor, argLen); + int totalOverhead = argLen + 1; + overhead += totalOverhead; + itemLength += strLen; + scan += totalOverhead + strLen; + } + itemLength |= FLAG_INDEFINITE_LEN; + } + + private int collectionSize() { + long s = currentState >> 2; + return s >= 0 ? (int) s : -1; + } + + private void throwIncompleteCollectionException() { + String type = (currentState & 1L) == 0 ? "map" : "array"; + String msg = currentState < 0 ? "stream break" : ((currentState >> 1) + " more elements"); + throw new BadCborException("incomplete " + type + ": expecting " + msg); } @Override public void close() { - if (parser.currentToken() != Token.FINISHED) { + if (token != Token.FINISHED) { throw new SerializationException("Unexpected CBOR content at end of object"); } } @Override public boolean readBoolean(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; if (token == Token.TRUE) { return true; } else if (token == Token.FALSE) { @@ -148,12 +493,12 @@ private static SerializationException badType(String type, byte token) { @Override public ByteBuffer readBlob(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; if (token == Token.BYTE_STRING) { - int pos = parser.getPosition(); - int len = parser.getItemLength(); + int pos = idx; + int len = itemLength; ByteBuffer buffer; - if (CborParser.isIndefinite(len)) { + if (isIndefinite(len)) { buffer = ByteBuffer.wrap(readByteString(payload, pos, len)); } else { buffer = ByteBuffer.wrap(payload, pos, len).slice(); @@ -165,27 +510,27 @@ public ByteBuffer readBlob(Schema schema) { @Override public byte readByte(Schema schema) { - return (byte) readLong("byte", parser.currentToken()); + return (byte) readLong("byte", this.token); } @Override public short readShort(Schema schema) { - return (short) readLong("short", parser.currentToken()); + return (short) readLong("short", this.token); } @Override public int readInteger(Schema schema) { - return (int) readLong("integer", parser.currentToken()); + return (int) readLong("integer", this.token); } @Override public long readLong(Schema schema) { - return readLong("long", parser.currentToken()); + return readLong("long", this.token); } private long readLong(String type, byte token) { - int off = parser.getPosition(); - int len = parser.getItemLength(); + int off = idx; + int len = itemLength; if (token > Token.NEG_INT) throw badType(type, token); long val = CborReadUtil.readLong(payload, token, off, len); @@ -201,12 +546,12 @@ private long readLong(String type, byte token) { @Override public float readFloat(Schema schema) { - return (float) readDouble("float", parser.currentToken()); + return (float) readDouble("float", this.token); } @Override public double readDouble(Schema schema) { - return readDouble("double", parser.currentToken()); + return readDouble("double", this.token); } private double readDouble(String type, byte token) { @@ -218,8 +563,8 @@ private double readDouble(String type, byte token) { private double readDouble(byte token) { - int pos = parser.getPosition(); - int len = parser.getItemLength(); + int pos = idx; + int len = itemLength; long fp = CborReadUtil.readLong(payload, token, pos, len); // ordered by how likely it is we'll encounter each case if (len == 8) { // double @@ -258,40 +603,40 @@ private static float float16(int hbits) { @Override public BigInteger readBigInteger(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; int tmp = token & 0b11110; if (tmp != Token.POS_INT && tmp != Token.POS_BIGINT) { throw badType("biginteger", token); } - return CborReadUtil.readBigInteger(payload, token, parser.getPosition(), parser.getItemLength()); + return CborReadUtil.readBigInteger(payload, token, idx, itemLength); } @Override public BigDecimal readBigDecimal(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; if (token == Token.BIG_DECIMAL) { - return CborReadUtil.readBigDecimal(payload, parser.getPosition()); + return CborReadUtil.readBigDecimal(payload, idx); } else if (token == Token.FLOAT) { return BigDecimal.valueOf(readDouble(token)); } else if (token <= Token.NEG_INT) { return BigDecimal - .valueOf(CborReadUtil.readLong(payload, token, parser.getPosition(), parser.getItemLength())); + .valueOf(CborReadUtil.readLong(payload, token, idx, itemLength)); } throw badType("bigdecimal", token); } @Override public String readString(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; if (token != Token.TEXT_STRING) { throw badType("string", token); } - return CborReadUtil.readTextString(payload, parser.getPosition(), parser.getItemLength()); + return CborReadUtil.readTextString(payload, idx, itemLength); } @Override public Document readDocument() { - var token = parser.currentToken(); + var token = this.token; if (token == Token.FINISHED) { throw new SerializationException("No CBOR value to read"); } @@ -304,8 +649,8 @@ public Document readDocument() { case Token.FALSE -> Document.of(false); case Token.EPOCH_INEG, Token.EPOCH_IPOS, Token.EPOCH_F -> Document.of(readTimestamp(null)); case Token.FLOAT -> { - int pos = parser.getPosition(); - int len = parser.getItemLength(); + int pos = idx; + int len = itemLength; long fp = CborReadUtil.readLong(payload, token, pos, len); // ordered by how likely it is we'll encounter each case if (len == 8) { // double @@ -320,20 +665,20 @@ public Document readDocument() { case Token.BIG_DECIMAL -> Document.of(readBigDecimal(null)); case Token.START_ARRAY -> { List values = new ArrayList<>(); - for (token = parser.advance(); token != Token.END_ARRAY; token = parser.advance()) { + for (token = advance(); token != Token.END_ARRAY; token = advance()) { values.add(readDocument()); } yield Document.of(values); } case Token.START_OBJECT -> { Map values = new LinkedHashMap<>(); - for (token = parser.advance(); token != Token.END_OBJECT; token = parser.advance()) { + for (token = advance(); token != Token.END_OBJECT; token = advance()) { if (token != Token.KEY) { throw badType("struct member", token); } - var key = CborReadUtil.readTextString(payload, parser.getPosition(), parser.getItemLength()); - parser.advance(); + var key = CborReadUtil.readTextString(payload, idx, itemLength); + advance(); values.put(key, readDocument()); } yield CborDocuments.of(values, settings); @@ -344,10 +689,10 @@ public Document readDocument() { @Override public Instant readTimestamp(Schema schema) { - byte token = parser.currentToken(); + byte token = this.token; byte actual = (byte) (token ^ Token.TAG_FLAG); if (actual <= Token.NEG_INT) { - return Instant.ofEpochMilli(readLong("timestamp", token) * 1000); + return Instant.ofEpochSecond(readLong("timestamp", actual)); } else if (actual == Token.FLOAT) { double d = readDouble("timestamp", actual); return Instant.ofEpochMilli(Math.round(d * 1000d)); @@ -357,41 +702,96 @@ public Instant readTimestamp(Schema schema) { @Override public void readStruct(Schema schema, T state, StructMemberConsumer consumer) { - byte token = parser.currentToken(); - if (token == Token.FINISHED && schema.hasTrait(TraitKey.UNIT_TYPE_TRAIT)) { - // Empty input — treat as empty struct with no members. - return; - } + byte token = this.token; if (token != Token.START_OBJECT) { - throw badType("struct", token); + readStructEmpty(schema, token); + return; } - var canonicalizer = getCanonicalizer(schema); - for (token = parser.advance(); token != Token.END_OBJECT; token = parser.advance()) { + Schema structSchema = schema.isMember() ? schema.memberTarget() : schema; + var ext = structSchema.getExtension(CborSchemaExtensions.KEY); + CborMemberLookup lookup = ext != null ? ext.memberLookup() : null; + int expectedNext = 0; + + for (token = advance(); token != Token.END_OBJECT; token = advance()) { if (token != Token.KEY) { throw badType("struct member", token); } - int memberPos = parser.getPosition(); - int memberLen = parser.getItemLength(); - // don't dispatch any events for explicit nulls - if (parser.advance() == Token.NULL) { + int memberPos = idx; + int memberLen = itemLength; + if (advance() == Token.NULL) { continue; } - // wait to resolve the member until we know an event will be dispatched - Object member = resolveMember(schema, canonicalizer, payload, memberPos, memberLen); - if (member.getClass() == String.class) { - consumer.unknownMember(state, (String) member); - skipUnknownMember(); - } else { - consumer.accept(state, (Schema) member, this); + Schema member = null; + + if (!isIndefinite(memberLen) && lookup != null) { + if (expectedNext >= 0 && expectedNext < lookup.orderedNameBytes.length) { + byte[] expected = lookup.orderedNameBytes[expectedNext]; + if (expected.length == memberLen + && Arrays.equals( + payload, + memberPos, + memberPos + memberLen, + expected, + 0, + memberLen)) { + member = lookup.orderedSchemas[expectedNext]; + expectedNext = member.memberIndex() + 1; + } + } + if (member == null) { + member = lookup.lookup(payload, memberPos, memberPos + memberLen, -1); + if (member != null) { + expectedNext = member.memberIndex() + 1; + } + } } + + if (member == null) { + expectedNext = resolveMemberFallback( + structSchema, + memberPos, + memberLen, + expectedNext, + state, + consumer); + continue; + } + + consumer.accept(state, member, this); } } + private void readStructEmpty(Schema schema, byte token) { + if (token == Token.FINISHED && schema.hasTrait(TraitKey.UNIT_TYPE_TRAIT)) { + return; + } + throw badType("struct", token); + } + + private int resolveMemberFallback( + Schema structSchema, + int memberPos, + int memberLen, + int expectedNext, + T state, + StructMemberConsumer consumer + ) { + String name = CborReadUtil.readTextString(payload, memberPos, memberLen); + Schema member = structSchema.member(name); + if (member != null) { + consumer.accept(state, member, this); + return member.memberIndex() + 1; + } + consumer.unknownMember(state, name); + skipUnknownMember(); + return expectedNext; + } + private void skipUnknownMember() { - byte current = parser.currentToken(); + byte current = token; if (current != Token.START_OBJECT && current != Token.START_ARRAY) { return; } @@ -403,81 +803,47 @@ private void skipUnknownMember() { } else if ((current == Token.END_OBJECT || current == Token.END_ARRAY) && --depth == 0) { return; } - current = parser.advance(); - } - } - - private static Object resolveMember(Schema host, Canonicalizer canonicalizer, byte[] payload, int pos, int len) { - // this method is static for safety. the parser has already advanced past the member field by the time - // resolveMember is called, so don't touch the parser or any other instance state. - if (CborParser.isIndefinite(len)) { - return resolveSlow(host, payload, pos, len); - } - - var schema = canonicalizer.resolve(payload, pos, len); - if (schema != null) { - return schema; + current = advance(); } - return CborReadUtil.readTextString(payload, pos, len); - } - - private static Object resolveSlow(Schema host, byte[] payload, int pos, int len) { - var name = CborReadUtil.readTextString(payload, pos, len); - var schema = host.member(name); - return schema != null ? schema : host; - } - - private Canonicalizer getCanonicalizer(Schema schema) { - var canonicalizer = CANONICALIZERS.get(schema); - if (canonicalizer == null) { - canonicalizer = new Canonicalizer(schema); - CANONICALIZERS.put(schema, canonicalizer); - } - - return canonicalizer; } @Override public void readList(Schema schema, T state, ListMemberConsumer consumer) { - byte token = parser.currentToken(); + byte token = this.token; if (token != Token.START_ARRAY) { throw badType("list", token); } - for (token = parser.advance(); token != Token.END_ARRAY; token = parser.advance()) { + for (token = advance(); token != Token.END_ARRAY; token = advance()) { consumer.accept(state, this); } } @Override public int containerSize() { - return parser.collectionSize(); + return collectionSize(); } @Override public void readStringMap(Schema schema, T state, MapMemberConsumer consumer) { - byte token = parser.currentToken(); + byte token = this.token; if (token != Token.START_OBJECT) { throw badType("struct", token); } - for (token = parser.advance(); token != Token.END_OBJECT; token = parser.advance()) { + for (token = advance(); token != Token.END_OBJECT; token = advance()) { if (token != Token.KEY) { throw badType("key", token); } - var key = CborReadUtil.readTextString(payload, parser.getPosition(), parser.getItemLength()); - parser.advance(); + var key = CborReadUtil.readTextString(payload, idx, itemLength); + advance(); consumer.accept(state, key, this); } } @Override public boolean isNull() { - return parser.currentToken() == Token.NULL; + return token == Token.NULL; } - @Override - public T readNull() { - return null; - } } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborMemberLookup.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborMemberLookup.java new file mode 100644 index 000000000..dedc851a5 --- /dev/null +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborMemberLookup.java @@ -0,0 +1,88 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.cbor; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.java.core.schema.Schema; + +/** + * Hash-based field name lookup that operates directly on UTF-8 bytes during CBOR deserialization. + * + *

Strategy: + *

    + *
  1. Speculative fast path: check the expected next member via length + Arrays.equals + * (JVM-intrinsified). Fires on every field when CBOR map entries arrive in schema definition + * order (the common case for smithy-to-smithy communication). + *
  2. Slow path: linear scan with FNV-1a hash pre-filter. Hash is computed lazily + * (only on speculative miss) to avoid the per-byte multiply+XOR cost on the hot path. + *
+ */ +final class CborMemberLookup { + + private static final long FNV_OFFSET = 0xcbf29ce484222325L; + private static final long FNV_PRIME = 0x100000001b3L; + + final long[] orderedHashes; + final Schema[] orderedSchemas; + final byte[][] orderedNameBytes; + + CborMemberLookup(List members) { + int size = members.size(); + this.orderedHashes = new long[size]; + this.orderedSchemas = new Schema[size]; + this.orderedNameBytes = new byte[size][]; + + for (int i = 0; i < size; i++) { + Schema m = members.get(i); + byte[] nameBytes = m.memberName().getBytes(StandardCharsets.UTF_8); + orderedNameBytes[i] = nameBytes; + orderedHashes[i] = fnvHash(nameBytes, 0, nameBytes.length); + orderedSchemas[i] = m; + } + } + + /** + * Looks up a member by matching the field name bytes directly from the CBOR payload buffer. + * No String allocation on the common path. + * + * @param buf input buffer containing the field name bytes (raw UTF-8, no CBOR header) + * @param start start offset in buf + * @param end end offset in buf (exclusive) + * @param expectedNext speculative next member index (-1 to disable) + * @return matched Schema, or null if not found + */ + Schema lookup(byte[] buf, int start, int end, int expectedNext) { + int nameLen = end - start; + + if (expectedNext >= 0 && expectedNext < orderedNameBytes.length + && orderedNameBytes[expectedNext].length == nameLen + && Arrays.equals(buf, start, end, orderedNameBytes[expectedNext], 0, nameLen)) { + return orderedSchemas[expectedNext]; + } + + long hash = fnvHash(buf, start, end); + for (int i = 0; i < orderedHashes.length; i++) { + if (orderedHashes[i] == hash + && orderedNameBytes[i].length == nameLen + && Arrays.equals(buf, start, end, orderedNameBytes[i], 0, nameLen)) { + return orderedSchemas[i]; + } + } + + return null; + } + + private static long fnvHash(byte[] buf, int start, int end) { + long hash = FNV_OFFSET; + for (int i = start; i < end; i++) { + hash ^= buf[i] & 0xFF; + hash *= FNV_PRIME; + } + return hash; + } +} diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborParser.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborParser.java deleted file mode 100644 index 3ee5db1e4..000000000 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborParser.java +++ /dev/null @@ -1,538 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.cbor; - -import static software.amazon.smithy.java.cbor.CborParser.Token.name; -import static software.amazon.smithy.java.cbor.CborReadUtil.argLength; -import static software.amazon.smithy.java.cbor.CborReadUtil.readPosInt; -import static software.amazon.smithy.java.cbor.CborReadUtil.readStrLen; - -import java.util.Arrays; -import software.amazon.smithy.utils.SmithyInternalApi; - -@SmithyInternalApi -public final class CborParser { - public static final class Token { - public static int version() { - return 1; - } - - // high bit toggle to indicate a tagged item - public static final byte TAG_FLAG = 1 << 4; - - // the first group of simple types directly map to their cbor major types for simpler reading routines - public static final byte POS_INT = TYPE_POSINT; // 0b00000 - public static final byte NEG_INT = TYPE_NEGINT; // 0b00001 - public static final byte BYTE_STRING = TYPE_BYTESTRING; // 0b00010 - public static final byte TEXT_STRING = TYPE_TEXTSTRING; // 0b00011 - - // types that require explicit method invocations on ReadTranslator - // these values must be contiguous for efficient lookup table generation - public static final byte NULL = 4; // 0b00100 - public static final byte KEY = 5; // 0b00101 - public static final byte START_OBJECT = 6; // 0b00110 - public static final byte START_ARRAY = 7; // 0b00111 - public static final byte END_OBJECT = 8; // 0b01000 - public static final byte END_ARRAY = 9; // 0b01001 - - // the second group of simple types have arbitrary values and can pretty much be anything - public static final byte POS_BIGINT = 10; // 0b01010 - public static final byte NEG_BIGINT = 11; // 0b01011 - public static final byte FLOAT = 12; // 0b01100 - // gap: we used to have a bigfloat token, but we never supported that type - public static final byte BIG_DECIMAL = 14; // 0b01110 - public static final byte TRUE = 15; // 0b01111 - public static final byte FALSE = TRUE | TAG_FLAG; // 0b11111 (31) - - // tag types are the type of the tagged data with the high bit set - // these do not need to be contiguous - public static final byte EPOCH_IPOS = POS_INT | TAG_FLAG; // 0b10000 (16) - public static final byte EPOCH_INEG = NEG_INT | TAG_FLAG; // 0b10001 (17) - public static final byte EPOCH_F = FLOAT | TAG_FLAG; // 0b11100 (28) - - public static final byte FINISHED = -1; - - public static String name(byte token) { - switch (token) { - case POS_INT: - return "POS_INT"; - case NEG_INT: - return "NEG_INT"; - case BYTE_STRING: - return "BYTE_STRING"; - case TEXT_STRING: - return "TEXT_STRING"; - case POS_BIGINT: - return "POS_BIGINT"; - case NEG_BIGINT: - return "NEG_BIGINT"; - case FLOAT: - return "FLOAT"; - case BIG_DECIMAL: - return "BIG_DECIMAL"; - case TRUE: - return "TRUE"; - case FALSE: - return "FALSE"; - case EPOCH_IPOS: - return "EPOCH_IPOS"; - case EPOCH_INEG: - return "EPOCH_INEG"; - case EPOCH_F: - return "EPOCH_F"; - case NULL: - return "NULL"; - case KEY: - return "KEY"; - case START_OBJECT: - return "START_OBJECT"; - case START_ARRAY: - return "START_ARRAY"; - case END_ARRAY: - return "END_ARRAY"; - case END_OBJECT: - return "END_OBJECT"; - } - if (token == FINISHED) - return "FINISHED"; - throw new BadCborException("unknown token " + token); - } - } - - static final int MAJOR_TYPE_SHIFT = 5, - MAJOR_TYPE_MASK = 0b111_00000, - MINOR_TYPE_MASK = 0b0001_1111; - - static final byte TYPE_POSINT = 0, - TYPE_NEGINT = 1, - TYPE_BYTESTRING = 2, - TYPE_TEXTSTRING = 3, - TYPE_ARRAY = 4, - TYPE_MAP = 5, - TYPE_TAG = 6, - TYPE_SIMPLE = 7; - - static final int ZERO_BYTES = 23, - ONE_BYTE = 24, - EIGHT_BYTES = 27, - INDEFINITE = 31; - - static final int SIMPLE_FALSE = 20, - SIMPLE_TRUE = 21, - SIMPLE_NULL = 22, - SIMPLE_UNDEFINED = 23, - SIMPLE_VALUE_1 = 24, // value follows in next 1 byte, currently reserved and unused - SIMPLE_HALF_FLOAT = 25, - SIMPLE_FLOAT = 26, - SIMPLE_DOUBLE = 27; - - static final byte SIMPLE_STREAM_BREAK = (byte) ((TYPE_SIMPLE << MAJOR_TYPE_SHIFT) | INDEFINITE); - - static final byte TAG_TIME_RFC3339 = 0, // expect text string - TAG_TIME_EPOCH = 1, // expect integer or float - TAG_POS_BIGNUM = 2, // expect byte string - TAG_NEG_BIGNUM = 3, // expect byte string - TAG_DECIMAL = 4; // expect two-element integer array - - private static final int FLAG_INDEFINITE_LEN = 1 << 31; - private static final int MASK_LEN = ~FLAG_INDEFINITE_LEN; - - public static boolean isIndefinite(int itemLength) { - return itemLength < 0; - } - - public static int itemLength(int itemLength) { - return itemLength & MASK_LEN; - } - - private final byte[] buffer; - private final int end; - private int idx; - private byte token; - - // Definite sizes shrink to zero, indefinite sizes start at -1 and decrement meaninglessly towards Long.MIN_VALUE. - // Must be long because we need to store 2 * size for maps, and a map can have up to Integer.MAX_VALUE elements. - // Count is left shifted one. Low bit is collection type: 0 == map, 1 == array. - private long currentState = 0; - private long[] previousStates = new long[4]; - private boolean inCollection = false; - private int historyDepth = 0; - private int itemLength = 0; - private int overhead = 0; // overhead is [0,8] - private boolean readingTag = false; - - public CborParser(byte[] buffer) { - this(buffer, 0, buffer.length); - } - - public CborParser(byte[] buffer, int off, int end) { - this.buffer = buffer; - this.idx = off; - this.end = end; - } - - /** - * @return the starting position of the current data item - */ - public int getPosition() { - return idx; - } - - /** - * If the last token returned by {@link #advance()} is a single data item (e.g. a numeric type), - * this indicates the number of bytes from {@link #getPosition()} to read to retrieve it. - * - *

This method does not return a defined result for collection types like strings, arrays, or maps. - * - * @return number of bytes encoding the current single-element data item, or undefined - */ - public int getItemLength() { - return itemLength; - } - - public int collectionSize() { - long s = currentState >> 2; - return s >= 0 ? (int) s : -1; - } - - public byte currentToken() { - return token; - } - - /** - * Gets the next {@link Token} in the payload. The data for this token begins at {@link #getPosition()}. The data's - * length is determined by {@link #getItemLength()} if the token is not one of these types: - * - *

    - *
  • {@link Token#NULL}
  • - *
  • {@link Token#START_ARRAY}
  • - *
  • {@link Token#END_ARRAY}
  • - *
  • {@link Token#START_OBJECT}
  • - *
  • {@link Token#END_OBJECT}
  • - *
- * - * @return the next {@link Token} in the payload - */ - public byte advance() { - return (token = nextToken0()); - } - - private byte nextToken0() { - if (inCollection) { - long state = currentState; - if (state >> 1 == 0) { - // count is 0, so the only remaining value is the collection type in the low bit - return getEndToken(state); - } else if ((state & 3) == 0) { - // mask is 0b11: low bit is collection type (map == 0), high bit is 0 if the count is even - int i = (idx += itemLength(itemLength) + overhead); - if (i >= end) { - throwIncompleteCollectionException(); - } - return dispatchKey(buffer[i]); - } - } - - int i = (idx += itemLength(itemLength) + overhead); - if (i >= end) { - return endOfBuffer(i); - } - - return dispatch(buffer[i]); - } - - private byte dispatchKey(byte b) { - byte major = (byte) ((b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT); - if (major == TYPE_TEXTSTRING) { - byte minor = (byte) (b & MINOR_TYPE_MASK); - string(major, minor); - return Token.KEY; - } else if (b == SIMPLE_STREAM_BREAK) { - return endStreamCollection(); - } else { - throw new BadCborException("map keys must be strings"); - } - } - - private byte endOfBuffer(int i) { - itemLength = 0; - overhead = 0; - if (i > end) { - throw new BadCborException("unexpected end of payload"); - } - if (inCollection) { - throwIncompleteCollectionException(); - } - return Token.FINISHED; - } - - private byte getEndToken(long state) { - byte retVal = state == 0 ? Token.END_OBJECT : Token.END_ARRAY; - if (historyDepth > 0) { - currentState = previousStates[--historyDepth]; - } else { - inCollection = false; - currentState = 0; - } - return retVal; - } - - private byte dispatch(byte b) { - byte major = (byte) ((b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT); - byte minor = (byte) (b & MINOR_TYPE_MASK); - // major is guaranteed in range [0,7] by the mask-and-shift operation - switch (major) { - case TYPE_POSINT: - case TYPE_NEGINT: - return integer(major, minor); - case TYPE_BYTESTRING: - case TYPE_TEXTSTRING: - return string(major, minor); - case TYPE_ARRAY: - case TYPE_MAP: - return collection(major, minor); - case TYPE_TAG: - return tag(minor); - case TYPE_SIMPLE: - return simple(major, minor); - default: - throw new BadCborException("unknown major type: " + major); - } - } - - private byte tag(byte minor) { - // RFC8949 3.4 permits nested tags, but I see no need to support anything beyond the simple - // tags that are relevant to the Smithy object model. - if (readingTag) - throw new BadCborException("nested tags not permitted"); - // reset increments before calling nextToken. 1 overhead for this tag /immediate value - overhead = 1; - itemLength = 0; - readingTag = true; - byte next = advance(); - readingTag = false; - switch (minor) { - case TAG_TIME_EPOCH: - if (next != Token.FLOAT && next > Token.NEG_INT) - throw new BadCborException("malformed instant: got " + name(next)); - return (byte) (next | Token.TAG_FLAG); - case TAG_POS_BIGNUM: - if (next != Token.BYTE_STRING) - throw new BadCborException("malformed +bignum: got " + name(next)); - return Token.POS_BIGINT; - case TAG_NEG_BIGNUM: - if (next != Token.BYTE_STRING) - throw new BadCborException("malformed -bignum: got " + name(next)); - return Token.NEG_BIGINT; - case TAG_DECIMAL: - tagDecimalFp(next); - return Token.BIG_DECIMAL; - default: - throw new BadCborException("unsupported tag minor " + minor); - } - } - - private void tagDecimalFp(byte next) { - // A decimal fraction or a bigfloat is represented as a tagged array that contains - // exactly an integer and a bignum/integer - if (next != Token.START_ARRAY) - badDecimalInitialType(next); - int start = idx; - byte token; - if ((token = advance()) > Token.NEG_INT) - badDecimalArgument1(token); - token = advance(); - int tmp = token & 0b11110; - if (tmp != Token.POS_INT && tmp != Token.POS_BIGINT) - badDecimalArgument2(token); - if ((token = advance()) != Token.END_ARRAY) - badDecimalFinalToken(token); - itemLength = idx - start + itemLength(itemLength) + overhead - 1; - overhead = 1; - idx = start; - } - - private static void badDecimalInitialType(byte next) { - throw new BadCborException("malformed BIG_DECIMAL: got " + name(next)); - } - - private static void badDecimalArgument1(byte token) { - throw new BadCborException("malformed BIG_DECIMAL: expected int 1, got " + name(token)); - } - - private static void badDecimalArgument2(byte token) { - throw new BadCborException("malformed BIG_DECIMAL: expected int 2, got " + name(token)); - } - - private static void badDecimalFinalToken(byte token) { - throw new BadCborException("malformed BIG_DECIMAL: expected END_ARRAY, got " + name(token)); - } - - private byte integer(byte major, byte minor) { - if (minor == INDEFINITE) - throw new BadCborException("numeric type has indefinite length"); - int argLength = argLength(minor); - if (argLength > 0) { - overhead = 0; - idx++; - } else { - overhead = 1; - } - itemLength = argLength; - // 2 because the count is left-shifted one (2 == 1 << 1) - currentState -= 2; - return major; - } - - private byte simple(byte major, byte minor) { - if (minor <= SIMPLE_VALUE_1) { - currentState -= 2; - itemLength = 1; - overhead = 0; - switch (minor) { - case SIMPLE_FALSE: - return Token.FALSE; - case SIMPLE_TRUE: - return Token.TRUE; - case SIMPLE_NULL: - case SIMPLE_UNDEFINED: - return Token.NULL; - default: - throw new BadCborException("bad simple minor type " + minor); - } - } else if (minor <= SIMPLE_DOUBLE) { - // collectionSize is decremented in integer if necessary - integer(major, minor); - return Token.FLOAT; - } else if (minor == INDEFINITE) { - return endStreamCollection(); - } - throw new BadCborException("illegal simple minor type " + minor); - } - - private byte endStreamCollection() { - // no need to decrement collectionSize in this branch since we're in an indefinite collection - itemLength = 0; - overhead = 1; - // note that we can leave the collection type in the low bit. all that matters is that - // the number is negative, and the low bit will only make a positive number more positive - // and a negative number more negative. - if (!inCollection || currentState >= 0) - throw new BadCborException("unexpected indefinite terminator"); - long state = currentState; - if (historyDepth > 0) { - currentState = previousStates[--historyDepth]; - } else { - inCollection = false; - currentState = 0; - } - return (state & 1) == 0 ? Token.END_OBJECT : Token.END_ARRAY; - } - - /** - * Reads a {@linkplain #TYPE_TEXTSTRING text string} or {@linkplain #TYPE_BYTESTRING bytestring} - * from the buffer. - * - *

Definite length strings have no overhead. {@code idx} will point to the start of the string - * data and {@code itemLength} will be the number of bytes in the string. The next data item begins - * at {@code idx + itemLength}. - * - *

Indefinite length strings are a bit trickier. {@code idx} will point to the start of the first - * string in the sequence and {@code itemLength} will be the number of bytes in the final assembled - * string. However, this is not the number of bytes that this string spans in the CBOR payload. - * We use a separate count, {@linkplain #overhead}, to factor in the additional overhead of encoding - * non-data tags. We add all opening tag bytes, length operands, and the final closing {@link #INDEFINITE} - * byte to this value. - */ - private byte string(byte major, byte minor) { - overhead = 0; - if (minor == INDEFINITE) { - readIndefiniteLength(major); - } else { - int argLen = argLength(minor); - itemLength = readImm(minor, argLen); - } - currentState -= 2; - return major; - } - - private int readImm(int minor, int argLen) { - if (argLen == 0) { - // minor is the collection/string length, data begins on next byte - idx++; - return minor; - } else { - // minor is the number of bytes following this one that encode the collection/string length - int ret = readPosInt(buffer, ++idx, argLen); - idx += argLen; - return ret; - } - } - - private byte collection(byte major, byte minor) { - // collection length is tracked in collectionSizes - itemLength = 0; - long size; - if (minor == INDEFINITE) { - overhead = 1; - size = -1; - } else { - int argLen = argLength(minor); - overhead = 0; - size = readImm(minor, argLen); - } - if (inCollection) { - currentState -= 2; - if (historyDepth == previousStates.length) { - previousStates = Arrays.copyOf(previousStates, previousStates.length * 2); - } - previousStates[historyDepth++] = currentState; - } - inCollection = true; - if (major == TYPE_ARRAY) { - currentState = (size << 1) | 1; - return Token.START_ARRAY; - } else { //major == TYPE_MAP - currentState = size << 2; - return Token.START_OBJECT; - } - } - - // idx = start of string, itemLength = byte count, overhead = tag bytes - private void readIndefiniteLength(byte type) { - itemLength = 0; - int scan = ++idx; - while (true) { - if (scan >= end) - throw new BadCborException("non-terminating string"); - byte b = buffer[scan]; - if (b == SIMPLE_STREAM_BREAK) { - overhead++; - break; - } - int major = (b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT; - int minor = b & MINOR_TYPE_MASK; - if (major != type) { - throw new BadCborException("major type misalign: " + type + " " + major); - } - if (minor == INDEFINITE) - throw new BadCborException("expected finite length"); - int argLen = argLength(minor); - int strLen = readStrLen(buffer, scan, minor, argLen); - int totalOverhead = argLen + 1; - overhead += totalOverhead; - itemLength += strLen; - scan += totalOverhead + strLen; - } - itemLength |= FLAG_INDEFINITE_LEN; - } - - private void throwIncompleteCollectionException() { - String type = (currentState & 1L) == 0 ? "map" : "array"; - String msg = currentState < 0 ? "stream break" : ((currentState >> 1) + " more elements"); - throw new BadCborException("incomplete " + type + ": expecting " + msg); - } -} diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborReadUtil.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborReadUtil.java index 7fe048651..4d6232dfe 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborReadUtil.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborReadUtil.java @@ -5,27 +5,36 @@ package software.amazon.smithy.java.cbor; -import static software.amazon.smithy.java.cbor.CborParser.EIGHT_BYTES; -import static software.amazon.smithy.java.cbor.CborParser.MAJOR_TYPE_MASK; -import static software.amazon.smithy.java.cbor.CborParser.MAJOR_TYPE_SHIFT; -import static software.amazon.smithy.java.cbor.CborParser.MINOR_TYPE_MASK; -import static software.amazon.smithy.java.cbor.CborParser.ONE_BYTE; -import static software.amazon.smithy.java.cbor.CborParser.TAG_NEG_BIGNUM; -import static software.amazon.smithy.java.cbor.CborParser.TAG_POS_BIGNUM; -import static software.amazon.smithy.java.cbor.CborParser.TYPE_NEGINT; -import static software.amazon.smithy.java.cbor.CborParser.TYPE_POSINT; -import static software.amazon.smithy.java.cbor.CborParser.TYPE_TAG; -import static software.amazon.smithy.java.cbor.CborParser.ZERO_BYTES; +import static software.amazon.smithy.java.cbor.CborDeserializer.EIGHT_BYTES; +import static software.amazon.smithy.java.cbor.CborDeserializer.MAJOR_TYPE_MASK; +import static software.amazon.smithy.java.cbor.CborDeserializer.MAJOR_TYPE_SHIFT; +import static software.amazon.smithy.java.cbor.CborDeserializer.MINOR_TYPE_MASK; +import static software.amazon.smithy.java.cbor.CborDeserializer.ONE_BYTE; +import static software.amazon.smithy.java.cbor.CborDeserializer.TAG_NEG_BIGNUM; +import static software.amazon.smithy.java.cbor.CborDeserializer.TAG_POS_BIGNUM; +import static software.amazon.smithy.java.cbor.CborDeserializer.TYPE_NEGINT; +import static software.amazon.smithy.java.cbor.CborDeserializer.TYPE_POSINT; +import static software.amazon.smithy.java.cbor.CborDeserializer.TYPE_TAG; +import static software.amazon.smithy.java.cbor.CborDeserializer.ZERO_BYTES; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import software.amazon.smithy.utils.SmithyInternalApi; @SmithyInternalApi public final class CborReadUtil { + + static final VarHandle BE_SHORT = + MethodHandles.byteArrayViewVarHandle(short[].class, ByteOrder.BIG_ENDIAN); + static final VarHandle BE_INT = + MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.BIG_ENDIAN); + static final VarHandle BE_LONG = + MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); + public static int argLength(int minorType) { if (minorType <= ZERO_BYTES) return 0; @@ -72,23 +81,16 @@ public static long readLong(byte[] buffer, byte type, int off, int len) { return val; } - @SuppressFBWarnings("SF_SWITCH_FALLTHROUGH") private static long readLong0(byte[] buffer, int off, int len) { - long acc = 0; - // case order is important here, do not reorder switch (len) { - case 8: - acc = ((long) buffer[off++] & 0xff) << 56 - | ((long) buffer[off++] & 0xff) << 48 - | ((long) buffer[off++] & 0xff) << 40 - | ((long) buffer[off++] & 0xff) << 32; - case 4: - acc |= ((long) buffer[off++] & 0xff) << 24 - | ((long) buffer[off++] & 0xff) << 16; - case 2: - acc |= ((long) buffer[off++] & 0xff) << 8; case 1: - return acc | ((long) buffer[off] & 0xff); + return buffer[off] & 0xffL; + case 2: + return ((short) BE_SHORT.get(buffer, off)) & 0xffffL; + case 4: + return ((int) BE_INT.get(buffer, off)) & 0xffffffffL; + case 8: + return (long) BE_LONG.get(buffer, off); default: return invalidLength(len); } @@ -100,8 +102,8 @@ private static long invalidLength(int len) { public static BigInteger readBigInteger(byte[] buffer, byte context, int off, int len) { //If the value fits inside a long - boolean isPosInt = context == CborParser.Token.POS_INT; - if (isPosInt || context == CborParser.Token.NEG_INT) { + boolean isPosInt = context == CborDeserializer.Token.POS_INT; + if (isPosInt || context == CborDeserializer.Token.NEG_INT) { return readSmallBigInteger(buffer, context, off, len, isPosInt); } final byte[] buff; @@ -126,7 +128,7 @@ public static BigInteger readBigInteger(byte[] buffer, byte context, int off, in buff = indefBuf; } } - if (context == CborParser.Token.NEG_BIGINT) { + if (context == CborDeserializer.Token.NEG_BIGINT) { CborReadUtil.flipBytes(buff); } return new BigInteger(buff); @@ -161,9 +163,9 @@ public static BigDecimal readBigDecimal(byte[] buffer, int originalOff) { switch (major) { case TYPE_TAG: if (minor == TAG_POS_BIGNUM) { - context = CborParser.Token.POS_BIGINT; + context = CborDeserializer.Token.POS_BIGINT; } else if (minor == TAG_NEG_BIGNUM) { - context = CborParser.Token.NEG_BIGINT; + context = CborDeserializer.Token.NEG_BIGINT; } else { throw new BadCborException("Unexpected minor " + minor, true); } @@ -194,24 +196,25 @@ public static BigDecimal readBigDecimal(byte[] buffer, int originalOff) { } public static String readTextString(byte[] buffer, int off, int len) { - if (CborParser.isIndefinite(len)) { - return new String(readBytesIndefinite(buffer, off, CborParser.itemLength(len)), StandardCharsets.UTF_8); + if (CborDeserializer.isIndefinite(len)) { + return new String(readBytesIndefinite(buffer, off, CborDeserializer.itemLength(len)), + StandardCharsets.UTF_8); } else { return new String(buffer, off, len, StandardCharsets.UTF_8); } } public static byte[] readByteString(byte[] buffer, int off, int len) { - if (CborParser.isIndefinite(len)) { - return readBytesIndefinite(buffer, off, CborParser.itemLength(len)); + if (CborDeserializer.isIndefinite(len)) { + return readBytesIndefinite(buffer, off, CborDeserializer.itemLength(len)); } else { return readBytesFinite(buffer, off, len); } } public static void readByteString(byte[] buffer, int off, byte[] dest, int destOff, int len) { - if (CborParser.isIndefinite(len)) { - readBytesIndefinite(buffer, off, dest, destOff, CborParser.itemLength(len)); + if (CborDeserializer.isIndefinite(len)) { + readBytesIndefinite(buffer, off, dest, destOff, CborDeserializer.itemLength(len)); } else { readBytesFinite(buffer, off, dest, destOff, len); } @@ -251,63 +254,6 @@ private static void readBytesIndefinite(byte[] buffer, int off, byte[] dest, int throw new BadCborException("cannot read unclosed indefinite length string"); } - /** - * Compares a byte sequence within the CBOR payload to an arbitrary byte sequence. It is assumed that the second - * sequence (argument {@code str}) is a freestanding byte sequence and does not contain CBOR indefinite length - * coding. - * - * @param buf the CBOR payload - * @param bOff offset in {@code buf} where the byte sequence to compare begins - * @param bLen length of the byte sequence in {@code buf} - * @param str the byte sequence to compare against - * @param sOff the offset in {@code str} where the sequence begins - * @param sLen the length of the sequence in {@code str} to compare against - * @return true if they match - */ - public static boolean compareStringExternal(byte[] buf, int bOff, int bLen, byte[] str, int sOff, int sLen) { - if (CborParser.isIndefinite(bLen)) { - return compareIndefinite(buf, bOff, CborParser.itemLength(bLen), str, sOff, sLen); - } else { - return compareFinite(buf, bOff, bLen, str, sOff, sLen); - } - } - - public static boolean compareStringsInPayload(byte[] buf, int off1, int len1, int off2, int len2) { - boolean indefinite1 = CborParser.isIndefinite(len1); - boolean indefinite2 = CborParser.isIndefinite(len2); - if (!indefinite1 && !indefinite2) { - return compareFinite(buf, off1, len1, buf, off2, len2); - } else { - // TODO: don't read one up front, or at least try to make sure we always use `compareFinite` when possible - byte[] one = CborReadUtil.readByteString(buf, off1, len1); - return compareStringExternal(buf, off2, len2, one, 0, one.length); - } - } - - private static boolean compareIndefinite(byte[] buf, int bOff, int bLen, byte[] s, int sOff, int sLen) { - if (bLen != sLen) - return false; - int lim = sOff + sLen; - while (sOff < lim) { - byte b = buf[bOff]; - int minor = b & MINOR_TYPE_MASK; - int argLen = argLength(minor); - int chunkLen = readStrLen(buf, bOff, minor, argLen); - bOff += argLen + 1; - if (!compareFinite(buf, bOff, chunkLen, s, sOff, chunkLen)) - return false; - bOff += chunkLen; - sOff += chunkLen; - } - if (sOff != lim) - throw new BadCborException("cannot compare unclosed indefinite length string"); - return true; - } - - private static boolean compareFinite(byte[] buf, int bOff, int bLen, byte[] s, int sOff, int sLen) { - return Arrays.compare(buf, bOff, bLen, s, sOff, sLen) == 0; - } - private static byte getMajor(byte b) { return (byte) ((b & MAJOR_TYPE_MASK) >> MAJOR_TYPE_SHIFT); } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java new file mode 100644 index 000000000..4569494aa --- /dev/null +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.cbor; + +import java.util.List; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SchemaExtensionKey; +import software.amazon.smithy.java.core.schema.SchemaExtensionProvider; +import software.amazon.smithy.model.shapes.ShapeType; + +/** + * Pre-computes CBOR codec data on Schema objects. + * + *

For member schemas: pre-encoded CBOR text string header + member name bytes. + * For struct/union schemas: {@link CborMemberLookup} instances for hash-based field matching + * and field name tables for O(1) lookup by memberIndex during serialization. + */ +public final class CborSchemaExtensions + implements SchemaExtensionProvider { + + /** + * Extension key for CBOR codec data. + */ + public static final SchemaExtensionKey KEY = new SchemaExtensionKey<>(); + + /** + * Pre-computed CBOR data stored on a Schema. + * + * @param memberNameBytes CBOR text string header + name bytes (null for non-members) + * @param memberLookup Hash-based member lookup (null for non-structs) + * @param fieldNameTable Indexed by memberIndex: pre-computed name bytes per member (null for non-structs) + */ + public record NativeCborExtension( + byte[] memberNameBytes, + CborMemberLookup memberLookup, + byte[][] fieldNameTable) {} + + @Override + public SchemaExtensionKey key() { + return KEY; + } + + @Override + public NativeCborExtension provide(Schema schema) { + if (schema.isMember()) { + return forMember(schema); + } + var type = schema.type(); + if (type == ShapeType.STRUCTURE || type == ShapeType.UNION) { + return forStruct(schema); + } + return null; + } + + private static NativeCborExtension forMember(Schema schema) { + byte[] memberNameBytes = CborSerializer.encodeMemberName(schema.memberName()); + return new NativeCborExtension(memberNameBytes, null, null); + } + + private static NativeCborExtension forStruct(Schema schema) { + List members = schema.members(); + if (members.isEmpty()) { + return new NativeCborExtension(null, null, null); + } + + CborMemberLookup memberLookup = new CborMemberLookup(members); + + int maxIndex = 0; + for (Schema m : members) { + maxIndex = Math.max(maxIndex, m.memberIndex()); + } + byte[][] fieldNameTable = new byte[maxIndex + 1][]; + for (Schema m : members) { + int idx = m.memberIndex(); + fieldNameTable[idx] = CborSerializer.encodeMemberName(m.memberName()); + } + + return new NativeCborExtension(null, memberLookup, fieldNameTable); + } +} diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerdeProvider.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerdeProvider.java index 52ca97259..e1bcc2e2c 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerdeProvider.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerdeProvider.java @@ -7,7 +7,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; -import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.SerializableShape; import software.amazon.smithy.java.core.serde.ShapeDeserializer; import software.amazon.smithy.java.core.serde.ShapeSerializer; @@ -22,5 +22,5 @@ public interface CborSerdeProvider { ShapeSerializer newSerializer(OutputStream sink, CborSettings settings); - ByteBuffer serialize(SerializableStruct struct, CborSettings settings); + ByteBuffer serialize(SerializableShape shape, CborSettings settings); } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java index b23ba1b74..c9a9170eb 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java @@ -29,16 +29,17 @@ import static software.amazon.smithy.java.cbor.CborConstants.TYPE_TEXTSTRING; import static software.amazon.smithy.java.cbor.CborReadUtil.flipBytes; +import java.io.OutputStream; +import java.lang.invoke.VarHandle; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiConsumer; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SerializableStruct; -import software.amazon.smithy.java.core.serde.InterceptingSerializer; import software.amazon.smithy.java.core.serde.MapSerializer; import software.amazon.smithy.java.core.serde.SerializationException; import software.amazon.smithy.java.core.serde.ShapeSerializer; @@ -47,101 +48,168 @@ import software.amazon.smithy.model.shapes.ShapeType; final class CborSerializer implements ShapeSerializer { + + private static final VarHandle BE_SHORT = CborReadUtil.BE_SHORT; + private static final VarHandle BE_INT = CborReadUtil.BE_INT; + private static final VarHandle BE_LONG = CborReadUtil.BE_LONG; + private static final int MAP_STREAM = TYPE_MAP | INDEFINITE; private static final int ARRAY_STREAM = TYPE_ARRAY | INDEFINITE; - private boolean[] collection = new boolean[4]; - private int collectionIdx = -1; - private final Sink sink; - private final CborMapSerializer mapSerializer = new CborMapSerializer(); + private static final int DEFAULT_BUF_SIZE = 4096; + private static final int MAX_CACHEABLE_BUF = DEFAULT_BUF_SIZE * 4; + + private static final int POOL_SLOTS; + private static final int POOL_MASK; + private static final AtomicReferenceArray POOL; + private static final int MAX_PROBE = 3; + + static { + int processors = Runtime.getRuntime().availableProcessors(); + int raw = processors * 4; + POOL_SLOTS = Integer.highestOneBit(raw - 1) << 1; + POOL_MASK = POOL_SLOTS - 1; + POOL = new AtomicReferenceArray<>(POOL_SLOTS); + } + + byte[] buf; + int pos; + + private final OutputStream sink; + + // Bit i of collectionMask records whether level i was opened as indefinite-length. + private long collectionMask = 0L; + private int collectionDepth = 0; + private long[] collectionOverflow; + + private byte[][] currentFieldNameTable; + private final CborStructSerializer structSerializer = new CborStructSerializer(); + private final CborMapSerializer mapSerializer = new CborMapSerializer(); private SerializeDocumentContents serializeDocumentContents; - public CborSerializer(Sink sink) { - this.sink = sink; + CborSerializer() { + this.sink = null; + this.buf = new byte[DEFAULT_BUF_SIZE]; + this.pos = 0; } - private void startMap(int size) { - boolean indefinite = size < 0; - if (indefinite) { - sink.write(MAP_STREAM); - } else { - tagAndLength(TYPE_MAP, size); + CborSerializer(OutputStream sink) { + this.sink = sink; + this.buf = new byte[DEFAULT_BUF_SIZE]; + this.pos = 0; + } + + static CborSerializer acquire() { + if (!Thread.currentThread().isVirtual()) { + int base = poolProbe(); + for (int i = 0; i < MAX_PROBE; i++) { + int idx = (base + i) & POOL_MASK; + CborSerializer s = POOL.getPlain(idx); + if (s != null && POOL.compareAndExchangeAcquire(idx, s, null) == s) { + s.pos = 0; + s.collectionMask = 0L; + s.collectionDepth = 0; + s.currentFieldNameTable = null; + return s; + } + } } - startCollection(indefinite); + return new CborSerializer(); } - private void startArray(int size) { - boolean indefinite = size < 0; - if (indefinite) { - sink.write(ARRAY_STREAM); - } else { - tagAndLength(TYPE_ARRAY, size); + static void release(CborSerializer serializer, boolean exception) { + if (serializer.buf == null || serializer.sink != null || Thread.currentThread().isVirtual()) { + return; } - startCollection(indefinite); + if (serializer.buf.length > MAX_CACHEABLE_BUF) { + serializer.buf = new byte[DEFAULT_BUF_SIZE]; + } + int base = poolProbe(); + for (int i = 0; i < MAX_PROBE; i++) { + int idx = (base + i) & POOL_MASK; + if (POOL.getPlain(idx) == null + && POOL.compareAndExchangeRelease(idx, null, serializer) == null) { + return; + } + } + // Pool full, let GC collect } - private void startCollection(boolean indefinite) { - int idx = ++collectionIdx; - boolean[] coll = collection; - int l = coll.length; - if (idx == l) { - collection = (coll = Arrays.copyOf(coll, l + (l >> 1))); + ByteBuffer extractResult() { + return ByteBuffer.wrap(Arrays.copyOf(buf, pos)); + } + + private static int poolProbe() { + long id = Thread.currentThread().threadId(); + return (int) (id ^ (id >>> 16)) & POOL_MASK; + } + + private void ensureCapacity(int needed) { + if (pos + needed > buf.length) { + grow(needed); } - coll[idx] = indefinite; } - private void endMap() { - if (collection[collectionIdx--]) { - sink.write(TYPE_SIMPLE_BREAK_STREAM); + private void grow(int needed) { + buf = Arrays.copyOf(buf, Math.max(buf.length * 2, pos + needed)); + } + + @Override + public void flush() { + try { + if (sink != null && pos > 0) { + sink.write(buf, 0, pos); + pos = 0; + sink.flush(); + } + } catch (Exception e) { + throw new SerializationException(e); } } - private void endArray() { - if (collection[collectionIdx--]) { - sink.write(TYPE_SIMPLE_BREAK_STREAM); + @Override + public void close() { + try { + if (sink != null && pos > 0) { + sink.write(buf, 0, pos); + pos = 0; + } + } catch (Exception e) { + throw new SerializationException(e); } } private void tagAndLength(int type, int len) { + ensureCapacity(5); // max: 1 type byte + 4 length bytes (int) + tagAndLengthUnchecked(type, len); + } + + /** Write tag+length without ensureCapacity, caller must have reserved space. */ + private void tagAndLengthUnchecked(int type, int len) { if (len < ONE_BYTE) { - sink.write(type | len); + buf[pos++] = (byte) (type | len); } else if (len <= 0xFF) { - sink.write(type | ONE_BYTE); - sink.write(len); + buf[pos++] = (byte) (type | ONE_BYTE); + buf[pos++] = (byte) len; } else if (len <= 0xFFFF) { - sink.write(type | TWO_BYTES); - write2Nonnegative(len); + buf[pos++] = (byte) (type | TWO_BYTES); + BE_SHORT.set(buf, pos, (short) len); + pos += 2; } else { - sink.write(type | FOUR_BYTES); - write4Nonnegative(len); + buf[pos++] = (byte) (type | FOUR_BYTES); + BE_INT.set(buf, pos, len); + pos += 4; } } - private void write8Nonnegative(long l) { - sink.write((int) ((l >> 56) & 0xFF)); - sink.write((int) ((l >> 48) & 0xFF)); - sink.write((int) ((l >> 40) & 0xFF)); - sink.write((int) ((l >> 32) & 0xFF)); - sink.write((int) ((l >> 24) & 0xFF)); - sink.write((int) ((l >> 16) & 0xFF)); - sink.write((int) ((l >> 8) & 0xFF)); - sink.write((int) ((l) & 0xFF)); - } - - private void write4Nonnegative(int i) { - sink.write((i >> 24) & 0xFF); - sink.write((i >> 16) & 0xFF); - sink.write((i >> 8) & 0xFF); - sink.write((i) & 0xFF); - } - - private void write2Nonnegative(int i) { - sink.write((i >> 8) & 0xFF); - sink.write((i) & 0xFF); + private void writeLong(long l) { + ensureCapacity(9); // max: 1 type byte + 8 data bytes + writeLongUnchecked(l); } - private void writeLong(long l) { + /** Write a CBOR integer without ensureCapacity, caller must have reserved 9 bytes. */ + private void writeLongUnchecked(long l) { byte type; if (l < 0) { l = -l - 1; @@ -151,32 +219,128 @@ private void writeLong(long l) { } if (l < ONE_BYTE) { - sink.write(type | (int) l); + buf[pos++] = (byte) (type | (int) l); } else if (l <= 0xFFL) { - sink.write(type | ONE_BYTE); - sink.write((int) l); + buf[pos++] = (byte) (type | ONE_BYTE); + buf[pos++] = (byte) l; } else if (l <= 0xFFFFL) { - sink.write(type | TWO_BYTES); - write2Nonnegative((int) l); + buf[pos++] = (byte) (type | TWO_BYTES); + BE_SHORT.set(buf, pos, (short) l); + pos += 2; } else if (l <= 0xFFFF_FFFFL) { - sink.write(type | FOUR_BYTES); - write4Nonnegative((int) l); + buf[pos++] = (byte) (type | FOUR_BYTES); + BE_INT.set(buf, pos, (int) l); + pos += 4; } else { - sink.write(type | EIGHT_BYTES); - write8Nonnegative(l); + buf[pos++] = (byte) (type | EIGHT_BYTES); + BE_LONG.set(buf, pos, l); + pos += 8; } } + private void writeDoubleUnchecked(long bits) { + buf[pos++] = (byte) TYPE_SIMPLE_DOUBLE; + BE_LONG.set(buf, pos, bits); + pos += 8; + } + private void writeBytes0(int type, byte[] b, int off, int len) { - tagAndLength(type, len); - sink.write(b, off, len); + ensureCapacity(5 + len); + tagAndLengthUnchecked(type, len); + System.arraycopy(b, off, buf, pos, len); + pos += len; + } + + private void startMap(int size) { + boolean indefinite = size < 0; + if (indefinite) { + ensureCapacity(1); + buf[pos++] = (byte) MAP_STREAM; + } else { + tagAndLength(TYPE_MAP, size); + } + startCollection(indefinite); + } + + private void startArray(int size) { + boolean indefinite = size < 0; + if (indefinite) { + ensureCapacity(1); + buf[pos++] = (byte) ARRAY_STREAM; + } else { + tagAndLength(TYPE_ARRAY, size); + } + startCollection(indefinite); + } + + private void startCollection(boolean indefinite) { + int d = collectionDepth; + if (d < 64) { + if (indefinite) { + collectionMask |= 1L << d; + } else { + collectionMask &= ~(1L << d); + } + } else { + pushOverflow(d, indefinite); + } + collectionDepth = d + 1; + } + + private void pushOverflow(int d, boolean indefinite) { + int overflowIdx = d - 64; + long[] stack = collectionOverflow; + if (stack == null) { + stack = collectionOverflow = new long[Math.max(4, (overflowIdx >> 6) + 1)]; + } else if ((overflowIdx >> 6) >= stack.length) { + stack = collectionOverflow = Arrays.copyOf(stack, stack.length * 2); + } + long bit = 1L << (overflowIdx & 63); + int slot = overflowIdx >>> 6; + if (indefinite) { + stack[slot] |= bit; + } else { + stack[slot] &= ~bit; + } + } + + private boolean popIndefinite() { + int d = --collectionDepth; + if (d < 64) { + return ((collectionMask >>> d) & 1L) != 0L; + } + int overflowIdx = d - 64; + return ((collectionOverflow[overflowIdx >>> 6] >>> (overflowIdx & 63)) & 1L) != 0L; + } + + private void endMap() { + if (popIndefinite()) { + ensureCapacity(1); + buf[pos++] = (byte) TYPE_SIMPLE_BREAK_STREAM; + } + } + + private void endArray() { + if (popIndefinite()) { + ensureCapacity(1); + buf[pos++] = (byte) TYPE_SIMPLE_BREAK_STREAM; + } } @Override public void writeStruct(Schema schema, SerializableStruct struct) { - sink.write(MAP_STREAM); + ensureCapacity(1); + buf[pos++] = (byte) MAP_STREAM; startCollection(true); + + byte[][] savedTable = currentFieldNameTable; + Schema structSchema = schema.isMember() ? schema.memberTarget() : schema; + var ext = structSchema.getExtension(CborSchemaExtensions.KEY); + currentFieldNameTable = ext != null ? ext.fieldNameTable() : null; + struct.serializeMembers(structSerializer); + + currentFieldNameTable = savedTable; endMap(); } @@ -196,7 +360,8 @@ public void writeMap(Schema schema, T mapState, int size, BiConsumer> 6)); + buf[p++] = (byte) (0x80 | (c & 0x3F)); + } else if (c >= 0xD800 && c <= 0xDBFF && j + 1 < charLen) { + char low = value.charAt(j + 1); + if (low >= 0xDC00 && low <= 0xDFFF) { + int cp = Character.toCodePoint((char) c, low); + buf[p++] = (byte) (0xF0 | (cp >> 18)); + buf[p++] = (byte) (0x80 | ((cp >> 12) & 0x3F)); + buf[p++] = (byte) (0x80 | ((cp >> 6) & 0x3F)); + buf[p++] = (byte) (0x80 | (cp & 0x3F)); + j++; + } else { + buf[p++] = (byte) 0xEF; + buf[p++] = (byte) 0xBF; + buf[p++] = (byte) 0xBD; + } + } else if (c >= 0xDC00 && c <= 0xDFFF) { + // Unpaired low surrogate, replace with U+FFFD to match JDK getBytes(UTF_8) + buf[p++] = (byte) 0xEF; + buf[p++] = (byte) 0xBF; + buf[p++] = (byte) 0xBD; + } else { + buf[p++] = (byte) (0xE0 | (c >> 12)); + buf[p++] = (byte) (0x80 | ((c >> 6) & 0x3F)); + buf[p++] = (byte) (0x80 | (c & 0x3F)); + } + } + int byteLen = p - writeStart; + pos = headerStart; + tagAndLengthUnchecked(TYPE_TEXTSTRING, byteLen); + int actualHeaderLen = pos - headerStart; + int shift = 5 - actualHeaderLen; + if (shift > 0) { + System.arraycopy(buf, writeStart, buf, writeStart - shift, byteLen); + } + pos = headerStart + actualHeaderLen + byteLen; } @Override public void writeBlob(Schema schema, ByteBuffer value) { - tagAndLength(TYPE_BYTESTRING, value.remaining()); - sink.write(value); + int len = value.remaining(); + ensureCapacity(5 + len); + tagAndLengthUnchecked(TYPE_BYTESTRING, len); + if (value.hasArray()) { + System.arraycopy(value.array(), value.arrayOffset() + value.position(), buf, pos, len); + } else { + value.duplicate().get(buf, pos, len); + } + pos += len; } @Override @@ -250,14 +501,23 @@ public void writeBlob(Schema schema, byte[] value) { @Override public void writeTimestamp(Schema schema, Instant value) { - double epochSeconds = value.toEpochMilli() / 1000D; - sink.write(TYPE_TAG | TAG_TIME_EPOCH); - writeDouble(schema, epochSeconds); + long millis = value.toEpochMilli(); + if (millis % 1000 == 0) { + ensureCapacity(10); + buf[pos++] = (byte) (TYPE_TAG | TAG_TIME_EPOCH); + writeLongUnchecked(millis / 1000); + } else { + double epochSeconds = millis / 1000D; + ensureCapacity(10); + buf[pos++] = (byte) (TYPE_TAG | TAG_TIME_EPOCH); + writeDoubleUnchecked(Double.doubleToRawLongBits(epochSeconds)); + } } @Override public void writeNull(Schema schema) { - sink.write(TYPE_SIMPLE_NULL); + ensureCapacity(1); + buf[pos++] = (byte) TYPE_SIMPLE_NULL; } @Override @@ -267,8 +527,9 @@ public void writeBigInteger(Schema schema, BigInteger value) { @Override public void writeBigDecimal(Schema schema, BigDecimal value) { - sink.write(TYPE_TAG | TAG_DECIMAL); - tagAndLength(TYPE_ARRAY, 2); + ensureCapacity(2); + buf[pos++] = (byte) (TYPE_TAG | TAG_DECIMAL); + tagAndLengthUnchecked(TYPE_ARRAY, 2); writeLong(-value.scale()); writeBigInteger(value.unscaledValue()); } @@ -298,8 +559,11 @@ private void writeBigInteger(BigInteger value) { } else { type = TYPE_POSINT; } - sink.write(type | EIGHT_BYTES); - write8Nonnegative(value.longValue() ^ signum); + ensureCapacity(9); + buf[pos++] = (byte) (type | EIGHT_BYTES); + long v = value.longValue() ^ signum; + BE_LONG.set(buf, pos, v); + pos += 8; } else { byte[] bytes = value.toByteArray(); byte tag; @@ -309,19 +573,203 @@ private void writeBigInteger(BigInteger value) { } else { tag = TAG_POS_BIG_INT; } - sink.write(TYPE_TAG | tag); + ensureCapacity(1); + buf[pos++] = (byte) (TYPE_TAG | tag); writeBytes0(TYPE_BYTESTRING, bytes, 0, bytes.length); } } } - private final class CborStructSerializer extends InterceptingSerializer { + private byte[] resolveFieldNameBytes(Schema schema) { + byte[][] table = currentFieldNameTable; + int idx = schema.memberIndex(); + if (table != null && idx >= 0 && idx < table.length && table[idx] != null) { + return table[idx]; + } + var ext = schema.getExtension(CborSchemaExtensions.KEY); + if (ext != null && ext.memberNameBytes() != null) { + return ext.memberNameBytes(); + } + return encodeMemberName(schema.memberName()); + } + + @SuppressWarnings("deprecation") + static byte[] encodeMemberName(String name) { + int len = name.length(); + int headerSize; + if (len < ONE_BYTE) { + headerSize = 1; + } else if (len <= 0xFF) { + headerSize = 2; + } else if (len <= 0xFFFF) { + headerSize = 3; + } else { + headerSize = 5; + } + byte[] result = new byte[headerSize + len]; + int p = 0; + if (len < ONE_BYTE) { + result[p++] = (byte) (TYPE_TEXTSTRING | len); + } else if (len <= 0xFF) { + result[p++] = (byte) (TYPE_TEXTSTRING | ONE_BYTE); + result[p++] = (byte) len; + } else if (len <= 0xFFFF) { + result[p++] = (byte) (TYPE_TEXTSTRING | TWO_BYTES); + result[p++] = (byte) (len >> 8); + result[p++] = (byte) len; + } else { + result[p++] = (byte) (TYPE_TEXTSTRING | FOUR_BYTES); + result[p++] = (byte) (len >> 24); + result[p++] = (byte) (len >> 16); + result[p++] = (byte) (len >> 8); + result[p++] = (byte) len; + } + // Smithy member names are always ASCII + name.getBytes(0, len, result, p); + return result; + } + + private final class CborStructSerializer implements ShapeSerializer { + + @Override + public void writeBoolean(Schema schema, boolean value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 1); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = (byte) (value ? TYPE_SIMPLE_TRUE : TYPE_SIMPLE_FALSE); + } + + @Override + public void writeByte(Schema schema, byte value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 9); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + writeLongUnchecked(value); + } + + @Override + public void writeShort(Schema schema, short value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 9); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + writeLongUnchecked(value); + } + @Override - protected ShapeSerializer before(Schema schema) { - String name = schema.memberName(); - tagAndLength(TYPE_TEXTSTRING, name.length()); - sink.writeAscii(name); - return CborSerializer.this; + public void writeInteger(Schema schema, int value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 9); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + writeLongUnchecked(value); + } + + @Override + public void writeLong(Schema schema, long value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 9); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + writeLongUnchecked(value); + } + + @Override + public void writeFloat(Schema schema, float value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 5); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = (byte) TYPE_SIMPLE_FLOAT; + BE_INT.set(buf, pos, Float.floatToRawIntBits(value)); + pos += 4; + } + + @Override + public void writeDouble(Schema schema, double value) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 9); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + writeDoubleUnchecked(Double.doubleToRawLongBits(value)); + } + + @Override + public void writeNull(Schema schema) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length + 1); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = (byte) TYPE_SIMPLE_NULL; + } + + @Override + public void writeString(Schema schema, String value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeString(schema, value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeBlob(schema, value); + } + + @Override + public void writeBlob(Schema schema, byte[] value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeBlob(schema, value); + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeTimestamp(schema, value); + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeBigInteger(schema, value); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeBigDecimal(schema, value); + } + + @Override + public void writeStruct(Schema schema, SerializableStruct struct) { + writeFieldNameBytes(schema); + CborSerializer.this.writeStruct(schema, struct); + } + + @Override + public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + writeFieldNameBytes(schema); + CborSerializer.this.writeList(schema, listState, size, consumer); + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + writeFieldNameBytes(schema); + CborSerializer.this.writeMap(schema, mapState, size, consumer); + } + + @Override + public void writeDocument(Schema schema, Document value) { + writeFieldNameBytes(schema); + CborSerializer.this.writeDocument(schema, value); + } + + private void writeFieldNameBytes(Schema schema) { + byte[] nameBytes = resolveFieldNameBytes(schema); + ensureCapacity(nameBytes.length); + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; } } @@ -333,12 +781,13 @@ public void writeEntry( T state, BiConsumer valueSerializer ) { - byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8); - writeBytes0(TYPE_TEXTSTRING, keyBytes, 0, keyBytes.length); + writeStringValue(key); valueSerializer.accept(state, CborSerializer.this); } } + private static final byte[] TYPE_FIELD_BYTES = encodeMemberName("__type"); + private static final class SerializeDocumentContents extends SpecificShapeSerializer { private final CborSerializer parent; @@ -348,16 +797,15 @@ private static final class SerializeDocumentContents extends SpecificShapeSerial @Override public void writeStruct(Schema schema, SerializableStruct struct) { - try { - parent.startMap(-1); - parent.tagAndLength(TYPE_TEXTSTRING, 6); - parent.sink.writeAscii("__type"); - parent.writeString(null, schema.id().toString()); - struct.serializeMembers(parent.structSerializer); - parent.endMap(); - } catch (Exception e) { - throw new SerializationException(e); - } + parent.ensureCapacity(1); + parent.buf[parent.pos++] = (byte) MAP_STREAM; + parent.startCollection(true); + parent.ensureCapacity(TYPE_FIELD_BYTES.length); + System.arraycopy(TYPE_FIELD_BYTES, 0, parent.buf, parent.pos, TYPE_FIELD_BYTES.length); + parent.pos += TYPE_FIELD_BYTES.length; + parent.writeString(null, schema.id().toString()); + struct.serializeMembers(parent.structSerializer); + parent.endMap(); } } } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/DefaultCborSerdeProvider.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/DefaultCborSerdeProvider.java index 37f1fd2a1..a52bf5701 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/DefaultCborSerdeProvider.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/DefaultCborSerdeProvider.java @@ -7,7 +7,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; -import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.SerializableShape; import software.amazon.smithy.java.core.serde.ShapeDeserializer; import software.amazon.smithy.java.core.serde.ShapeSerializer; @@ -34,14 +34,21 @@ public ShapeDeserializer newDeserializer(ByteBuffer source, CborSettings setting @Override public ShapeSerializer newSerializer(OutputStream sink, CborSettings settings) { - return new CborSerializer(new Sink.OutputStreamSink(sink)); + return new CborSerializer(sink); } @Override - public ByteBuffer serialize(SerializableStruct struct, CborSettings settings) { - var sink = new Sink.ResizingSink(); - var serializer = new CborSerializer(sink); - struct.serialize(serializer); - return sink.finish(); + public ByteBuffer serialize(SerializableShape shape, CborSettings settings) { + var serializer = CborSerializer.acquire(); + boolean exception = false; + try { + shape.serialize(serializer); + return serializer.extractResult(); + } catch (Exception t) { + exception = true; + throw t; + } finally { + CborSerializer.release(serializer, exception); + } } } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Rpcv2CborCodec.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Rpcv2CborCodec.java index fe6e0cd54..b27d08b98 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Rpcv2CborCodec.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Rpcv2CborCodec.java @@ -7,6 +7,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; +import software.amazon.smithy.java.core.schema.SerializableShape; import software.amazon.smithy.java.core.serde.Codec; import software.amazon.smithy.java.core.serde.ShapeDeserializer; import software.amazon.smithy.java.core.serde.ShapeSerializer; @@ -22,6 +23,11 @@ public static Builder builder() { return new Builder(); } + @Override + public ByteBuffer serialize(SerializableShape shape) { + return settings.provider().serialize(shape, settings); + } + @Override public ShapeSerializer createSerializer(OutputStream sink) { return settings.provider().newSerializer(sink, settings); diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Sink.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Sink.java deleted file mode 100644 index 853110356..000000000 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/Sink.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.cbor; - -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import software.amazon.smithy.java.core.serde.SerializationException; - -sealed interface Sink permits Sink.OutputStreamSink, Sink.ResizingSink, Sink.NullSink { - void write(byte[] b, int off, int len); - - void write(byte[] b); - - void write(int b); - - void write(ByteBuffer b); - - void writeAscii(String s); - - default ByteBuffer finish() { - return null; - } - - final class OutputStreamSink implements Sink { - private final OutputStream os; - - public OutputStreamSink(OutputStream os) { - this.os = os; - } - - @Override - public void write(byte[] b, int off, int len) { - try { - os.write(b, off, len); - } catch (Exception e) { - throw new SerializationException(e); - } - } - - @Override - public void write(byte[] b) { - try { - os.write(b); - } catch (Exception e) { - throw new SerializationException(e); - } - } - - @Override - public void write(int b) { - try { - os.write(b); - } catch (Exception e) { - throw new SerializationException(e); - } - } - - @Override - public void write(ByteBuffer b) { - try { - if (b.hasArray()) { - os.write(b.array(), b.arrayOffset() + b.position(), b.remaining()); - } else { - copyNonArrayBB(b); - } - } catch (Exception e) { - throw new SerializationException(e); - } - } - - @Override - public void writeAscii(String s) { - try { - os.write(s.getBytes(StandardCharsets.UTF_8)); - } catch (Exception e) { - throw new SerializationException(e); - } - } - - private void copyNonArrayBB(ByteBuffer b) throws Exception { - b = b.duplicate(); - int rem = b.remaining(); - byte[] copy = new byte[rem]; - b.get(copy); - os.write(copy); - } - } - - final class ResizingSink implements Sink { - private byte[] bytes = new byte[128]; - private int pos; - - @Override - public void write(byte[] b, int off, int len) { - ensureCapacity(len); - System.arraycopy(b, off, bytes, pos, len); - pos += len; - } - - @Override - public void write(byte[] b) { - write(b, 0, b.length); - } - - @Override - public void write(int b) { - ensureCapacity(1); - bytes[pos++] = (byte) b; - } - - @Override - public void write(ByteBuffer b) { - if (b.hasArray()) { - write(b.array(), b.position() + b.arrayOffset(), b.limit()); - } else { - copyNonArrayBB(b); - } - } - - @Override - @SuppressWarnings("deprecation") - public void writeAscii(String s) { - int len = s.length(); - ensureCapacity(len); - s.getBytes(0, s.length(), bytes, pos); - pos += len; - } - - private void copyNonArrayBB(ByteBuffer b) { - int rem = b.remaining(); - ensureCapacity(rem); - b.duplicate().get(bytes, pos, rem); - pos += rem; - } - - @Override - public ByteBuffer finish() { - return ByteBuffer.wrap(bytes, 0, pos); - } - - private void ensureCapacity(int len) { - int cap = bytes.length; - int required = pos + len; - if (required > cap) { - bytes = Arrays.copyOf(bytes, Math.max(required, cap + (cap >> 1))); - } - } - } - - final class NullSink implements Sink { - @Override - public void write(byte[] b, int off, int len) { - - } - - @Override - public void write(byte[] b) { - - } - - @Override - public void write(int b) { - - } - - @Override - public void write(ByteBuffer b) { - - } - - @Override - public void writeAscii(String s) { - - } - - @Override - public ByteBuffer finish() { - return ByteBuffer.wrap(new byte[0]); - } - } -} diff --git a/codecs/cbor-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider b/codecs/cbor-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider new file mode 100644 index 000000000..4997f936d --- /dev/null +++ b/codecs/cbor-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider @@ -0,0 +1 @@ +software.amazon.smithy.java.cbor.CborSchemaExtensions diff --git a/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java b/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java new file mode 100644 index 000000000..e37a5f9ae --- /dev/null +++ b/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java @@ -0,0 +1,590 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.cbor; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; +import java.util.function.Consumer; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.smithy.java.core.schema.PreludeSchemas; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.serde.MapSerializer; +import software.amazon.smithy.java.core.serde.ShapeDeserializer; +import software.amazon.smithy.java.core.serde.ShapeSerializer; +import software.amazon.smithy.java.io.ByteBufferOutputStream; +import software.amazon.smithy.java.io.ByteBufferUtils; +import software.amazon.smithy.model.shapes.ShapeId; + +class CborCodecTest { + + private static final CborSettings SETTINGS = CborSettings.defaultSettings(); + private static final DefaultCborSerdeProvider CODEC = new DefaultCborSerdeProvider(); + + private static ByteBuffer roundTrip(SerializableStruct shape) { + return CODEC.serialize(shape, SETTINGS); + } + + @Nested + class SerializerPoolTest { + @Test + void acquireAndReleaseReturnsToPool() { + var s1 = CborSerializer.acquire(); + CborSerializer.release(s1, false); + + var s2 = CborSerializer.acquire(); + // Pool may or may not return the same instance depending on thread, + // but it should not throw. + assertNotNull(s2); + CborSerializer.release(s2, false); + } + + @Test + void acquireProducesCleanState() { + var s = CborSerializer.acquire(); + assertEquals(0, s.pos); + CborSerializer.release(s, false); + } + + @Test + void largeBufferIsDownsizedOnRelease() { + var s = CborSerializer.acquire(); + s.buf = new byte[1024 * 1024]; + CborSerializer.release(s, false); + + var s2 = CborSerializer.acquire(); + // If we got back the same pooled instance, its buffer should have been downsized + // If we got a new instance, it has the default buffer size. Either way, not 1MB. + assertNotNull(s2); + CborSerializer.release(s2, false); + } + + @Test + void streamingSerializerNotPooled() { + var s = new CborSerializer(java.io.OutputStream.nullOutputStream()); + CborSerializer.release(s, false); + // Should not throw, and the streaming serializer should not be returned from acquire + var s2 = CborSerializer.acquire(); + assertNotSame(s, s2); + CborSerializer.release(s2, false); + } + } + + @Nested + class Utf8EncodingTest { + @ParameterizedTest + @ValueSource(strings = { + "hello", + "ASCII only 12345!@#", + "", + "a" + }) + void asciiStrings(String input) { + var result = serializeString(input); + var de = CODEC.newDeserializer(result, SETTINGS); + assertEquals(input, de.readString(PreludeSchemas.STRING)); + } + + @ParameterizedTest + @ValueSource(strings = { + "é", // e-acute (2-byte UTF-8) + "üö", // u-umlaut, o-umlaut + "世界", // Chinese: "world" + "café", // mixed ASCII + 2-byte + "😀", // emoji (surrogate pair, 4-byte UTF-8) + "a😀z", // ASCII + emoji + ASCII + }) + void nonAsciiStrings(String input) { + var result = serializeString(input); + var de = CODEC.newDeserializer(result, SETTINGS); + assertEquals(input, de.readString(PreludeSchemas.STRING)); + } + + @Test + void unpairedHighSurrogate() { + String broken = "a\uD800b"; + var result = serializeString(broken); + var de = CODEC.newDeserializer(result, SETTINGS); + String decoded = de.readString(PreludeSchemas.STRING); + assertEquals("a�b", decoded); + } + + @Test + void unpairedLowSurrogate() { + String broken = "a\uDC00b"; + var result = serializeString(broken); + var de = CODEC.newDeserializer(result, SETTINGS); + String decoded = de.readString(PreludeSchemas.STRING); + assertEquals("a�b", decoded); + } + + private byte[] serializeString(String value) { + var s = CborSerializer.acquire(); + try { + s.writeString(PreludeSchemas.STRING, value); + return java.util.Arrays.copyOf(s.buf, s.pos); + } finally { + CborSerializer.release(s, false); + } + } + } + + @Nested + class CollectionDepthTest { + @Test + void deeplyNestedCollections() { + int depth = 100; + var s = CborSerializer.acquire(); + try { + for (int i = 0; i < depth; i++) { + s.writeList(PreludeSchemas.DOCUMENT, null, -1, ($, inner) -> {}); + } + } finally { + CborSerializer.release(s, false); + } + } + + @Test + void collectionDepthExceeds64() { + int depth = 70; + var s = CborSerializer.acquire(); + try { + for (int i = 0; i < depth; i++) { + boolean indefinite = i % 2 == 0; + s.writeList(PreludeSchemas.DOCUMENT, + null, + indefinite ? -1 : 1, + ($, inner) -> {}); + } + } finally { + CborSerializer.release(s, false); + } + } + } + + @Nested + class MemberLookupTest { + private static final ShapeId STRUCT_ID = ShapeId.from("smithy.test#TestStruct"); + private static final Schema STRUCT = Schema.structureBuilder(STRUCT_ID) + .putMember("alpha", PreludeSchemas.STRING) + .putMember("beta", PreludeSchemas.INTEGER) + .putMember("gamma", PreludeSchemas.BOOLEAN) + .build(); + + @Test + void lookupInDefinitionOrder() { + var lookup = new CborMemberLookup(STRUCT.members()); + byte[] alpha = "alpha".getBytes(StandardCharsets.UTF_8); + byte[] beta = "beta".getBytes(StandardCharsets.UTF_8); + byte[] gamma = "gamma".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(alpha, 0, alpha.length, 0); + assertEquals("alpha", found.memberName()); + + found = lookup.lookup(beta, 0, beta.length, 1); + assertEquals("beta", found.memberName()); + + found = lookup.lookup(gamma, 0, gamma.length, 2); + assertEquals("gamma", found.memberName()); + } + + @Test + void lookupOutOfOrder() { + var lookup = new CborMemberLookup(STRUCT.members()); + byte[] gamma = "gamma".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(gamma, 0, gamma.length, 0); + assertEquals("gamma", found.memberName()); + } + + @Test + void lookupDisabledSpeculative() { + var lookup = new CborMemberLookup(STRUCT.members()); + byte[] beta = "beta".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(beta, 0, beta.length, -1); + assertEquals("beta", found.memberName()); + } + + @Test + void lookupUnknownMemberReturnsNull() { + var lookup = new CborMemberLookup(STRUCT.members()); + byte[] unknown = "unknown".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(unknown, 0, unknown.length, 0); + assertNull(found); + } + + @Test + void lookupWithOffset() { + var lookup = new CborMemberLookup(STRUCT.members()); + byte[] padded = "XXalphaYY".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(padded, 2, 7, -1); + assertEquals("alpha", found.memberName()); + } + + @Test + void emptyMemberList() { + var lookup = new CborMemberLookup(List.of()); + byte[] any = "any".getBytes(StandardCharsets.UTF_8); + + Schema found = lookup.lookup(any, 0, any.length, -1); + assertNull(found); + } + } + + @Nested + class SchemaExtensionsTest { + private static final ShapeId STRUCT_ID = ShapeId.from("smithy.test#ExtStruct"); + private static final Schema STRUCT = Schema.structureBuilder(STRUCT_ID) + .putMember("foo", PreludeSchemas.STRING) + .putMember("bar", PreludeSchemas.INTEGER) + .build(); + + @Test + void memberExtensionHasEncodedNameBytes() { + var provider = new CborSchemaExtensions(); + var ext = provider.provide(STRUCT.member("foo")); + + assertNotNull(ext); + assertNotNull(ext.memberNameBytes()); + assertNull(ext.memberLookup()); + assertNull(ext.fieldNameTable()); + + byte[] expected = CborSerializer.encodeMemberName("foo"); + assertArrayEquals(expected, ext.memberNameBytes()); + } + + @Test + void structExtensionHasLookupAndFieldNameTable() { + var provider = new CborSchemaExtensions(); + var ext = provider.provide(STRUCT); + + assertNotNull(ext); + assertNull(ext.memberNameBytes()); + assertNotNull(ext.memberLookup()); + assertNotNull(ext.fieldNameTable()); + + assertEquals(2, ext.fieldNameTable().length); + assertNotNull(ext.fieldNameTable()[0]); + assertNotNull(ext.fieldNameTable()[1]); + } + + @Test + void nonStructNonMemberReturnsNull() { + var provider = new CborSchemaExtensions(); + var ext = provider.provide(PreludeSchemas.STRING); + assertNull(ext); + } + + @Test + void emptyStructReturnsEmptyExtension() { + Schema emptyStruct = Schema.structureBuilder(ShapeId.from("smithy.test#Empty")).build(); + var provider = new CborSchemaExtensions(); + var ext = provider.provide(emptyStruct); + + assertNotNull(ext); + assertNull(ext.memberNameBytes()); + assertNull(ext.memberLookup()); + assertNull(ext.fieldNameTable()); + } + } + + @Nested + class EncodeMemberNameTest { + @Test + void shortName() { + byte[] result = CborSerializer.encodeMemberName("id"); + // CBOR text string header for length 2: 0x62, then 'i', 'd' + assertEquals(3, result.length); + assertEquals(0x62, result[0] & 0xFF); + assertEquals('i', result[1]); + assertEquals('d', result[2]); + } + + @Test + void longerName() { + String name = "a".repeat(200); + byte[] result = CborSerializer.encodeMemberName(name); + // Length 200 needs 1-byte arg: header = 0x78 (text string, 1-byte length), 0xC8 (200) + assertEquals(202, result.length); + assertEquals(0x78, result[0] & 0xFF); + assertEquals(200, result[1] & 0xFF); + } + } + + @Nested + class StructRoundTripTest { + @Test + void serializeDeserializeInOrder() { + var bird = new CborTestData.BirdBuilder() + .name("falcon") + .flightRange(BigInteger.valueOf(42)) + .build(); + + var ser = roundTrip(bird); + var de = new CborTestData.BirdBuilder() + .deserialize(CODEC.newDeserializer(ser, SETTINGS)) + .build(); + + assertEquals("falcon", de.name); + assertEquals(BigInteger.valueOf(42), de.flightRange); + } + + @Test + void serializeDeserializeAllFields() { + var bird = new CborTestData.BirdBuilder() + .name("hawk") + .bytes(ByteBuffer.wrap(new byte[] {1, 2, 3})) + .lastSquawkAt(Instant.ofEpochSecond(1000)) + .flightRange(BigInteger.TEN) + .wingspan(new BigDecimal("3.14")) + .build(); + + var ser = roundTrip(bird); + var de = new CborTestData.BirdBuilder() + .deserialize(CODEC.newDeserializer(ser, SETTINGS)) + .build(); + + assertEquals("hawk", de.name); + assertEquals(Instant.ofEpochSecond(1000), de.lastSquawkAt); + assertEquals(BigInteger.TEN, de.flightRange); + assertEquals(new BigDecimal("3.14"), de.wingspan); + } + + @Test + void timestampWholeSeconds() { + Instant wholeSecond = Instant.ofEpochSecond(1700000000); + var bird = new CborTestData.BirdBuilder() + .lastSquawkAt(wholeSecond) + .build(); + + var ser = roundTrip(bird); + var de = new CborTestData.BirdBuilder() + .deserialize(CODEC.newDeserializer(ser, SETTINGS)) + .build(); + + assertEquals(wholeSecond, de.lastSquawkAt); + } + + @Test + void timestampWithMillis() { + Instant withMillis = Instant.ofEpochMilli(1700000000123L); + var bird = new CborTestData.BirdBuilder() + .lastSquawkAt(withMillis) + .build(); + + var ser = roundTrip(bird); + var de = new CborTestData.BirdBuilder() + .deserialize(CODEC.newDeserializer(ser, SETTINGS)) + .build(); + + assertEquals(withMillis, de.lastSquawkAt); + } + } + + @Nested + class BigIntegerEdgeCaseTest { + @ParameterizedTest + @ValueSource(strings = { + "9223372036854775808", // Long.MAX_VALUE + 1 (exactly 64 bits positive) + "-9223372036854775809", // Long.MIN_VALUE - 1 (exactly 64 bits negative) + "18446744073709551615", // max unsigned 64-bit + }) + void bigInteger64BitBoundary(String value) { + BigInteger bi = new BigInteger(value); + var bird = new CborTestData.BirdBuilder().flightRange(bi).build(); + var ser = roundTrip(bird); + var de = new CborTestData.BirdBuilder() + .deserialize(CODEC.newDeserializer(ser, SETTINGS)) + .build(); + assertEquals(bi, de.flightRange); + } + } + + @Nested + class MalformedCborTest { + + @Test + void bigDecimalLongExponent() { + byte[][] payloads = new byte[][] { + new byte[] {-60, -126, 27, 0, 0, 0, 7, -1, -1, -1, -1, 1}, + new byte[] {-60, -126, 59, 0, 0, 0, 7, -1, -1, -1, -1, 1}, + new byte[] {-60, -126, 58, 127, -1, -1, -1, 1}, // 1^-2147483648 + new byte[] {-60, -126, 58, -128, 0, 0, 0, 1} // 1^-2147483649 + }; + for (byte[] payload : payloads) { + ShapeDeserializer de = CODEC.newDeserializer(payload, SETTINGS); + assertThrows(BadCborException.class, () -> de.readBigDecimal(PreludeSchemas.BIG_DECIMAL)); + } + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void bigIntegerWithExcessiveLength(boolean negative) { + byte[] payload = new byte[] { + (byte) (negative ? 0xC3 : 0xC2), + (byte) 0x5A, + (byte) 0x7F, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0x01, + (byte) 0x02 + }; + ShapeDeserializer de = CODEC.newDeserializer(payload, SETTINGS); + assertThrows(BadCborException.class, () -> de.readBigInteger(PreludeSchemas.BIG_INTEGER)); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void bigIntegerWithIndefiniteLengthExcessiveChunk(boolean negative) { + byte[] payload = new byte[] { + (byte) (negative ? 0xC3 : 0xC2), + (byte) 0x5F, + (byte) 0x5A, + (byte) 0x7F, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0x01, + (byte) 0x02 + }; + // TODO Fix this so that we don't throw AIOBE + assertThrows(ArrayIndexOutOfBoundsException.class, () -> CODEC.newDeserializer(payload, SETTINGS)); + } + + @Test + void bigIntegerWithIndefiniteLengthExcessiveTotalSize() { + byte[] payload = new byte[] { + (byte) 0xC2, + (byte) 0x5F, + (byte) 0x5A, + (byte) 0x1D, + (byte) 0xCD, + (byte) 0x65, + (byte) 0x00, + (byte) 0x01, + (byte) 0x02, + (byte) 0x5A, + (byte) 0x1D, + (byte) 0xCD, + (byte) 0x65, + (byte) 0x00, + (byte) 0x03, + (byte) 0x04 + }; + assertThrows(BadCborException.class, () -> CODEC.newDeserializer(payload, SETTINGS)); + } + + @Test + void incompleteImmediate() { + byte[] cbor = write(os -> { + writeList(os, Integer.MAX_VALUE, list -> list.writeString(null, "stop")); + }); + BadCborException e = assertThrows(BadCborException.class, () -> { + ShapeDeserializer de = CODEC.newDeserializer(cbor, SETTINGS); + de.readDocument(); + }); + assertTrue(e.getMessage().contains("incomplete array"), e.getMessage()); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void incompleteCollection(boolean map) { + byte[] cbor = write(os -> { + if (map) { + writeMap(os, 2, c -> { + c.entry("hi", value -> value.writeString(null, "hi")); + }); + } else { + writeList(os, 2, c -> { + c.writeString(null, "hi"); + }); + } + }); + BadCborException e = assertThrows(BadCborException.class, () -> { + ShapeDeserializer de = CODEC.newDeserializer(cbor, SETTINGS); + de.readDocument(); + }); + assertTrue(e.getMessage().contains("incomplete " + (map ? "map" : "array")), e.getMessage()); + } + + @Test + void missingMapValue() { + byte[] cbor = write(os -> { + writeMap(os, 1, m -> { + m.entry("hi", v -> { + throw new StopWritingException(); + }); + }); + }); + BadCborException e = assertThrows(BadCborException.class, () -> { + ShapeDeserializer de = CODEC.newDeserializer(cbor, SETTINGS); + de.readDocument(); + }); + assertTrue(e.getMessage().contains("incomplete map"), e.getMessage()); + } + } + + private static final class StopWritingException extends RuntimeException { + @Override + public Throwable fillInStackTrace() { + return this; + } + } + + private static byte[] write(Consumer consumer) { + try ( + var stream = new ByteBufferOutputStream(); + var ser = new CborSerializer(stream)) { + try { + consumer.accept(ser); + } catch (StopWritingException ignored) {} + ser.flush(); + return ByteBufferUtils.getBytes(stream.toByteBuffer()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void writeList(ShapeSerializer s, int len, Consumer listHandler) { + s.writeList(null, null, len, ($, l) -> listHandler.accept(l)); + } + + private static void writeMap(ShapeSerializer s, int len, Consumer mapHandler) { + s.writeMap(null, null, len, ($, l) -> mapHandler.accept(new WriteEntry(l))); + } + + private static final class WriteEntry { + private final MapSerializer m; + + private WriteEntry(MapSerializer m) { + this.m = m; + } + + void entry(String key, Consumer val) { + m.writeEntry(null, key, null, ($, s) -> val.accept(s)); + } + } +} diff --git a/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborParserTest.java b/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborParserTest.java deleted file mode 100644 index 3f943a82f..000000000 --- a/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborParserTest.java +++ /dev/null @@ -1,672 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.cbor; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static software.amazon.smithy.java.cbor.CborReadUtil.readByteString; -import static software.amazon.smithy.java.cbor.CborReadUtil.readLong; -import static software.amazon.smithy.java.cbor.CborReadUtil.readTextString; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Date; -import java.util.List; -import java.util.function.Consumer; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import software.amazon.smithy.java.cbor.CborParser.Token; -import software.amazon.smithy.java.core.serde.MapSerializer; -import software.amazon.smithy.java.core.serde.ShapeSerializer; -import software.amazon.smithy.java.io.ByteBufferOutputStream; -import software.amazon.smithy.java.io.ByteBufferUtils; - -public class CborParserTest { - private byte[] cbor; - private CborParser parser; - - private static byte[] write(Consumer consumer) { - try ( - var stream = new ByteBufferOutputStream(); - var ser = new CborSerializer(new Sink.OutputStreamSink(stream))) { - try { - consumer.accept(ser); - } catch (StopWritingException ignored) {} - return ByteBufferUtils.getBytes(stream.toByteBuffer()); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static final class StopWritingException extends RuntimeException { - @Override - public Throwable fillInStackTrace() { - return this; - } - } - - private void stopWriting(ShapeSerializer ser) { - throw new StopWritingException(); - } - - @Test - public void simple() { - cbor = write(os -> { - os.writeInteger(null, 1); - os.writeInteger(null, 1048576); - os.writeInteger(null, -1); - }); - - parser = new CborParser(cbor); - - token(Token.POS_INT, 0, 0); - num(1); - - token(Token.POS_INT, 2, 4); - num(1048576); - - token(Token.NEG_INT, 6, 0); - num(-1); - - finished(); - } - - @Test - public void incompleteImmediate() { - cbor = write(os -> { - writeList(os, Integer.MAX_VALUE, this::stopWriting); - }); - - parser = new CborParser(cbor); - token(Token.START_ARRAY); - Exception e = assertThrows(BadCborException.class, this::finished); - assertEquals("incomplete array: expecting " + Integer.MAX_VALUE + " more elements", e.getMessage()); - } - - @Test - public void bytestring() { - byte[] bigString = new byte[10]; - Arrays.fill(bigString, (byte) 'A'); - String s = "well howdy there"; - cbor = write(io -> { - io.writeBlob(null, "hello".getBytes(StandardCharsets.UTF_8)); - io.writeBlob(null, bigString); - io.writeString(null, s); - }); - parser = new CborParser(cbor); - - token(Token.BYTE_STRING, 1, 5); - string("hello"); - - token(Token.BYTE_STRING, 7, bigString.length); - string(bigString); - - token(Token.TEXT_STRING, 18, s.length()); - string(s); - - finished(); - } - - @Test - public void nestedIndefiniteArrays() { - cbor = write(os -> { - writeList(os, -1, list -> { - list.writeInteger(null, 1); - writeList(list, -1, nested1 -> { - nested1.writeInteger(null, 2); - nested1.writeLong(null, Long.MIN_VALUE); - writeList(nested1, -1, nested2 -> { - nested2.writeString(null, "hello"); - }); - }); - }); - }); - - parser = new CborParser(cbor); - token(Token.START_ARRAY, 0); - - token(Token.POS_INT, 1, 0); - num(1); - - token(Token.START_ARRAY, 2); - - token(Token.POS_INT, 3, 0); - num(2); - token(Token.NEG_INT, 5, 8); - num(Long.MIN_VALUE); - - token(Token.START_ARRAY, 13); - token(Token.TEXT_STRING, 15, 5); - string("hello"); - token(Token.END_ARRAY, 20); - - token(Token.END_ARRAY, 21); - - token(Token.END_ARRAY, 22); - - finished(); - } - - @Test - public void map() { - cbor = write(os -> { - writeMap(os, 1, map -> { - map.entry("key", e -> e.writeInteger(null, -1)); - }); - os.writeString(null, "not a key"); - }); - parser = new CborParser(cbor); - - token(Token.START_OBJECT); - token(Token.KEY, 2, 3); - - token(Token.NEG_INT, 5, 0); - num(-1); - - token(Token.END_OBJECT); - token(Token.TEXT_STRING, 7, "not a key".length()); - - finished(); - } - - @ValueSource(ints = {1, 1024, 1040000, 19100100}) - @ParameterizedTest - public void longList(int elements) { - cbor = write(os -> { - writeList(os, elements, list -> { - for (int i = 0; i < elements; i++) { - os.writeInteger(null, i); - } - }); - }); - parser = new CborParser(cbor); - - token(Token.START_ARRAY); - for (int i = 0; i < elements; i++) { - token(Token.POS_INT); - num(i); - } - token(Token.END_ARRAY); - } - - @ValueSource(ints = {1, 1024, 1040000}) - @ParameterizedTest - public void longMap(int elements) { - cbor = write(os -> { - writeMap(os, elements, map -> { - for (int i = 0; i < elements; i++) { - os.writeString(null, Integer.toString(i)); - os.writeInteger(null, i); - } - }); - }); - parser = new CborParser(cbor); - - token(Token.START_OBJECT); - for (int i = 0; i < elements; i++) { - token(Token.KEY); - string(Integer.toString(i)); - token(Token.POS_INT); - num(i); - } - token(Token.END_OBJECT); - } - - @Test - public void collectionsWithIndefiniteMembers() { - cbor = write(os -> { - writeMap(os, 1, map -> { - map.entry("key", e -> { - writeList(e, -1, list -> { - list.writeString(null, "value1"); - list.writeString(null, "value2"); - }); - }); - }); - - writeList(os, -1, list -> { - writeList(list, -1, nested -> { - nested.writeString(null, "s"); - }); - }); - - writeList(os, 1, list -> { - writeList(list, -1, nested -> { - nested.writeString(null, "s2"); - }); - }); - - writeList(os, -1, list -> { - writeList(list, 3, nested -> { - nested.writeInteger(null, 1); - nested.writeInteger(null, 2); - nested.writeString(null, "three"); - }); - - writeMap(list, -1, map -> { - map.entry("it's my key!", v -> { - writeMap(v, -1, nested -> { - nested.entry("it's another nested key", nestedValue -> { - nestedValue.writeInteger(null, 91919191); - }); - }); - }); - - map.entry("still just a key", value -> { - writeList(value, -1, nested -> { - nested.writeString(null, "array"); - }); - }); - }); - }); - }); - parser = new CborParser(cbor); - - token(Token.START_OBJECT); - token(Token.KEY, 2, 3); - string("key"); - token(Token.START_ARRAY); - token(Token.TEXT_STRING); - string("value1"); - token(Token.TEXT_STRING); - string("value2"); - token(Token.END_ARRAY); - token(Token.END_OBJECT); - - token(Token.START_ARRAY); - token(Token.START_ARRAY); - token(Token.TEXT_STRING); - string("s"); - token(Token.END_ARRAY); - token(Token.END_ARRAY); - - token(Token.START_ARRAY); - token(Token.START_ARRAY); - token(Token.TEXT_STRING); - string("s2"); - token(Token.END_ARRAY); - token(Token.END_ARRAY); - - token(Token.START_ARRAY); - token(Token.START_ARRAY); - token(Token.POS_INT); - num(1); - token(Token.POS_INT); - num(2); - token(Token.TEXT_STRING); - string("three"); - token(Token.END_ARRAY); - - token(Token.START_OBJECT); - token(Token.KEY); - string("it's my key!"); - token(Token.START_OBJECT); - token(Token.KEY); - string("it's another nested key"); - token(Token.POS_INT); - num(91919191); - token(Token.END_OBJECT); - token(Token.KEY); - string("still just a key"); - token(Token.START_ARRAY); - token(Token.TEXT_STRING); - string("array"); - token(Token.END_ARRAY); - token(Token.END_OBJECT); - - token(Token.END_ARRAY); - token(Token.FINISHED); - } - - @Test - public void nestedCollections() { - cbor = write(os -> { - writeMap(os, 2, map -> { - map.entry("AAA", value -> value.writeInteger(null, 1)); - map.entry("BBB", value -> writeMap(value, 1, nested -> { - nested.entry("CCC", nestedValue -> { - nestedValue.writeString(null, "DDDDD"); - }); - })); - }); - }); - - parser = new CborParser(cbor); - - token(Token.START_OBJECT); - token(Token.KEY, 2, 3); - string("AAA"); - token(Token.POS_INT, 5, 0); - num(1); - token(Token.KEY, 7, 3); - string("BBB"); - token(Token.START_OBJECT); - token(Token.KEY, 12, 3); - string("CCC"); - token(Token.TEXT_STRING, 16, 5); - string("DDDDD"); - token(Token.END_OBJECT); - token(Token.END_OBJECT); - token(Token.FINISHED); - } - - @Test - public void array() { - cbor = write(os -> { - List ints = Arrays.asList(1, 2, 3, 31); - writeList(os, ints.size(), list -> { - ints.forEach(i -> list.writeInteger(null, i)); - }); - }); - - parser = new CborParser(cbor); - token(Token.START_ARRAY); - assertEquals(1, parser.getPosition()); - - token(Token.POS_INT); - num(1); - token(Token.POS_INT); - num(2); - token(Token.POS_INT); - num(3); - token(Token.POS_INT); - num(31); - - token(Token.END_ARRAY); - - token(Token.FINISHED); - } - - @Test - public void booleans() { - cbor = write(os -> { - os.writeBoolean(null, false); - os.writeBoolean(null, true); - }); - - parser = new CborParser(cbor); - token(Token.FALSE); - token(Token.TRUE); - token(Token.FINISHED); - } - - @Test - public void floats() { - cbor = write(os -> { - os.writeDouble(null, 1); - os.writeFloat(null, 0.1f); - }); - - parser = new CborParser(cbor); - token(Token.FLOAT); - token(Token.FLOAT); - token(Token.FINISHED); - } - - @Test - public void bigDecimalLongExponent() { - byte[][] payloads = new byte[][] { - new byte[] {-60, -126, 27, 0, 0, 0, 7, -1, -1, -1, -1, 1}, - new byte[] {-60, -126, 59, 0, 0, 0, 7, -1, -1, -1, -1, 1}, - new byte[] {-60, -126, 58, 127, -1, -1, -1, 1}, // 1^-2147483648 - new byte[] {-60, -126, 58, -128, 0, 0, 0, 1} // 1^-2147483649 - }; - for (byte[] payload : payloads) { - cbor = payload; - parser = new CborParser(cbor); - token(Token.BIG_DECIMAL); - assertThrows(BadCborException.class, () -> CborReadUtil.readBigDecimal(cbor, parser.getPosition())); - } - } - - @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void bigIntegerWithExcessiveLength(boolean negative) { - // Create a CBOR payload with a bignum (tag 2 for positive, tag 3 for negative) followed by - // a byte string claiming to have a very large length (Integer.MAX_VALUE) - // but the actual buffer is much smaller. - // CBOR structure: tag(2/3) + byte_string(length=Integer.MAX_VALUE) + minimal data - // - // 0xC2 = tag 2 (positive bignum), 0xC3 = tag 3 (negative bignum) - // 0x5A = byte string with 4-byte length - // 0x7F 0xFF 0xFF 0xFF = Integer.MAX_VALUE - // followed by a few bytes (nowhere near enough for the claimed length) - byte[] payload = new byte[] { - (byte) (negative ? 0xC3 : 0xC2), // tag 2 or 3 - (byte) 0x5A, // byte string with 4-byte length - (byte) 0x7F, - (byte) 0xFF, // length = Integer.MAX_VALUE - (byte) 0xFF, - (byte) 0xFF, - (byte) 0x01, - (byte) 0x02 // only 2 bytes of actual data - }; - - cbor = payload; - parser = new CborParser(cbor); - byte expectedToken = negative ? Token.NEG_BIGINT : Token.POS_BIGINT; - token(expectedToken); - - // Should throw BadCborException instead of trying to allocate Integer.MAX_VALUE bytes - assertThrows(BadCborException.class, - () -> CborReadUtil - .readBigInteger(cbor, expectedToken, parser.getPosition(), parser.getItemLength())); - } - - @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void bigIntegerWithIndefiniteLengthExcessiveChunk(boolean negative) { - // Test indefinite length byte string with a chunk claiming excessive length - // CBOR structure: tag(2/3) + indefinite_byte_string + chunk_with_large_length + break - // - // 0xC2 = tag 2 (positive bignum), 0xC3 = tag 3 (negative bignum) - // 0x5F = indefinite length byte string marker - // 0x5A = definite byte string chunk with 4-byte length - // 0x7F 0xFF 0xFF 0xFF = Integer.MAX_VALUE (chunk length) - // followed by minimal data (nowhere near the claimed chunk length) - byte[] payload = new byte[] { - (byte) (negative ? 0xC3 : 0xC2), // tag 2 or 3 - (byte) 0x5F, // indefinite length byte string - (byte) 0x5A, // chunk: byte string with 4-byte length - (byte) 0x7F, - (byte) 0xFF, // chunk length = Integer.MAX_VALUE - (byte) 0xFF, - (byte) 0xFF, - (byte) 0x01, - (byte) 0x02 // only 2 bytes of actual data - }; - - cbor = payload; - parser = new CborParser(cbor); - byte expectedToken = negative ? Token.NEG_BIGINT : Token.POS_BIGINT; - - // TODO Fix this so that we don't throw AIOBE - assertThrows(ArrayIndexOutOfBoundsException.class, - () -> token(expectedToken)); - } - - @Test - public void bigIntegerWithIndefiniteLengthExcessiveTotalSize() { - // Test indefinite length byte string where total claimed size across chunks is excessive - // Even though individual chunks are small, their sum should trigger protection - // - // 0xC2 = tag 2 (positive bignum) - // 0x5F = indefinite length byte string marker - // Multiple chunks each claiming 500MB (0x1DCD6500 bytes) - byte[] payload = new byte[] { - (byte) 0xC2, // tag 2 (positive bignum) - (byte) 0x5F, // indefinite length byte string - (byte) 0x5A, // chunk 1: byte string with 4-byte length - (byte) 0x1D, - (byte) 0xCD, // 500MB - (byte) 0x65, - (byte) 0x00, - (byte) 0x01, - (byte) 0x02, // only 2 bytes of actual data - (byte) 0x5A, // chunk 2: byte string with 4-byte length - (byte) 0x1D, - (byte) 0xCD, // another 500MB - (byte) 0x65, - (byte) 0x00, - (byte) 0x03, - (byte) 0x04 // only 2 bytes of actual data - }; - - cbor = payload; - parser = new CborParser(cbor); - - // Should throw BadCborException instead of trying to process 1GB worth of chunks - assertThrows(BadCborException.class, () -> token(Token.POS_BIGINT)); - } - - @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void incompleteCollection(boolean map) { - cbor = write(os -> { - if (map) { - writeMap(os, 2, c -> { - c.entry("hi", value -> value.writeString(null, "hi")); - }); - } else { - writeList(os, 2, c -> { - c.writeString(null, "hi"); - }); - } - }); - parser = new CborParser(cbor); - - token(map ? Token.START_OBJECT : Token.START_ARRAY); - token(map ? Token.KEY : Token.TEXT_STRING); - string("hi"); - if (map) { - token(Token.TEXT_STRING); - string("hi"); - } - expectFailure("incomplete " + (map ? "map" : "array")); - } - - @Test - public void missingMapValue() { - cbor = write(os -> { - writeMap(os, 1, map -> { - map.entry("hi", this::stopWriting); - }); - }); - parser = new CborParser(cbor); - - token(Token.START_OBJECT); - token(Token.KEY); - string("hi"); - expectFailure("incomplete map"); - } - - @Test - public void date() { - Date time = new Date(); - cbor = write(os -> { - os.writeTimestamp(null, time.toInstant()); - }); - - parser = new CborParser(cbor); - token(Token.EPOCH_F); - num(Double.doubleToRawLongBits(time.getTime() / 1000d)); - finished(); - } - - @Test - public void negativeDate() { - Date d = new Date(-1000); - cbor = write(os -> { - os.writeTimestamp(null, d.toInstant()); - }); - - parser = new CborParser(cbor); - token(Token.EPOCH_F); - num(d.getTime() / 1000d); - finished(); - } - - private void token(byte token) { - byte next = parser.advance(); - assertEquals(token, next, "expected " + Token.name(token) + " but got " + Token.name(next)); - } - - private void token(byte token, int position) { - token(token); - assertEquals(position, parser.getPosition(), "expected pos " + position + " but was " + parser.getPosition()); - } - - private void token(byte token, int position, int itemLength) { - token(token, position); - assertEquals( - itemLength, - CborParser.itemLength(parser.getItemLength()), - "expected len " + itemLength - + " but was " + CborParser.itemLength(parser.getItemLength())); - } - - private void num(double d) { - num(Double.doubleToRawLongBits(d), Token.POS_INT); - } - - private void num(long n) { - num(n, n < 0 ? Token.NEG_INT : Token.POS_INT); - } - - private void num(long n, byte token) { - assertEquals(n, readLong(cbor, token, parser.getPosition(), parser.getItemLength())); - } - - private void string(String s) { - assertEquals(s, readTextString(cbor, parser.getPosition(), parser.getItemLength())); - } - - private void string(byte[] b) { - assertArrayEquals(b, readByteString(cbor, parser.getPosition(), parser.getItemLength())); - } - - private void finished() { - token(Token.FINISHED, cbor.length, 0); - } - - private void lengthIsFinite() { - assertFalse(CborParser.isIndefinite(parser.getItemLength())); - } - - private void lengthIsIndefinite() { - assertTrue(CborParser.isIndefinite(parser.getItemLength())); - } - - private void expectFailure(String msg) { - BadCborException e = assertThrows(BadCborException.class, parser::advance); - assertTrue(e.getMessage().contains(msg), e.getMessage()); - } - - private static void writeList(ShapeSerializer s, int len, Consumer listHandler) { - s.writeList(null, null, len, ($, l) -> listHandler.accept(l)); - } - - private static void writeMap(ShapeSerializer s, int len, Consumer mapHandler) { - s.writeMap(null, null, len, ($, l) -> mapHandler.accept(new WriteEntry(l))); - } - - private static final class WriteEntry { - private final MapSerializer m; - - private WriteEntry(MapSerializer m) { - this.m = m; - } - - void entry(String key, Consumer val) { - m.writeEntry(null, key, null, ($, s) -> val.accept(s)); - } - } -} diff --git a/codecs/cbor-codec/src/testFixtures/java/software/amazon/java/cbor/CborComparator.java b/codecs/cbor-codec/src/testFixtures/java/software/amazon/java/cbor/CborComparator.java index 34fa802ff..b5c25ed44 100644 --- a/codecs/cbor-codec/src/testFixtures/java/software/amazon/java/cbor/CborComparator.java +++ b/codecs/cbor-codec/src/testFixtures/java/software/amazon/java/cbor/CborComparator.java @@ -6,238 +6,27 @@ package software.amazon.java.cbor; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; -import static software.amazon.smithy.java.cbor.CborReadUtil.readTextString; import java.nio.ByteBuffer; -import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Deque; -import java.util.HashMap; -import java.util.HexFormat; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import software.amazon.smithy.java.cbor.CborParser; -import software.amazon.smithy.java.cbor.CborReadUtil; +import software.amazon.smithy.java.cbor.Rpcv2CborCodec; +import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.java.io.ByteBufferUtils; public class CborComparator { + private static final Rpcv2CborCodec CODEC = Rpcv2CborCodec.builder().build(); + public static void assertEquals(ByteBuffer expected, ByteBuffer actual) { byte[] expectedBytes = ByteBufferUtils.getBytes(expected); byte[] actualBytes = ByteBufferUtils.getBytes(actual); if (Arrays.equals(expectedBytes, actualBytes)) { return; } - assertThat(CborValue.parse(actualBytes)) + Document expectedDoc = CODEC.createDeserializer(expectedBytes).readDocument(); + Document actualDoc = CODEC.createDeserializer(actualBytes).readDocument(); + assertThat(actualDoc) .usingRecursiveComparison() - .isEqualTo(CborValue.parse(expectedBytes)); - - } - - private abstract static sealed class CborValue> { - - private static CborValue parse(byte[] buf) { - try { - return parse0(buf); - } catch (Exception e) { - // to ensure a stack trace is gathered - throw new RuntimeException(e); - } - } - - private static CborValue parse0(byte[] buf) { - var parser = new CborParser(buf); - var context = new ArrayDeque>(); - var keys = new ArrayDeque(); - CborValue root = null; - byte token; - while ((token = parser.advance()) != CborParser.Token.FINISHED) { - switch (token) { - case CborParser.Token.KEY -> - keys.addLast(readTextString(buf, parser.getPosition(), parser.getItemLength())); - case CborParser.Token.START_OBJECT -> context.addLast(new MapValue()); - case CborParser.Token.START_ARRAY -> context.addLast(new ListValue()); - case CborParser.Token.END_ARRAY, CborParser.Token.END_OBJECT -> { - var top = Objects.requireNonNull(context.pollLast()); - if (context.isEmpty()) { - root = top; - } else { - add(keys, context, top); - } - } - case CborParser.Token.NULL -> add(keys, context, new NullValue()); - case CborParser.Token.TEXT_STRING -> add(keys, context, new StringValue(buf, parser)); - case CborParser.Token.POS_INT, CborParser.Token.NEG_INT -> - add(keys, context, new IntValue(buf, token, parser)); - case CborParser.Token.FLOAT -> add(keys, context, new FloatValue(buf, token, parser)); - case CborParser.Token.TRUE, CborParser.Token.FALSE -> add(keys, context, new BooleanValue(token)); - case CborParser.Token.EPOCH_F, CborParser.Token.EPOCH_INEG, CborParser.Token.EPOCH_IPOS -> add( - keys, - context, - new TimeValue(buf, token, parser)); - case CborParser.Token.BYTE_STRING -> add(keys, context, new BlobValue(buf, parser)); - default -> throw new RuntimeException("can't handle " + CborParser.Token.name(token)); - } - } - assertThat(parser.getPosition()).isEqualTo(buf.length); - return root; - } - - private static void add(Deque keys, Deque> context, CborValue value) { - var top = context.peekLast(); - if (top instanceof MapValue v) { - v.put(Objects.requireNonNull(keys.pollLast()), value); - } else if (top instanceof ListValue v) { - v.add(value); - } else { - throw new RuntimeException("Can't add to a " + top.getClass()); - } - } - - private static final class MapValue extends CborValue { - private final Map> map; - - private MapValue() { - this.map = new HashMap<>(); - } - - public void put(String key, CborValue value) { - map.put(key, value); - } - - @Override - public String toString() { - return "Map{" + map + "}"; - } - } - - private static final class ListValue extends CborValue { - private final List> values = new ArrayList<>(); - - public void add(CborValue value) { - values.add(value); - } - - @Override - public String toString() { - return "List{" + values + "}"; - } - } - - private static final class NullValue extends CborValue { - - @Override - public String toString() { - return "Null"; - } - } - - private static final class StringValue extends CborValue { - private final String value; - - private StringValue(byte[] buf, CborParser parser) { - this.value = CborReadUtil.readTextString(buf, parser.getPosition(), parser.getItemLength()); - } - - @Override - public String toString() { - return "String{" + value + "}"; - } - } - - private static final class IntValue extends CborValue { - private final long value; - - IntValue(byte[] buf, byte type, CborParser parser) { - if (type > CborParser.Token.NEG_INT) { - throw new RuntimeException("can't read " + CborParser.Token.name(type) + " as long"); - } - - int pos = parser.getPosition(); - int len = parser.getItemLength(); - long val = CborReadUtil.readLong(buf, type, pos, len); - if (len < 8) { - this.value = val; - } else if (type == CborParser.Token.POS_INT) { - this.value = val < 0 ? Long.MAX_VALUE : val; - } else { - this.value = val < 0 ? val : Long.MIN_VALUE; - } - } - - @Override - public String toString() { - return "Int{" + value + "}"; - } - } - - private static final class FloatValue extends CborValue { - private final double value; - - FloatValue(byte[] buf, byte type, CborParser parser) { - int len = parser.getItemLength(); - long raw = CborReadUtil.readLong(buf, type, parser.getPosition(), parser.getItemLength()); - if (len == 8) { - value = Double.longBitsToDouble(raw); - } else if (len == 4) { - value = Float.intBitsToFloat((int) raw); - } else { - throw new RuntimeException("Can't handle fp of len " + len); - } - } - - @Override - public String toString() { - return "Float{" + value + "}"; - } - } - - private static final class BooleanValue extends CborValue { - private final boolean bool; - - BooleanValue(byte type) { - this.bool = type == CborParser.Token.TRUE; - } - - @Override - public String toString() { - return "Boolean{" + bool + "}"; - } - } - - private static final class TimeValue extends CborValue { - private final long time; - - TimeValue(byte[] buf, byte type, CborParser parser) { - byte actual = (byte) (type ^ CborParser.Token.TAG_FLAG); - if (actual <= CborParser.Token.NEG_INT) { - time = new IntValue(buf, actual, parser).value * 1000; - } else if (actual == CborParser.Token.FLOAT) { - time = Math.round(new FloatValue(buf, actual, parser).value * 1000d); - } else { - throw new RuntimeException("not a timestamp: " + CborParser.Token.name(type)); - } - } - - @Override - public String toString() { - return "Time{" + time + "}"; - } - } - - private static final class BlobValue extends CborValue { - private final byte[] bytes; - - BlobValue(byte[] buf, CborParser parser) { - this.bytes = CborReadUtil.readByteString(buf, parser.getPosition(), parser.getItemLength()); - } - - @Override - public String toString() { - return "Blob{" + HexFormat.of().formatHex(bytes) + "}"; - } - } + .isEqualTo(expectedDoc); } } diff --git a/codecs/json-codec/build.gradle.kts b/codecs/json-codec/build.gradle.kts index b097e6d6b..274aa1981 100644 --- a/codecs/json-codec/build.gradle.kts +++ b/codecs/json-codec/build.gradle.kts @@ -1,7 +1,6 @@ plugins { id("smithy-java.module-conventions") id("smithy-java.fuzz-test") - id("me.champeau.jmh") version "0.7.3" id("software.amazon.smithy.gradle.smithy-base") alias(libs.plugins.shadow) } @@ -70,14 +69,6 @@ afterEvaluate { afterEvaluate { val typePath = smithy.getPluginProjectionPath(smithy.sourceProjection.get(), "java-codegen").get() - sourceSets.named("jmh") { - java { - srcDir("$typePath/java") - } - resources { - srcDir("$typePath/resources") - } - } sourceSets.named("test") { java { srcDir("$typePath/java") @@ -88,33 +79,10 @@ afterEvaluate { } } -tasks.named("compileJmhJava") { - dependsOn("smithyBuild") -} - tasks.named("compileTestJava") { dependsOn("smithyBuild") } -tasks.named("processJmhResources") { - dependsOn("smithyBuild") -} - tasks.named("processTestResources") { dependsOn("smithyBuild") } - -jmh { - warmupIterations = 3 - iterations = 5 - fork = 1 - jvmArgs.addAll("-Xms1g", "-Xmx1g") - includes.addAll( - providers - .gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList()), - ) - profilers.add("async:output=jfr;dir=${layout.buildDirectory.get()}/jmh-profiler") - // profilers.add("gc") -} diff --git a/codecs/json-codec/src/jmh/java/software/amazon/smithy/java/json/JsonBench.java b/codecs/json-codec/src/jmh/java/software/amazon/smithy/java/json/JsonBench.java deleted file mode 100644 index 49ff850ac..000000000 --- a/codecs/json-codec/src/jmh/java/software/amazon/smithy/java/json/JsonBench.java +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.json; - -import static tools.jackson.core.JsonToken.PROPERTY_NAME; - -import java.io.ByteArrayOutputStream; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.time.Instant; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.OutputTimeUnit; -import org.openjdk.jmh.annotations.Param; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.Warmup; -import software.amazon.smithy.java.core.schema.SerializableStruct; -import software.amazon.smithy.java.core.schema.ShapeBuilder; -import software.amazon.smithy.java.core.serde.document.Document; -import software.amazon.smithy.java.json.bench.model.BenchUnion; -import software.amazon.smithy.java.json.bench.model.Color; -import software.amazon.smithy.java.json.bench.model.ComplexStruct; -import software.amazon.smithy.java.json.bench.model.InnerStruct; -import software.amazon.smithy.java.json.bench.model.NestedStruct; -import software.amazon.smithy.java.json.bench.model.SimpleStruct; -import software.amazon.smithy.java.json.jackson.JacksonJsonSerdeProvider; -import software.amazon.smithy.java.json.smithy.SmithyJsonSerdeProvider; -import tools.jackson.core.ObjectReadContext; -import tools.jackson.core.ObjectWriteContext; -import tools.jackson.core.json.JsonFactory; - -@State(Scope.Benchmark) -@OutputTimeUnit(TimeUnit.NANOSECONDS) -@BenchmarkMode(Mode.AverageTime) -@Warmup(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS) -@Measurement(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS) -@Fork(1) -public class JsonBench { - - public enum TestCase { - SIMPLE, - COMPLEX, - } - - public enum Provider { - jackson, - smithy, - } - - @Param - private TestCase testCase; - - @Param - private Provider provider; - - private JsonCodec codec; - private SerializableStruct shape; - private byte[] serializedBytes; - private byte[] reversedBytes; - private Supplier> builderSupplier; - - @Setup - public void setup() { - JsonSerdeProvider serdeProvider = switch (provider) { - case jackson -> new JacksonJsonSerdeProvider(); - case smithy -> new SmithyJsonSerdeProvider(); - }; - codec = JsonCodec.builder() - .overrideSerdeProvider(serdeProvider) - .useJsonName(true) - .useTimestampFormat(true) - .build(); - - switch (testCase) { - case SIMPLE -> { - shape = buildSimpleStruct(); - builderSupplier = SimpleStruct::builder; - } - case COMPLEX -> { - shape = buildComplexStruct(); - builderSupplier = ComplexStruct::builder; - } - } - - // Pre-serialize to byte[] for deserialization benchmarks - ByteBuffer buf = codec.serialize(shape); - serializedBytes = new byte[buf.remaining()]; - buf.get(serializedBytes); - - // Reversed field order to force the slow path in SmithyMemberLookup - reversedBytes = reverseJsonFieldOrder(serializedBytes); - } - - private static SimpleStruct buildSimpleStruct() { - return SimpleStruct.builder() - .name("benchmark-test") - .age(42) - .active(true) - .score(98.6) - .createdAt(Instant.parse("2025-01-15T10:30:00Z")) - .build(); - } - - private static ComplexStruct buildComplexStruct() { - var inner = InnerStruct.builder() - .value("inner-value") - .numbers(List.of(1, 2, 3, 4, 5)) - .build(); - var nested = NestedStruct.builder() - .field1("nested-field") - .field2(100) - .inner(inner) - .build(); - var sparseMap = new HashMap(); - sparseMap.put("x", "1"); - sparseMap.put("y", "2"); - sparseMap.put("z", null); - return ComplexStruct.builder() - .id("bench-001") - .count(999) - .enabled(true) - .ratio(1.618) - .score(2.718f) - .bigCount(1_000_000L) - .optionalString("optional-value") - .optionalInt(42) - .createdAt(Instant.parse("2025-01-15T10:30:00Z")) - .updatedAt(Instant.parse("2025-06-01T12:00:00Z")) - .expiresAt(Instant.parse("2026-01-01T00:00:00Z")) - .payload(ByteBuffer.wrap("binary-payload-data".getBytes(StandardCharsets.UTF_8))) - .tags(List.of("alpha", "beta", "gamma", "delta")) - .intList(List.of(10, 20, 30, 40, 50)) - .metadata(Map.of("key1", "value1", "key2", "value2", "key3", "value3")) - .intMap(Map.of("a", 1, "b", 2, "c", 3)) - .nested(nested) - .optionalNested(NestedStruct.builder() - .field1("opt-nested") - .field2(200) - .build()) - .structList(List.of(nested, nested)) - .structMap(Map.of("first", nested, "second", nested)) - .choice(new BenchUnion.StringValueMember("union-string")) - .color(Color.GREEN) - .colorList(List.of(Color.RED, Color.BLUE, Color.YELLOW)) - .sparseStrings(Arrays.asList("a", null, "c")) - .sparseMap(sparseMap) - .bigIntValue(new BigInteger("123456789012345678901234567890")) - .bigDecValue(new BigDecimal("99999.99999")) - .freeformData(Document.of(Map.of("key", Document.of("value"), "num", Document.of(42)))) - .build(); - } - - @Benchmark - public ByteBuffer serialize() { - return codec.serialize(shape); - } - - @Benchmark - public Object deserialize() { - return codec.deserializeShape(serializedBytes, builderSupplier.get()); - } - - @Benchmark - public Object deserializeReversed() { - return codec.deserializeShape(reversedBytes, builderSupplier.get()); - } - - @Benchmark - public Object roundtrip() { - ByteBuffer bytes = codec.serialize(shape); - return codec.deserializeShape(bytes, builderSupplier.get()); - } - - /** - * Reverses the order of top-level JSON object fields using Jackson streaming. - * Only used at setup time. - */ - private static byte[] reverseJsonFieldOrder(byte[] json) { - var factory = JsonFactory.builder().build(); - - // Parse top-level fields and capture each value as raw bytes - List> fields = new ArrayList<>(); - try (var parser = factory.createParser(ObjectReadContext.empty(), json)) { - parser.nextToken(); // START_OBJECT - while (parser.nextToken() == PROPERTY_NAME) { - String name = parser.currentName(); - parser.nextToken(); // advance to value - var baos = new ByteArrayOutputStream(); - try (var gen = factory.createGenerator(ObjectWriteContext.empty(), baos)) { - gen.copyCurrentStructure(parser); - } - fields.add(new AbstractMap.SimpleEntry<>(name, baos.toByteArray())); - } - } - - // Re-emit in reverse order - Collections.reverse(fields); - var out = new ByteArrayOutputStream(); - try (var gen = factory.createGenerator(ObjectWriteContext.empty(), out)) { - gen.writeStartObject(); - for (var field : fields) { - gen.writeName(field.getKey()); - try (var p = factory.createParser(ObjectReadContext.empty(), field.getValue())) { - p.nextToken(); - gen.copyCurrentStructure(p); - } - } - gen.writeEndObject(); - } - return out.toByteArray(); - } -} diff --git a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/MapGenerator.java b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/MapGenerator.java index 7c209c016..63797e49e 100644 --- a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/MapGenerator.java +++ b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/generators/MapGenerator.java @@ -84,7 +84,7 @@ public void accept(${value:B} values, ${shapeSerializer:T} serializer) { static ${shape:T} deserialize${name:U}(${schema:T} schema, ${shapeDeserializer:T} deserializer) { var size = Math.min(deserializer.containerSize(), deserializer.containerPreAllocationLimit()); - ${shape:T} result = size == -1 ? new ${collectionImpl:T}<>() : new ${collectionImpl:T}<>(size); + ${shape:T} result = size == -1 ? new ${collectionImpl:T}<>() : ${collectionImpl:T}.${newMap:L}(size); deserializer.readStringMap(schema, result, ${name:U}$$ValueDeserializer.INSTANCE); return result; } @@ -108,10 +108,12 @@ public void accept(${shape:B} state, ${string:T} key, ${shapeDeserializer:T} des writer.putContext("value", valueSymbol); writer.putContext("keySchema", keySchema); writer.putContext("valueSchema", valueSchema); + var collectionImpl = (Class) directive.symbol() + .expectProperty(SymbolProperties.COLLECTION_IMPLEMENTATION_CLASS); + writer.putContext("collectionImpl", collectionImpl); writer.putContext( - "collectionImpl", - directive.symbol() - .expectProperty(SymbolProperties.COLLECTION_IMPLEMENTATION_CLASS)); + "newMap", + "new" + collectionImpl.getSimpleName()); writer.putContext("schema", Schema.class); writer.putContext("biConsumer", BiConsumer.class); writer.putContext("mapSerializer", MapSerializer.class); diff --git a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ArrayHttpHeaders.java b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ArrayHttpHeaders.java index f65dacc08..0a67b275d 100644 --- a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ArrayHttpHeaders.java +++ b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ArrayHttpHeaders.java @@ -42,6 +42,15 @@ void addHeaderCanonical(String key, String value) { size++; } + @Override + public void addHeaderTrusted(HeaderName name, String value) { + ensureCapacity(); + int idx = size * 2; + array[idx] = name.name(); + array[idx + 1] = value; + size++; + } + /** * Add a header directly from bytes (zero-copy for known headers). * diff --git a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpHeaders.java b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpHeaders.java index 8777d493f..ef41d730a 100644 --- a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpHeaders.java +++ b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpHeaders.java @@ -171,6 +171,16 @@ default void placeHeaders(HttpHeaders headers) { headers.forEachEntry(this::setHeader); } + /** + * Add a header with a pre-validated value, bypassing value normalization. + * + *

Use this for values known to be valid (e.g., numeric content-length). + * The default delegates to {@link #addHeader(HeaderName, String)}. + */ + default void addHeaderTrusted(HeaderName name, String value) { + addHeader(name, value); + } + /** * Remove a header and its values by name. * diff --git a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpResponseImpl.java b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpResponseImpl.java index 75e7df1db..509cb8d6b 100644 --- a/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpResponseImpl.java +++ b/http/http-api/src/main/java/software/amazon/smithy/java/http/api/ModifiableHttpResponseImpl.java @@ -73,14 +73,13 @@ public ModifiableHttpResponse setBody(DataStream body) { return this; } - // Shared helper method with ModifiableHttpRequestImpl to set headers based on the provided body. static void addBodyHeaders(DataStream body, ModifiableHttpHeaders headers) { var ct = body.contentType(); if (ct != null && !headers.hasHeader(HeaderName.CONTENT_TYPE)) { - headers.addHeader(HeaderName.CONTENT_TYPE, ct); + headers.addHeaderTrusted(HeaderName.CONTENT_TYPE, ct); } if (body.hasKnownLength() && !headers.hasHeader(HeaderName.CONTENT_LENGTH)) { - headers.addHeader(HeaderName.CONTENT_LENGTH, String.valueOf(body.contentLength())); + headers.addHeaderTrusted(HeaderName.CONTENT_LENGTH, Long.toString(body.contentLength())); } } diff --git a/http/http-api/src/test/java/software/amazon/smithy/java/http/api/HttpHeadersTest.java b/http/http-api/src/test/java/software/amazon/smithy/java/http/api/HttpHeadersTest.java index 99b19ead8..c725b5560 100644 --- a/http/http-api/src/test/java/software/amazon/smithy/java/http/api/HttpHeadersTest.java +++ b/http/http-api/src/test/java/software/amazon/smithy/java/http/api/HttpHeadersTest.java @@ -49,4 +49,34 @@ public void convertUnmofiableToModifiable() { assertThat(httpHeaders.map(), equalTo(mod.map())); } + + @Test + public void addHeaderTrustedBypassesNormalization() { + var headers = new ArrayHttpHeaders(); + headers.addHeaderTrusted(HeaderName.CONTENT_LENGTH, "42"); + + assertThat(headers.firstValue("content-length"), equalTo("42")); + } + + @Test + public void addHeaderTrustedUsesInternedName() { + var headers = new ArrayHttpHeaders(); + headers.addHeaderTrusted(HeaderName.CONTENT_TYPE, "application/json"); + headers.addHeaderTrusted(HeaderName.CONTENT_LENGTH, "100"); + + assertThat(headers.firstValue("content-type"), equalTo("application/json")); + assertThat(headers.firstValue("content-length"), equalTo("100")); + } + + @Test + public void addHeaderTrustedCoexistsWithRegularHeaders() { + var headers = new ArrayHttpHeaders(); + headers.addHeader(HeaderName.of("x-custom"), "value1"); + headers.addHeaderTrusted(HeaderName.CONTENT_TYPE, "text/plain"); + headers.addHeader(HeaderName.of("x-other"), "value2"); + + assertThat(headers.firstValue("x-custom"), equalTo("value1")); + assertThat(headers.firstValue("content-type"), equalTo("text/plain")); + assertThat(headers.firstValue("x-other"), equalTo("value2")); + } } From a56d0249fbe287c77d7b42d09e0152cec6e07bee Mon Sep 17 00:00:00 2001 From: Adwait Kumar Singh Date: Mon, 18 May 2026 14:36:28 -0700 Subject: [PATCH 2/3] Fix DateTimeException in CBOR timestamp deserialization The previous readTimestamp passed the tagged token to readLong, which always threw SerializationException before reaching Instant.ofEpochSecond, masking the fact that out-of-range epoch values cause DateTimeException. Now that the token is correctly stripped, wrap with try-catch to convert the overflow to SerializationException (zero happy-path cost via exception tables). --- .../smithy/java/cbor/CborDeserializer.java | 15 +++++--- .../smithy/java/cbor/CborCodecTest.java | 38 +++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java index 4746c1554..4c21abd79 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborDeserializer.java @@ -13,6 +13,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; +import java.time.DateTimeException; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -691,11 +692,15 @@ public Document readDocument() { public Instant readTimestamp(Schema schema) { byte token = this.token; byte actual = (byte) (token ^ Token.TAG_FLAG); - if (actual <= Token.NEG_INT) { - return Instant.ofEpochSecond(readLong("timestamp", actual)); - } else if (actual == Token.FLOAT) { - double d = readDouble("timestamp", actual); - return Instant.ofEpochMilli(Math.round(d * 1000d)); + try { + if (actual <= Token.NEG_INT) { + return Instant.ofEpochSecond(readLong("timestamp", actual)); + } else if (actual == Token.FLOAT) { + double d = readDouble("timestamp", actual); + return Instant.ofEpochMilli(Math.round(d * 1000d)); + } + } catch (DateTimeException e) { + throw new SerializationException("timestamp out of range", e); } throw badType("timestamp", token); } diff --git a/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java b/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java index e37a5f9ae..147a984ff 100644 --- a/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java +++ b/codecs/cbor-codec/src/test/java/software/amazon/smithy/java/cbor/CborCodecTest.java @@ -29,6 +29,7 @@ import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SerializableStruct; import software.amazon.smithy.java.core.serde.MapSerializer; +import software.amazon.smithy.java.core.serde.SerializationException; import software.amazon.smithy.java.core.serde.ShapeDeserializer; import software.amazon.smithy.java.core.serde.ShapeSerializer; import software.amazon.smithy.java.io.ByteBufferOutputStream; @@ -372,6 +373,24 @@ void serializeDeserializeAllFields() { assertEquals(new BigDecimal("3.14"), de.wingspan); } + // Regression: the original readTimestamp passed the tagged token (EPOCH_IPOS=16) to readLong, + // which rejected any token > NEG_INT(1), so integer epoch timestamps always threw. + @Test + void timestampIntegerEpoch() { + // tag(1) + uint8(100) - simplest valid integer epoch + byte[] payload = new byte[] {(byte) 0xC1, 0x18, 0x64}; + ShapeDeserializer de = CODEC.newDeserializer(payload, SETTINGS); + assertEquals(Instant.ofEpochSecond(100), de.readTimestamp(PreludeSchemas.TIMESTAMP)); + } + + @Test + void timestampNegativeIntegerEpoch() { + // tag(1) + negint8(99) = -1 - 99 = -100 seconds + byte[] payload = new byte[] {(byte) 0xC1, 0x38, 0x63}; + ShapeDeserializer de = CODEC.newDeserializer(payload, SETTINGS); + assertEquals(Instant.ofEpochSecond(-100), de.readTimestamp(PreludeSchemas.TIMESTAMP)); + } + @Test void timestampWholeSeconds() { Instant wholeSecond = Instant.ofEpochSecond(1700000000); @@ -530,6 +549,25 @@ void incompleteCollection(boolean map) { assertTrue(e.getMessage().contains("incomplete " + (map ? "map" : "array")), e.getMessage()); } + @Test + void timestampOutOfRange() { + // tag(1) + uint64(Long.MAX_VALUE) - epoch seconds far exceeding Instant.MAX + byte[] payload = new byte[] { + (byte) 0xC1, // tag 1 (epoch timestamp) + (byte) 0x1B, // 8-byte positive integer follows + (byte) 0x7F, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF, + (byte) 0xFF + }; + ShapeDeserializer de = CODEC.newDeserializer(payload, SETTINGS); + assertThrows(SerializationException.class, () -> de.readTimestamp(PreludeSchemas.TIMESTAMP)); + } + @Test void missingMapValue() { byte[] cbor = write(os -> { From 0a6d2ef94bf54ecef3ea6d3256096cca52e83f9a Mon Sep 17 00:00:00 2001 From: Adwait Kumar Singh Date: Mon, 18 May 2026 16:13:02 -0700 Subject: [PATCH 3/3] Mark CborSchemaExtensions internal --- .../java/cbor/CborSchemaExtensions.java | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java index 4569494aa..f35de7bd0 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSchemaExtensions.java @@ -10,6 +10,7 @@ import software.amazon.smithy.java.core.schema.SchemaExtensionKey; import software.amazon.smithy.java.core.schema.SchemaExtensionProvider; import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.utils.SmithyInternalApi; /** * Pre-computes CBOR codec data on Schema objects. @@ -18,13 +19,14 @@ * For struct/union schemas: {@link CborMemberLookup} instances for hash-based field matching * and field name tables for O(1) lookup by memberIndex during serialization. */ +@SmithyInternalApi public final class CborSchemaExtensions - implements SchemaExtensionProvider { + implements SchemaExtensionProvider { /** * Extension key for CBOR codec data. */ - public static final SchemaExtensionKey KEY = new SchemaExtensionKey<>(); + public static final SchemaExtensionKey KEY = new SchemaExtensionKey<>(); /** * Pre-computed CBOR data stored on a Schema. @@ -33,18 +35,18 @@ public final class CborSchemaExtensions * @param memberLookup Hash-based member lookup (null for non-structs) * @param fieldNameTable Indexed by memberIndex: pre-computed name bytes per member (null for non-structs) */ - public record NativeCborExtension( + public record CborExtension( byte[] memberNameBytes, CborMemberLookup memberLookup, byte[][] fieldNameTable) {} @Override - public SchemaExtensionKey key() { + public SchemaExtensionKey key() { return KEY; } @Override - public NativeCborExtension provide(Schema schema) { + public CborExtension provide(Schema schema) { if (schema.isMember()) { return forMember(schema); } @@ -55,15 +57,15 @@ public NativeCborExtension provide(Schema schema) { return null; } - private static NativeCborExtension forMember(Schema schema) { + private static CborExtension forMember(Schema schema) { byte[] memberNameBytes = CborSerializer.encodeMemberName(schema.memberName()); - return new NativeCborExtension(memberNameBytes, null, null); + return new CborExtension(memberNameBytes, null, null); } - private static NativeCborExtension forStruct(Schema schema) { + private static CborExtension forStruct(Schema schema) { List members = schema.members(); if (members.isEmpty()) { - return new NativeCborExtension(null, null, null); + return new CborExtension(null, null, null); } CborMemberLookup memberLookup = new CborMemberLookup(members); @@ -78,6 +80,6 @@ private static NativeCborExtension forStruct(Schema schema) { fieldNameTable[idx] = CborSerializer.encodeMemberName(m.memberName()); } - return new NativeCborExtension(null, memberLookup, fieldNameTable); + return new CborExtension(null, memberLookup, fieldNameTable); } }