diff --git a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java index ce0646a82b8..e09dc95d2dd 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java +++ b/lang/java/avro/src/main/java/org/apache/avro/generic/GenericDatumReader.java @@ -17,12 +17,16 @@ */ package org.apache.avro.generic; +import java.io.EOFException; import java.io.IOException; import java.lang.reflect.Constructor; import java.nio.ByteBuffer; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; @@ -291,6 +295,7 @@ protected Object readArray(Object old, Schema expected, ResolvingDecoder in) thr long l = in.readArrayStart(); long base = 0; if (l > 0) { + ensureAvailableCollectionBytes(in, l, expectedType); LogicalType logicalType = expectedType.getLogicalType(); Conversion> conversion = getData().getConversionFor(logicalType); Object array = newArray(old, (int) l, expected); @@ -306,13 +311,25 @@ protected Object readArray(Object old, Schema expected, ResolvingDecoder in) thr } } base += l; - } while ((l = in.arrayNext()) > 0); + } while ((l = arrayNext(in, expectedType)) > 0); return pruneArray(array); } else { return pruneArray(newArray(old, 0, expected)); } } + /** + * Reads the next array block count and validates remaining bytes before the + * caller allocates storage. + */ + private long arrayNext(ResolvingDecoder in, Schema elementType) throws IOException { + long l = in.arrayNext(); + if (l > 0) { + ensureAvailableCollectionBytes(in, l, elementType); + } + return l; + } + private Object pruneArray(Object object) { if (object instanceof GenericArray>) { ((GenericArray>) object).prune(); @@ -348,6 +365,9 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw long l = in.readMapStart(); LogicalType logicalType = eValue.getLogicalType(); Conversion> conversion = getData().getConversionFor(logicalType); + if (l > 0) { + ensureAvailableMapBytes(in, l, eValue); + } Object map = newMap(old, (int) l); if (l > 0) { do { @@ -361,11 +381,40 @@ protected Object readMap(Object old, Schema expected, ResolvingDecoder in) throw addToMap(map, readMapKey(null, expected, in), readWithoutConversion(null, eValue, in)); } } - } while ((l = in.mapNext()) > 0); + } while ((l = mapNext(in, eValue)) > 0); } return map; } + /** + * Reads the next map block count and validates remaining bytes before the + * caller allocates storage. + */ + private long mapNext(ResolvingDecoder in, Schema valueType) throws IOException { + long l = in.mapNext(); + if (l > 0) { + ensureAvailableMapBytes(in, l, valueType); + } + return l; + } + + /** + * Validates remaining bytes for a map block. Each map entry has a string key + * (at least 1 byte for the length varint) plus a value, so the minimum bytes + * per entry is {@code 1 + minBytesPerElement(valueSchema)}. + */ + private static void ensureAvailableMapBytes(Decoder decoder, long count, Schema valueSchema) throws EOFException { + // Map keys are always strings: at least 1 byte for the length varint + int minBytesPerEntry = 1 + minBytesPerElement(valueSchema); + if (count > 0) { + int remaining = decoder.remainingBytes(); + if (remaining >= 0 && count * (long) minBytesPerEntry > remaining) { + throw new EOFException("Map claims " + count + " entries with at least " + minBytesPerEntry + + " bytes each, but only " + remaining + " bytes are available"); + } + } + } + /** * Called by the default implementation of {@link #readMap} to read a key value. * The default implementation returns delegates to @@ -384,6 +433,73 @@ protected void addToMap(Object map, Object key, Object value) { ((Map) map).put(key, value); } + /** + * Returns the minimum number of bytes required to encode a single value of the + * given schema in Avro binary format. Used to validate that the decoder has + * enough data remaining before allocating collection backing arrays. + *
+ * Returns 0 for types whose binary encoding is empty ({@code null}, zero-length
+ * {@code fixed}, records with only zero-byte fields). Returns a positive value
+ * for all other types.
+ */
+ static int minBytesPerElement(Schema schema) {
+ return minBytesPerElement(schema, Collections.newSetFromMap(new IdentityHashMap<>()));
+ }
+
+ private static int minBytesPerElement(Schema schema, Set
+ * This check prevents out-of-memory errors from pre-allocating huge backing
+ * arrays when the source data is truncated or malicious.
+ */
+ private static void ensureAvailableCollectionBytes(Decoder decoder, long count, Schema elementSchema)
+ throws EOFException {
+ int minBytes = minBytesPerElement(elementSchema);
+ if (minBytes > 0 && count > 0) {
+ int remaining = decoder.remainingBytes();
+ if (remaining >= 0 && count * (long) minBytes > remaining) {
+ throw new EOFException("Collection claims " + count + " elements with at least " + minBytes
+ + " bytes each, but only " + remaining + " bytes are available");
+ }
+ }
+ }
+
/**
* Called to read a fixed value. May be overridden for alternate fixed
* representations. By default, returns {@link GenericFixed}.
diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java
index 827b2fea3c7..22d86ca6504 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/io/BinaryDecoder.java
@@ -20,8 +20,10 @@
import org.apache.avro.AvroRuntimeException;
import org.apache.avro.InvalidNumberEncodingException;
import org.apache.avro.SystemLimitException;
+import org.apache.avro.util.ByteBufferInputStream;
import org.apache.avro.util.Utf8;
+import java.io.ByteArrayInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
@@ -295,6 +297,7 @@ public double readDouble() throws IOException {
@Override
public Utf8 readString(Utf8 old) throws IOException {
int length = SystemLimitException.checkMaxStringLength(readLong());
+ ensureAvailableBytes(length);
Utf8 result = (old != null ? old : new Utf8());
result.setByteLength(length);
if (0 != length) {
@@ -318,6 +321,7 @@ public void skipString() throws IOException {
@Override
public ByteBuffer readBytes(ByteBuffer old) throws IOException {
int length = SystemLimitException.checkMaxBytesLength(readLong());
+ ensureAvailableBytes(length);
final ByteBuffer result;
if (old != null && length <= old.capacity()) {
result = old;
@@ -508,6 +512,21 @@ public boolean isEnd() throws IOException {
return (0 == read);
}
+ /**
+ * Returns the total number of bytes remaining that can be read from this
+ * decoder (including any buffered bytes), or {@code -1} if the total is
+ * unknown.
+ *
+ * Byte-array-backed decoders return an exact count. InputStream-backed decoders
+ * return an exact count only when the wrapped stream can report one.
+ *
+ * {@link DirectBinaryDecoder} always returns {@code -1}.
+ */
+ @Override
+ public int remainingBytes() {
+ return source != null ? source.remainingBytes() : -1;
+ }
+
/**
* Ensures that buf[pos + num - 1] is not out of the buffer array bounds.
* However, buf[pos + num -1] may be >= limit if there is not enough data left
@@ -530,6 +549,27 @@ private void ensureBounds(int num) throws IOException {
}
}
+ /**
+ * Validates that the source has at least {@code length} bytes remaining before
+ * proceeding. Throws early if the declared length is inconsistent with the
+ * available data.
+ *
+ * This check is only applied when the decoder knows the exact remaining byte
+ * count.
+ *
+ * @param length the number of bytes expected to be available
+ * @throws EOFException if the source is known to have fewer bytes remaining
+ */
+ private void ensureAvailableBytes(int length) throws EOFException {
+ if (source != null && length > 0) {
+ int remaining = source.remainingBytes();
+ if (remaining >= 0 && length > remaining) {
+ throw new EOFException(
+ "Attempted to read " + length + " bytes, but only " + remaining + " bytes are available");
+ }
+ }
+ }
+
/**
* Returns an {@link java.io.InputStream} that is aware of any buffering that
* may occur in this BinaryDecoder. Readers that need to interleave decoding
@@ -664,6 +704,12 @@ protected ByteSource() {
abstract boolean isEof();
+ /**
+ * Returns the total number of bytes remaining that can be read from this source
+ * (including any buffered bytes), or {@code -1} if the total is unknown.
+ */
+ protected abstract int remainingBytes();
+
protected void attach(int bufferSize, BinaryDecoder decoder) {
decoder.buf = new byte[bufferSize];
decoder.pos = 0;
@@ -910,6 +956,19 @@ public boolean isEof() {
return isEof;
}
+ @Override
+ protected int remainingBytes() {
+ int buffered = ba.getLim() - ba.getPos();
+ try {
+ if (in.getClass() == ByteArrayInputStream.class || in.getClass() == ByteBufferInputStream.class) {
+ return buffered + in.available();
+ }
+ } catch (IOException e) {
+ return -1;
+ }
+ return -1;
+ }
+
@Override
public void close() throws IOException {
in.close();
@@ -1028,5 +1087,10 @@ public boolean isEof() {
int remaining = ba.getLim() - ba.getPos();
return (remaining == 0);
}
+
+ @Override
+ protected int remainingBytes() {
+ return ba.getLim() - ba.getPos();
+ }
}
}
diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java
index 11fc28d762e..80640a61aa0 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/io/Decoder.java
@@ -299,4 +299,14 @@ public void readFixed(byte[] bytes) throws IOException {
* type of the next value to be read
*/
public abstract int readIndex() throws IOException;
+
+ /**
+ * Returns the total number of bytes remaining that can be read from this
+ * decoder, or {@code -1} if the total is unknown. Implementations that can
+ * determine remaining capacity (for example, byte-array-backed decoders) should
+ * override this method. The default returns {@code -1}.
+ */
+ public int remainingBytes() {
+ return -1;
+ }
}
diff --git a/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java b/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java
index dbee4458575..26f79a16ff2 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/io/ValidatingDecoder.java
@@ -246,4 +246,9 @@ public int readIndex() throws IOException {
public Symbol doAction(Symbol input, Symbol top) throws IOException {
return null;
}
+
+ @Override
+ public int remainingBytes() {
+ return in != null ? in.remainingBytes() : -1;
+ }
}
diff --git a/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java b/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java
index 6abb62015dc..375abc23fbf 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/util/ByteBufferInputStream.java
@@ -65,6 +65,18 @@ public int read(byte[] b, int off, int len) throws IOException {
}
}
+ @Override
+ public int available() throws IOException {
+ long remaining = 0;
+ for (int i = current; i < buffers.size(); i++) {
+ remaining += buffers.get(i).remaining();
+ if (remaining >= Integer.MAX_VALUE) {
+ return Integer.MAX_VALUE;
+ }
+ }
+ return (int) remaining;
+ }
+
/**
* Read a buffer from the input without copying, if possible.
*/
diff --git a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
index f74dab95b0f..5586b828999 100644
--- a/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
+++ b/lang/java/avro/src/test/java/org/apache/avro/generic/TestGenericDatumReader.java
@@ -18,15 +18,24 @@
package org.apache.avro.generic;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import java.io.ByteArrayOutputStream;
+import java.io.EOFException;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.avro.Schema;
+import org.apache.avro.io.BinaryDecoder;
+import org.apache.avro.io.BinaryEncoder;
+import org.apache.avro.io.DecoderFactory;
+import org.apache.avro.io.EncoderFactory;
import org.junit.jupiter.api.Test;
public class TestGenericDatumReader {
@@ -117,4 +126,181 @@ private void sleep() {
}
}
}
+
+ // --- minBytesPerElement tests ---
+
+ @Test
+ void testMinBytesPerElementPrimitives() {
+ assertEquals(0, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.NULL)));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.BOOLEAN)));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.INT)));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.LONG)));
+ assertEquals(4, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.FLOAT)));
+ assertEquals(8, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.DOUBLE)));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.STRING)));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.create(Schema.Type.BYTES)));
+ }
+
+ @Test
+ void testMinBytesPerElementFixed() {
+ assertEquals(0, GenericDatumReader.minBytesPerElement(Schema.createFixed("ZeroFixed", null, "test", 0)));
+ assertEquals(5, GenericDatumReader.minBytesPerElement(Schema.createFixed("FiveFixed", null, "test", 5)));
+ assertEquals(16, GenericDatumReader.minBytesPerElement(Schema.createFixed("SixteenFixed", null, "test", 16)));
+ }
+
+ @Test
+ void testMinBytesPerElementUnion() {
+ // Union always >= 1 byte (branch index varint)
+ Schema nullableInt = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.INT));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(nullableInt));
+ }
+
+ @Test
+ void testMinBytesPerElementRecord() {
+ // Empty record = 0 bytes
+ Schema emptyRecord = Schema.createRecord("Empty", null, "test", false);
+ emptyRecord.setFields(Collections.emptyList());
+ assertEquals(0, GenericDatumReader.minBytesPerElement(emptyRecord));
+
+ // Record with a single non-null field >= 1 byte
+ Schema recWithInt = Schema.createRecord("WithInt", null, "test", false);
+ recWithInt.setFields(Collections.singletonList(new Schema.Field("x", Schema.create(Schema.Type.INT))));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(recWithInt));
+
+ // Record with only null fields = 0 bytes
+ Schema recWithNull = Schema.createRecord("WithNull", null, "test", false);
+ recWithNull.setFields(Collections.singletonList(new Schema.Field("n", Schema.create(Schema.Type.NULL))));
+ assertEquals(0, GenericDatumReader.minBytesPerElement(recWithNull));
+
+ Schema recWithMultipleFields = Schema.createRecord("WithMultipleFields", null, "test", false);
+ recWithMultipleFields.setFields(Arrays.asList(new Schema.Field("f", Schema.create(Schema.Type.FLOAT)),
+ new Schema.Field("d", Schema.create(Schema.Type.DOUBLE))));
+ assertEquals(12, GenericDatumReader.minBytesPerElement(recWithMultipleFields));
+ }
+
+ @Test
+ void testMinBytesPerElementNestedCollections() {
+ // Array and map types are >= 1 byte (count varint)
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.createArray(Schema.create(Schema.Type.INT))));
+ assertEquals(1, GenericDatumReader.minBytesPerElement(Schema.createMap(Schema.create(Schema.Type.INT))));
+ }
+
+ // --- Collection byte validation end-to-end tests ---
+
+ /**
+ * Encodes the given longs as Avro varints into a byte array.
+ */
+ private static byte[] encodeVarints(long... values) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ BinaryEncoder enc = EncoderFactory.get().directBinaryEncoder(baos, null);
+ for (long v : values) {
+ enc.writeLong(v);
+ }
+ enc.flush();
+ return baos.toByteArray();
+ }
+
+ /**
+ * Verify that reading an array of ints with a huge count but no element data
+ * throws EOFException from the schema-aware byte check.
+ */
+ @Test
+ void arrayOfIntsRejectsHugeCount() throws Exception {
+ Schema schema = Schema.createArray(Schema.create(Schema.Type.INT));
+ GenericDatumReader