diff --git a/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java b/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java index 9b9a89034..67f65162c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java +++ b/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java @@ -121,31 +121,33 @@ private void loadBuffers( int bufferLayoutCount = (int) (variadicBufferLayoutCount + TypeLayout.getTypeBufferCount(field.getType())); List ownBuffers = new ArrayList<>(bufferLayoutCount); - for (int j = 0; j < bufferLayoutCount; j++) { - if (!buffers.hasNext()) { - throw new IllegalArgumentException( - "no more buffers for field " + field + ". Expected " + bufferLayoutCount); + try { + for (int j = 0; j < bufferLayoutCount; j++) { + if (!buffers.hasNext()) { + throw new IllegalArgumentException( + "no more buffers for field " + field + ". Expected " + bufferLayoutCount); + } + ArrowBuf nextBuf = buffers.next(); + // for vectors without nulls, the buffer is empty, so there is no need to decompress it. + ArrowBuf bufferToAdd = + nextBuf.writerIndex() > 0 ? codec.decompress(vector.getAllocator(), nextBuf) : nextBuf; + ownBuffers.add(bufferToAdd); + if (decompressionNeeded) { + nextBuf.getReferenceManager().retain(); + } } - ArrowBuf nextBuf = buffers.next(); - // for vectors without nulls, the buffer is empty, so there is no need to decompress it. - ArrowBuf bufferToAdd = - nextBuf.writerIndex() > 0 ? codec.decompress(vector.getAllocator(), nextBuf) : nextBuf; - ownBuffers.add(bufferToAdd); - if (decompressionNeeded) { - // decompression performed - nextBuf.getReferenceManager().retain(); + try { + vector.loadFieldBuffers(fieldNode, ownBuffers); + } catch (RuntimeException e) { + throw new IllegalArgumentException( + "Could not load buffers for field " + field + ". error message: " + e.getMessage(), e); } - } - try { - vector.loadFieldBuffers(fieldNode, ownBuffers); + } finally { if (decompressionNeeded) { for (ArrowBuf buf : ownBuffers) { buf.close(); } } - } catch (RuntimeException e) { - throw new IllegalArgumentException( - "Could not load buffers for field " + field + ". error message: " + e.getMessage(), e); } List children = field.getChildren(); if (children.size() > 0) { diff --git a/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java b/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java index 782535fcc..35b9a7116 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestVectorUnloadLoad.java @@ -37,6 +37,9 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.complex.writer.BigIntWriter; import org.apache.arrow.vector.complex.writer.IntWriter; +import org.apache.arrow.vector.compression.CompressionCodec; +import org.apache.arrow.vector.compression.CompressionUtil; +import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -348,4 +351,64 @@ public static VectorUnloader newVectorUnloader(FieldVector root) { VectorSchemaRoot vsr = new VectorSchemaRoot(schema.getFields(), fields, valueCount); return new VectorUnloader(vsr); } + + @Test + public void testLoadReleasesBuffersOnDecompressionFailure() { + Schema schema = new Schema(asList(Field.nullable("int", new ArrowType.Int(32, true)))); + CompressionCodec.Factory failingFactory = + new CompressionCodec.Factory() { + @Override + public CompressionCodec createCodec(CompressionUtil.CodecType codecType) { + return new CompressionCodec() { + @Override + public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) { + throw new UnsupportedOperationException(); + } + + @Override + public ArrowBuf decompress(BufferAllocator allocator, ArrowBuf compressedBuffer) { + throw new RuntimeException("simulated decompression failure"); + } + + @Override + public CompressionUtil.CodecType getCodecType() { + return codecType; + } + }; + } + + @Override + public CompressionCodec createCodec( + CompressionUtil.CodecType codecType, int compressionLevel) { + return createCodec(codecType); + } + }; + + try (BufferAllocator testAllocator = + allocator.newChildAllocator("test", 0, Integer.MAX_VALUE)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, testAllocator)) { + VectorLoader loader = new VectorLoader(root, failingFactory); + ArrowBodyCompression compression = + new ArrowBodyCompression( + CompressionUtil.CodecType.LZ4_FRAME.getType(), + org.apache.arrow.flatbuf.BodyCompressionMethod.BUFFER); + List nodes = asList(new ArrowFieldNode(1, 0)); + ArrowBuf validityBuf = testAllocator.buffer(8); + validityBuf.writerIndex(8); + ArrowBuf dataBuf = testAllocator.buffer(4); + dataBuf.writerIndex(4); + try (ArrowRecordBatch batch = + new ArrowRecordBatch(1, nodes, asList(validityBuf, dataBuf), compression)) { + RuntimeException ex = + org.junit.jupiter.api.Assertions.assertThrows( + RuntimeException.class, () -> loader.load(batch)); + assertTrue(ex.getMessage().contains("simulated decompression failure")); + } finally { + validityBuf.close(); + dataBuf.close(); + } + } + assertEquals(0, testAllocator.getAllocatedMemory()); + } + } }