From 84de9c69b195be94cb483300d6b304da820603c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Sun, 17 May 2026 22:36:20 +0200 Subject: [PATCH] GH-3530: Optimize BYTE_STREAM_SPLIT encoding/decoding Reader: replace generic ByteBuffer.get() transpose loop in decodeData() with specialized single-pass loops for element sizes 2/4/8/12/16 bytes plus a stream-oriented generic fallback. Bulk-access the backing array directly when available, falling back to a single bulk copy for direct buffers. Writer: replace per-value scatterBytes() (which allocates a temp byte[] and issues N single-byte stream writes) with batched scatter buffers. Int/Long values accumulate in int[]/long[] batches of 64 and flush as bulk write(byte[], off, len) calls -- one per stream. FLBA uses per-stream byte[][] scratch buffers with the same batching strategy. getBufferedSize() now accounts for unflushed batch values. Add JMH benchmarks for scalar encode/decode of all 5 BSS types (FLOAT, DOUBLE, INT32, INT64, FIXED_LEN_BYTE_ARRAY). Add TestDataFactory for deterministic FLBA benchmark data generation. Add unit tests for transpose specializations, batch-boundary crossing, getBufferedSize with partial batches, direct ByteBuffer decode paths, and close/reset with pending unflushed batches. --- .../ByteStreamSplitDecodingBenchmark.java | 273 +++++++++++ .../ByteStreamSplitEncodingBenchmark.java | 166 +++++++ .../parquet/benchmarks/TestDataFactory.java | 66 +++ .../ByteStreamSplitValuesReader.java | 137 +++++- .../ByteStreamSplitValuesWriter.java | 145 +++++- .../ByteStreamSplitScalarOptTest.java | 462 ++++++++++++++++++ pom.xml | 1 + 7 files changed, 1227 insertions(+), 23 deletions(-) create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitDecodingBenchmark.java create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitEncodingBenchmark.java create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/TestDataFactory.java create mode 100644 parquet-column/src/test/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitScalarOptTest.java diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitDecodingBenchmark.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitDecodingBenchmark.java new file mode 100644 index 0000000000..f0eeca4405 --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitDecodingBenchmark.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.benchmarks; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.HeapByteBufferAllocator; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReader; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReaderForDouble; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReaderForFLBA; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReaderForFloat; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReaderForInteger; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReaderForLong; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesWriter; +import org.apache.parquet.io.api.Binary; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Decoding-level micro-benchmarks for the BYTE_STREAM_SPLIT encoding across all + * Parquet types that support it: {@code FLOAT}, {@code DOUBLE}, {@code INT32}, + * {@code INT64}, and {@code FIXED_LEN_BYTE_ARRAY}. + * + *

Fixed-width numeric types are benchmarked directly by top-level methods. + * {@code FIXED_LEN_BYTE_ARRAY} uses an inner {@link FlbaState} parameterised by + * {@code fixedLength} to avoid cross-product pollution with the numeric benchmarks. + * + *

Each invocation decodes {@value #VALUE_COUNT} values; throughput is reported + * per-value via {@link OperationsPerInvocation}. The cost includes both + * {@code initFromPage} (which eagerly transposes the entire page) and the per-value + * read calls. Page transposition is the part this benchmark is primarily designed + * to exercise. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(1) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +@State(Scope.Thread) +public class ByteStreamSplitDecodingBenchmark { + + static final int VALUE_COUNT = 100_000; + private static final int INIT_SLAB_SIZE = 64 * 1024; + private static final int PAGE_SIZE = 4 * 1024 * 1024; + + private byte[] floatPage; + private byte[] doublePage; + private byte[] intPage; + private byte[] longPage; + + @Setup(Level.Trial) + public void setup() throws IOException { + Random random = new Random(42); + int[] intData = new int[VALUE_COUNT]; + long[] longData = new long[VALUE_COUNT]; + float[] floatData = new float[VALUE_COUNT]; + double[] doubleData = new double[VALUE_COUNT]; + for (int i = 0; i < VALUE_COUNT; i++) { + intData[i] = random.nextInt(); + longData[i] = random.nextLong(); + floatData[i] = random.nextFloat(); + doubleData[i] = random.nextDouble(); + } + + { + ValuesWriter w = new ByteStreamSplitValuesWriter.FloatByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (float v : floatData) { + w.writeFloat(v); + } + floatPage = w.getBytes().toByteArray(); + w.close(); + } + { + ValuesWriter w = new ByteStreamSplitValuesWriter.DoubleByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (double v : doubleData) { + w.writeDouble(v); + } + doublePage = w.getBytes().toByteArray(); + w.close(); + } + { + ValuesWriter w = new ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (int v : intData) { + w.writeInteger(v); + } + intPage = w.getBytes().toByteArray(); + w.close(); + } + { + ValuesWriter w = new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (long v : longData) { + w.writeLong(v); + } + longPage = w.getBytes().toByteArray(); + w.close(); + } + } + + private static void init(ByteStreamSplitValuesReader r, byte[] page) throws IOException { + r.initFromPage(VALUE_COUNT, ByteBufferInputStream.wrap(ByteBuffer.wrap(page))); + } + + private static void initDirect(ByteStreamSplitValuesReader r, ByteBuffer directPage) throws IOException { + directPage.clear(); // reset position to 0 for each invocation + r.initFromPage(VALUE_COUNT, ByteBufferInputStream.wrap(directPage)); + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeFloat(Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForFloat r = new ByteStreamSplitValuesReaderForFloat(); + init(r, floatPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readFloat()); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeDouble(Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForDouble r = new ByteStreamSplitValuesReaderForDouble(); + init(r, doublePage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readDouble()); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeInt(Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForInteger r = new ByteStreamSplitValuesReaderForInteger(); + init(r, intPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readInteger()); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeLong(Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForLong r = new ByteStreamSplitValuesReaderForLong(); + init(r, longPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readLong()); + } + } + + // ---- FIXED_LEN_BYTE_ARRAY (parameterised by fixedLength) ---- + + @State(Scope.Thread) + public static class FlbaState { + @Param({"2", "7", "12", "16"}) + public int fixedLength; + + byte[] flbaPage; + + @Setup(Level.Trial) + public void setup() throws IOException { + Binary[] data = TestDataFactory.generateFixedLenByteArrays( + VALUE_COUNT, fixedLength, 0, TestDataFactory.DEFAULT_SEED); + ValuesWriter w = new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + fixedLength, INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (Binary v : data) { + w.writeBytes(v); + } + flbaPage = w.getBytes().toByteArray(); + w.close(); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeFlba(FlbaState state, Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForFLBA r = new ByteStreamSplitValuesReaderForFLBA(state.fixedLength); + init(r, state.flbaPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readBytes()); + } + } + + // ---- Direct ByteBuffer decode (exercises the !hasArray() path in decodeData) ---- + + @State(Scope.Thread) + public static class DirectBufferState { + ByteBuffer directFloatPage; + ByteBuffer directLongPage; + + @Setup(Level.Trial) + public void setup() throws IOException { + Random random = new Random(42); + { + ValuesWriter w = new ByteStreamSplitValuesWriter.FloatByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (int i = 0; i < VALUE_COUNT; i++) { + w.writeFloat(random.nextFloat()); + } + byte[] page = w.getBytes().toByteArray(); + w.close(); + directFloatPage = ByteBuffer.allocateDirect(page.length); + directFloatPage.put(page); + directFloatPage.flip(); + } + { + ValuesWriter w = new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (int i = 0; i < VALUE_COUNT; i++) { + w.writeLong(random.nextLong()); + } + byte[] page = w.getBytes().toByteArray(); + w.close(); + directLongPage = ByteBuffer.allocateDirect(page.length); + directLongPage.put(page); + directLongPage.flip(); + } + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeFloatDirect(DirectBufferState state, Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForFloat r = new ByteStreamSplitValuesReaderForFloat(); + initDirect(r, state.directFloatPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readFloat()); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public void decodeLongDirect(DirectBufferState state, Blackhole bh) throws IOException { + ByteStreamSplitValuesReaderForLong r = new ByteStreamSplitValuesReaderForLong(); + initDirect(r, state.directLongPage); + for (int i = 0; i < VALUE_COUNT; i++) { + bh.consume(r.readLong()); + } + } +} diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitEncodingBenchmark.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitEncodingBenchmark.java new file mode 100644 index 0000000000..fa1516d1da --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ByteStreamSplitEncodingBenchmark.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.benchmarks; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.parquet.bytes.HeapByteBufferAllocator; +import org.apache.parquet.column.values.ValuesWriter; +import org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesWriter; +import org.apache.parquet.io.api.Binary; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +/** + * Encoding-level micro-benchmarks for the BYTE_STREAM_SPLIT encoding across all + * Parquet types that support it: {@code FLOAT}, {@code DOUBLE}, {@code INT32}, + * {@code INT64}, and {@code FIXED_LEN_BYTE_ARRAY}. + * + *

Fixed-width numeric types are benchmarked directly by top-level methods. + * {@code FIXED_LEN_BYTE_ARRAY} uses an inner {@link FlbaState} parameterised by + * {@code fixedLength} to avoid cross-product pollution with the numeric benchmarks. + * + *

Each invocation encodes {@value #VALUE_COUNT} values; throughput is reported + * per-value via {@link OperationsPerInvocation}. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(1) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +@State(Scope.Thread) +public class ByteStreamSplitEncodingBenchmark { + + static final int VALUE_COUNT = 100_000; + private static final int INIT_SLAB_SIZE = 64 * 1024; + private static final int PAGE_SIZE = 4 * 1024 * 1024; + + private int[] intData; + private long[] longData; + private float[] floatData; + private double[] doubleData; + + @Setup(Level.Trial) + public void setup() { + Random random = new Random(42); + intData = new int[VALUE_COUNT]; + longData = new long[VALUE_COUNT]; + floatData = new float[VALUE_COUNT]; + doubleData = new double[VALUE_COUNT]; + for (int i = 0; i < VALUE_COUNT; i++) { + intData[i] = random.nextInt(); + longData[i] = random.nextLong(); + floatData[i] = random.nextFloat(); + doubleData[i] = random.nextDouble(); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public byte[] encodeFloat() throws IOException { + ValuesWriter w = new ByteStreamSplitValuesWriter.FloatByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (float v : floatData) { + w.writeFloat(v); + } + byte[] bytes = w.getBytes().toByteArray(); + w.close(); + return bytes; + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public byte[] encodeDouble() throws IOException { + ValuesWriter w = new ByteStreamSplitValuesWriter.DoubleByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (double v : doubleData) { + w.writeDouble(v); + } + byte[] bytes = w.getBytes().toByteArray(); + w.close(); + return bytes; + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public byte[] encodeInt() throws IOException { + ValuesWriter w = new ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (int v : intData) { + w.writeInteger(v); + } + byte[] bytes = w.getBytes().toByteArray(); + w.close(); + return bytes; + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public byte[] encodeLong() throws IOException { + ValuesWriter w = new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (long v : longData) { + w.writeLong(v); + } + byte[] bytes = w.getBytes().toByteArray(); + w.close(); + return bytes; + } + + // ---- FIXED_LEN_BYTE_ARRAY (parameterised by fixedLength) ---- + + @State(Scope.Thread) + public static class FlbaState { + @Param({"2", "7", "12", "16"}) + public int fixedLength; + + Binary[] data; + + @Setup(Level.Trial) + public void setup() { + data = TestDataFactory.generateFixedLenByteArrays( + VALUE_COUNT, fixedLength, 0, TestDataFactory.DEFAULT_SEED); + } + } + + @Benchmark + @OperationsPerInvocation(VALUE_COUNT) + public byte[] encodeFlba(FlbaState state) throws IOException { + ValuesWriter w = new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + state.fixedLength, INIT_SLAB_SIZE, PAGE_SIZE, new HeapByteBufferAllocator()); + for (Binary v : state.data) { + w.writeBytes(v); + } + byte[] bytes = w.getBytes().toByteArray(); + w.close(); + return bytes; + } +} diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/TestDataFactory.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/TestDataFactory.java new file mode 100644 index 0000000000..f1df60e07d --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/TestDataFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.benchmarks; + +import java.util.Random; +import org.apache.parquet.io.api.Binary; + +/** + * Utility class for generating test data for encoding benchmarks. + */ +public final class TestDataFactory { + + /** Default RNG seed used across benchmarks for deterministic data. */ + public static final long DEFAULT_SEED = 42L; + + private TestDataFactory() {} + + /** + * Generates fixed-length byte arrays with the specified cardinality. + * + * @param count number of values + * @param length byte length of each value + * @param distinct number of distinct values (0 means all unique) + * @param seed RNG seed + */ + public static Binary[] generateFixedLenByteArrays(int count, int length, int distinct, long seed) { + Random random = new Random(seed); + if (distinct > 0) { + Binary[] palette = new Binary[distinct]; + for (int i = 0; i < distinct; i++) { + byte[] bytes = new byte[length]; + random.nextBytes(bytes); + palette[i] = Binary.fromConstantByteArray(bytes); + } + Binary[] data = new Binary[count]; + for (int i = 0; i < count; i++) { + data[i] = palette[random.nextInt(distinct)]; + } + return data; + } else { + Binary[] data = new Binary[count]; + for (int i = 0; i < count; i++) { + byte[] bytes = new byte[length]; + random.nextBytes(bytes); + data[i] = Binary.fromConstantByteArray(bytes); + } + return data; + } + } +} diff --git a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesReader.java b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesReader.java index c8ab3043bd..2dddf84d0a 100644 --- a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesReader.java +++ b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesReader.java @@ -49,17 +49,138 @@ protected int nextElementByteOffset() { return offset; } - // Decode an entire data page + // Decode an entire data page by transposing from stream-split layout to interleaved layout. private byte[] decodeData(ByteBuffer encoded, int valuesCount) { - assert encoded.limit() == valuesCount * elementSizeInBytes; - byte[] decoded = new byte[encoded.limit()]; - int destByteIndex = 0; - for (int srcValueIndex = 0; srcValueIndex < valuesCount; ++srcValueIndex) { - for (int stream = 0; stream < elementSizeInBytes; ++stream, ++destByteIndex) { - decoded[destByteIndex] = encoded.get(srcValueIndex + stream * valuesCount); + int totalBytes = valuesCount * elementSizeInBytes; + assert encoded.remaining() >= totalBytes; + + // Bulk access: use the backing array directly if available, otherwise copy once. + byte[] src; + int srcBase; + if (encoded.hasArray()) { + src = encoded.array(); + srcBase = encoded.arrayOffset() + encoded.position(); + } else { + src = new byte[totalBytes]; + encoded.get(src); + srcBase = 0; + } + + byte[] decoded = new byte[totalBytes]; + + // Specialized single-pass loops for common element sizes. + if (elementSizeInBytes == 2) { + int s0 = srcBase, s1 = srcBase + valuesCount; + for (int i = 0; i < valuesCount; ++i) { + int di = i * 2; + decoded[di] = src[s0 + i]; + decoded[di + 1] = src[s1 + i]; + } + } else if (elementSizeInBytes == 4) { + int s0 = srcBase, + s1 = srcBase + valuesCount, + s2 = srcBase + 2 * valuesCount, + s3 = srcBase + 3 * valuesCount; + for (int i = 0; i < valuesCount; ++i) { + int di = i * 4; + decoded[di] = src[s0 + i]; + decoded[di + 1] = src[s1 + i]; + decoded[di + 2] = src[s2 + i]; + decoded[di + 3] = src[s3 + i]; + } + } else if (elementSizeInBytes == 8) { + int s0 = srcBase, + s1 = srcBase + valuesCount, + s2 = srcBase + 2 * valuesCount, + s3 = srcBase + 3 * valuesCount, + s4 = srcBase + 4 * valuesCount, + s5 = srcBase + 5 * valuesCount, + s6 = srcBase + 6 * valuesCount, + s7 = srcBase + 7 * valuesCount; + for (int i = 0; i < valuesCount; ++i) { + int di = i * 8; + decoded[di] = src[s0 + i]; + decoded[di + 1] = src[s1 + i]; + decoded[di + 2] = src[s2 + i]; + decoded[di + 3] = src[s3 + i]; + decoded[di + 4] = src[s4 + i]; + decoded[di + 5] = src[s5 + i]; + decoded[di + 6] = src[s6 + i]; + decoded[di + 7] = src[s7 + i]; + } + } else if (elementSizeInBytes == 12) { + int s0 = srcBase, + s1 = srcBase + valuesCount, + s2 = srcBase + 2 * valuesCount, + s3 = srcBase + 3 * valuesCount, + s4 = srcBase + 4 * valuesCount, + s5 = srcBase + 5 * valuesCount, + s6 = srcBase + 6 * valuesCount, + s7 = srcBase + 7 * valuesCount, + s8 = srcBase + 8 * valuesCount, + s9 = srcBase + 9 * valuesCount, + s10 = srcBase + 10 * valuesCount, + s11 = srcBase + 11 * valuesCount; + for (int i = 0; i < valuesCount; ++i) { + int di = i * 12; + decoded[di] = src[s0 + i]; + decoded[di + 1] = src[s1 + i]; + decoded[di + 2] = src[s2 + i]; + decoded[di + 3] = src[s3 + i]; + decoded[di + 4] = src[s4 + i]; + decoded[di + 5] = src[s5 + i]; + decoded[di + 6] = src[s6 + i]; + decoded[di + 7] = src[s7 + i]; + decoded[di + 8] = src[s8 + i]; + decoded[di + 9] = src[s9 + i]; + decoded[di + 10] = src[s10 + i]; + decoded[di + 11] = src[s11 + i]; + } + } else if (elementSizeInBytes == 16) { + int s0 = srcBase, + s1 = srcBase + valuesCount, + s2 = srcBase + 2 * valuesCount, + s3 = srcBase + 3 * valuesCount, + s4 = srcBase + 4 * valuesCount, + s5 = srcBase + 5 * valuesCount, + s6 = srcBase + 6 * valuesCount, + s7 = srcBase + 7 * valuesCount, + s8 = srcBase + 8 * valuesCount, + s9 = srcBase + 9 * valuesCount, + s10 = srcBase + 10 * valuesCount, + s11 = srcBase + 11 * valuesCount, + s12 = srcBase + 12 * valuesCount, + s13 = srcBase + 13 * valuesCount, + s14 = srcBase + 14 * valuesCount, + s15 = srcBase + 15 * valuesCount; + for (int i = 0; i < valuesCount; ++i) { + int di = i * 16; + decoded[di] = src[s0 + i]; + decoded[di + 1] = src[s1 + i]; + decoded[di + 2] = src[s2 + i]; + decoded[di + 3] = src[s3 + i]; + decoded[di + 4] = src[s4 + i]; + decoded[di + 5] = src[s5 + i]; + decoded[di + 6] = src[s6 + i]; + decoded[di + 7] = src[s7 + i]; + decoded[di + 8] = src[s8 + i]; + decoded[di + 9] = src[s9 + i]; + decoded[di + 10] = src[s10 + i]; + decoded[di + 11] = src[s11 + i]; + decoded[di + 12] = src[s12 + i]; + decoded[di + 13] = src[s13 + i]; + decoded[di + 14] = src[s14 + i]; + decoded[di + 15] = src[s15 + i]; + } + } else { + // Generic fallback for arbitrary element sizes + for (int stream = 0; stream < elementSizeInBytes; ++stream) { + int srcOffset = srcBase + stream * valuesCount; + for (int i = 0; i < valuesCount; ++i) { + decoded[i * elementSizeInBytes + stream] = src[srcOffset + i]; + } } } - assert destByteIndex == decoded.length; return decoded; } diff --git a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java index c197a4fd6f..889c0d5f5c 100644 --- a/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java +++ b/parquet-column/src/main/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitValuesWriter.java @@ -20,7 +20,6 @@ import org.apache.parquet.bytes.ByteBufferAllocator; import org.apache.parquet.bytes.BytesInput; -import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.bytes.CapacityByteArrayOutputStream; import org.apache.parquet.column.Encoding; import org.apache.parquet.column.values.ValuesWriter; @@ -29,9 +28,23 @@ public abstract class ByteStreamSplitValuesWriter extends ValuesWriter { + /** + * Batch size for buffered scatter writes. Values are accumulated in a batch buffer + * and flushed as bulk {@code write(byte[], off, len)} calls to each stream, replacing + * N individual single-byte writes with one bulk write per stream per flush. + */ + private static final int BATCH_SIZE = 64; + protected final int numStreams; protected final int elementSizeInBytes; - private final CapacityByteArrayOutputStream[] byteStreams; + protected final CapacityByteArrayOutputStream[] byteStreams; + + // Batch buffers for int (4-byte) and long (8-byte) scatter writes. + // Only one of these is ever non-null per instance. + private int[] intBatch; + private long[] longBatch; + private byte[] scatterBuf; + private int batchCount; public ByteStreamSplitValuesWriter( int elementSizeInBytes, int initialCapacity, int pageSize, ByteBufferAllocator allocator) { @@ -53,7 +66,8 @@ public ByteStreamSplitValuesWriter( @Override public long getBufferedSize() { - long totalSize = 0; + // Include unflushed batch values without triggering a flush + long totalSize = (long) batchCount * elementSizeInBytes; for (CapacityByteArrayOutputStream stream : this.byteStreams) { totalSize += stream.size(); } @@ -62,6 +76,7 @@ public long getBufferedSize() { @Override public BytesInput getBytes() { + flushBatch(); BytesInput[] allInputs = new BytesInput[this.numStreams]; for (int i = 0; i < this.numStreams; ++i) { allInputs[i] = BytesInput.from(this.byteStreams[i]); @@ -76,6 +91,7 @@ public Encoding getEncoding() { @Override public void reset() { + batchCount = 0; for (CapacityByteArrayOutputStream stream : this.byteStreams) { stream.reset(); } @@ -83,20 +99,75 @@ public void reset() { @Override public void close() { + batchCount = 0; for (CapacityByteArrayOutputStream stream : byteStreams) { stream.close(); } } - protected void scatterBytes(byte[] bytes) { - if (bytes.length != this.numStreams) { - throw new ParquetEncodingException(String.format( - "Number of bytes doesn't match the number of streams. Num butes: %d, Num streams: %d", - bytes.length, this.numStreams)); + /** + * Buffer a 4-byte integer value for batched scatter to the byte streams. + * Values are accumulated until the batch is full, then flushed as bulk + * {@code write(byte[], off, len)} calls -- one per stream. + */ + protected void bufferInt(int v) { + if (intBatch == null) { + intBatch = new int[BATCH_SIZE]; + scatterBuf = new byte[BATCH_SIZE]; + } + intBatch[batchCount++] = v; + if (batchCount == BATCH_SIZE) { + flushIntBatch(); + } + } + + /** + * Buffer an 8-byte long value for batched scatter to the byte streams. + */ + protected void bufferLong(long v) { + if (longBatch == null) { + longBatch = new long[BATCH_SIZE]; + scatterBuf = new byte[BATCH_SIZE]; } - for (int i = 0; i < bytes.length; ++i) { - this.byteStreams[i].write(bytes[i]); + longBatch[batchCount++] = v; + if (batchCount == BATCH_SIZE) { + flushLongBatch(); + } + } + + private void flushBatch() { + if (batchCount == 0) return; + if (intBatch != null) { + flushIntBatch(); + } else if (longBatch != null) { + flushLongBatch(); + } + } + + private void flushIntBatch() { + if (batchCount == 0) return; + final int count = batchCount; + for (int stream = 0; stream < 4; stream++) { + final int shift = stream << 3; // stream * 8 + for (int i = 0; i < count; i++) { + scatterBuf[i] = (byte) (intBatch[i] >>> shift); + } + byteStreams[stream].write(scatterBuf, 0, count); + } + batchCount = 0; + } + + private void flushLongBatch() { + if (batchCount == 0) return; + final int count = batchCount; + for (int stream = 0; stream < 8; stream++) { + final int shift = stream << 3; // stream * 8 + for (int i = 0; i < count; i++) { + scatterBuf[i] = (byte) (longBatch[i] >>> shift); + } + byteStreams[stream].write(scatterBuf, 0, count); } + batchCount = 0; } @Override @@ -116,7 +187,7 @@ public FloatByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteB @Override public void writeFloat(float v) { - super.scatterBytes(BytesUtils.intToBytes(Float.floatToIntBits(v))); + bufferInt(Float.floatToIntBits(v)); } @Override @@ -133,7 +204,7 @@ public DoubleByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byte @Override public void writeDouble(double v) { - super.scatterBytes(BytesUtils.longToBytes(Double.doubleToLongBits(v))); + bufferLong(Double.doubleToLongBits(v)); } @Override @@ -149,7 +220,7 @@ public IntegerByteStreamSplitValuesWriter(int initialCapacity, int pageSize, Byt @Override public void writeInteger(int v) { - super.scatterBytes(BytesUtils.intToBytes(v)); + bufferInt(v); } @Override @@ -165,7 +236,7 @@ public LongByteStreamSplitValuesWriter(int initialCapacity, int pageSize, ByteBu @Override public void writeLong(long v) { - super.scatterBytes(BytesUtils.longToBytes(v)); + bufferLong(v); } @Override @@ -176,6 +247,8 @@ public String memUsageString(String prefix) { public static class FixedLenByteArrayByteStreamSplitValuesWriter extends ByteStreamSplitValuesWriter { private final int length; + private byte[][] batchBufs; // [stream][batchIndex] scratch buffers + private int flbaBatchCount; public FixedLenByteArrayByteStreamSplitValuesWriter( int length, int initialCapacity, int pageSize, ByteBufferAllocator allocator) { @@ -187,7 +260,49 @@ public FixedLenByteArrayByteStreamSplitValuesWriter( public final void writeBytes(Binary v) { assert (v.length() == length) : ("Fixed Binary size " + v.length() + " does not match field type length " + length); - super.scatterBytes(v.getBytesUnsafe()); + if (batchBufs == null) { + batchBufs = new byte[length][BATCH_SIZE]; + } + byte[] bytes = v.getBytesUnsafe(); + for (int stream = 0; stream < length; stream++) { + batchBufs[stream][flbaBatchCount] = bytes[stream]; + } + flbaBatchCount++; + if (flbaBatchCount == BATCH_SIZE) { + flushFlbaBatch(); + } + } + + private void flushFlbaBatch() { + if (flbaBatchCount == 0) return; + final int count = flbaBatchCount; + for (int stream = 0; stream < length; stream++) { + byteStreams[stream].write(batchBufs[stream], 0, count); + } + flbaBatchCount = 0; + } + + @Override + public BytesInput getBytes() { + flushFlbaBatch(); + return super.getBytes(); + } + + @Override + public void reset() { + flbaBatchCount = 0; + super.reset(); + } + + @Override + public void close() { + flbaBatchCount = 0; + super.close(); + } + + @Override + public long getBufferedSize() { + return super.getBufferedSize() + (long) flbaBatchCount * length; } @Override diff --git a/parquet-column/src/test/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitScalarOptTest.java b/parquet-column/src/test/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitScalarOptTest.java new file mode 100644 index 0000000000..af2bcd7cd0 --- /dev/null +++ b/parquet-column/src/test/java/org/apache/parquet/column/values/bytestreamsplit/ByteStreamSplitScalarOptTest.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.column.values.bytestreamsplit; + +import static org.junit.Assert.assertEquals; + +import java.nio.ByteBuffer; +import java.util.Random; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.DirectByteBufferAllocator; +import org.apache.parquet.io.api.Binary; +import org.junit.Test; + +/** + * Tests for the BYTE_STREAM_SPLIT scalar performance optimizations: + * internal scatter-buffer batching, decodeData transpose specializations, + * direct ByteBuffer decode path, getBufferedSize accuracy, and close/reset + * edge cases. + */ +public class ByteStreamSplitScalarOptTest { + + private static final int BATCH_SIZE = 64; // matches ByteStreamSplitValuesWriter.BATCH_SIZE + + // --------------------------------------------------------------------------- + // decodeData transpose specializations: element sizes 2, 12, 16 + // (sizes 4 and 8 are already covered by int/float and long/double tests) + // --------------------------------------------------------------------------- + + @Test + public void testFlbaTransposeSize2() throws Exception { + flbaRoundTrip(2, 512); + } + + @Test + public void testFlbaTransposeSize12() throws Exception { + flbaRoundTrip(12, 256); + } + + @Test + public void testFlbaTransposeSize16() throws Exception { + flbaRoundTrip(16, 256); + } + + /** Generic FLBA round-trip used to exercise specific element-size transpose paths. */ + private void flbaRoundTrip(int typeLength, int numElements) throws Exception { + Random rand = new Random(42); + Binary[] values = new Binary[numElements]; + for (int i = 0; i < numElements; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + values[i] = Binary.fromConstantByteArray(bytes); + } + + ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + typeLength, + numElements * typeLength, + numElements * typeLength, + new DirectByteBufferAllocator()); + // Use scalar writes to test the writer's scatter path independently from batch writes. + for (Binary v : values) { + writer.writeBytes(v); + } + BytesInput input = writer.getBytes(); + assertEquals(numElements * typeLength, input.size()); + + ByteStreamSplitValuesReaderForFLBA reader = new ByteStreamSplitValuesReaderForFLBA(typeLength); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(input.toByteBuffer())); + + // Scalar read to verify each value + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readBytes()); + } + + writer.reset(); + writer.close(); + } + + /** Also test the generic fallback path with an odd element size (e.g. 5, 7). */ + @Test + public void testFlbaTransposeGenericFallback() throws Exception { + flbaRoundTrip(5, 256); + flbaRoundTrip(7, 256); + } + + // --------------------------------------------------------------------------- + // BATCH_SIZE boundary crossing: tests that internal batch flush works correctly + // when writing exactly BATCH_SIZE, BATCH_SIZE+1, and multi-BATCH_SIZE counts + // --------------------------------------------------------------------------- + + @Test + public void testIntegerWriteExactBatchSize() throws Exception { + intScalarRoundTrip(BATCH_SIZE); + } + + @Test + public void testIntegerWriteBatchSizePlusOne() throws Exception { + intScalarRoundTrip(BATCH_SIZE + 1); + } + + @Test + public void testIntegerWriteMultipleBatches() throws Exception { + intScalarRoundTrip(BATCH_SIZE * 3 + 17); + } + + @Test + public void testLongWriteExactBatchSize() throws Exception { + longScalarRoundTrip(BATCH_SIZE); + } + + @Test + public void testLongWriteBatchSizePlusOne() throws Exception { + longScalarRoundTrip(BATCH_SIZE + 1); + } + + @Test + public void testLongWriteMultipleBatches() throws Exception { + longScalarRoundTrip(BATCH_SIZE * 3 + 17); + } + + private void intScalarRoundTrip(int numElements) throws Exception { + Random rand = new Random(42); + int[] values = rand.ints(numElements).toArray(); + + ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter( + numElements * 4, numElements * 4, new DirectByteBufferAllocator()); + for (int v : values) { + writer.writeInteger(v); + } + BytesInput input = writer.getBytes(); + + ByteStreamSplitValuesReaderForInteger reader = new ByteStreamSplitValuesReaderForInteger(); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(input.toByteBuffer())); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readInteger()); + } + + writer.reset(); + writer.close(); + } + + private void longScalarRoundTrip(int numElements) throws Exception { + Random rand = new Random(42); + long[] values = rand.longs(numElements).toArray(); + + ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + numElements * 8, numElements * 8, new DirectByteBufferAllocator()); + for (long v : values) { + writer.writeLong(v); + } + BytesInput input = writer.getBytes(); + + ByteStreamSplitValuesReaderForLong reader = new ByteStreamSplitValuesReaderForLong(); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(input.toByteBuffer())); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readLong()); + } + + writer.reset(); + writer.close(); + } + + // --------------------------------------------------------------------------- + // getBufferedSize accounts for unflushed batch + // --------------------------------------------------------------------------- + + @Test + public void testGetBufferedSizeWithPartialBatch() throws Exception { + ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter( + 256, 256, new DirectByteBufferAllocator()); + + // Write fewer than BATCH_SIZE values -- they sit in the internal batch buffer + for (int i = 0; i < 10; i++) { + writer.writeInteger(i); + } + assertEquals(10 * 4, writer.getBufferedSize()); + + // Write more to cross a batch boundary + for (int i = 0; i < BATCH_SIZE; i++) { + writer.writeInteger(i); + } + assertEquals((10 + BATCH_SIZE) * 4, writer.getBufferedSize()); + + writer.reset(); + assertEquals(0, writer.getBufferedSize()); + writer.close(); + } + + @Test + public void testGetBufferedSizeWithPartialLongBatch() throws Exception { + ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + 256, 256, new DirectByteBufferAllocator()); + + for (int i = 0; i < 10; i++) { + writer.writeLong(i); + } + assertEquals(10 * 8, writer.getBufferedSize()); + + writer.reset(); + assertEquals(0, writer.getBufferedSize()); + writer.close(); + } + + // --------------------------------------------------------------------------- + // decodeData direct ByteBuffer path (no backing array) + // --------------------------------------------------------------------------- + + @Test + public void testDecodeFromDirectByteBuffer() throws Exception { + Random rand = new Random(42); + final int numElements = 256; + float[] values = new float[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = rand.nextFloat(); + } + + // Encode using standard writer + ByteStreamSplitValuesWriter.FloatByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FloatByteStreamSplitValuesWriter( + numElements * 4, numElements * 4, new DirectByteBufferAllocator()); + for (float v : values) { + writer.writeFloat(v); + } + byte[] encoded = writer.getBytes().toByteArray(); + + // Copy into a direct ByteBuffer (no backing array) to exercise the else branch in decodeData + ByteBuffer direct = ByteBuffer.allocateDirect(encoded.length); + direct.put(encoded); + direct.flip(); + + ByteStreamSplitValuesReaderForFloat reader = new ByteStreamSplitValuesReaderForFloat(); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(direct)); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readFloat(), 0.0f); + } + + writer.reset(); + writer.close(); + } + + @Test + public void testDecodeFromDirectByteBufferLong() throws Exception { + Random rand = new Random(42); + final int numElements = 256; + long[] values = rand.longs(numElements).toArray(); + + ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + numElements * 8, numElements * 8, new DirectByteBufferAllocator()); + for (long v : values) { + writer.writeLong(v); + } + byte[] encoded = writer.getBytes().toByteArray(); + + ByteBuffer direct = ByteBuffer.allocateDirect(encoded.length); + direct.put(encoded); + direct.flip(); + + ByteStreamSplitValuesReaderForLong reader = new ByteStreamSplitValuesReaderForLong(); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(direct)); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readLong()); + } + + writer.reset(); + writer.close(); + } + + @Test + public void testDecodeFromDirectByteBufferFlba() throws Exception { + Random rand = new Random(42); + final int numElements = 256; + final int typeLength = 12; + Binary[] values = new Binary[numElements]; + for (int i = 0; i < numElements; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + values[i] = Binary.fromConstantByteArray(bytes); + } + + ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + typeLength, + numElements * typeLength, + numElements * typeLength, + new DirectByteBufferAllocator()); + for (Binary v : values) { + writer.writeBytes(v); + } + byte[] encoded = writer.getBytes().toByteArray(); + + ByteBuffer direct = ByteBuffer.allocateDirect(encoded.length); + direct.put(encoded); + direct.flip(); + + ByteStreamSplitValuesReaderForFLBA reader = new ByteStreamSplitValuesReaderForFLBA(typeLength); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(direct)); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readBytes()); + } + + writer.reset(); + writer.close(); + } + + // --------------------------------------------------------------------------- + // FLBA getBufferedSize with partial batch + // --------------------------------------------------------------------------- + + @Test + public void testFlbaGetBufferedSizeWithPartialBatch() throws Exception { + final int typeLength = 12; + ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + typeLength, 256, 256, new DirectByteBufferAllocator()); + + Random rand = new Random(42); + // Write fewer than BATCH_SIZE values -- they sit in batchBufs + for (int i = 0; i < 10; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + writer.writeBytes(Binary.fromConstantByteArray(bytes)); + } + assertEquals(10 * typeLength, writer.getBufferedSize()); + + // Write more to cross a batch boundary + for (int i = 0; i < BATCH_SIZE; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + writer.writeBytes(Binary.fromConstantByteArray(bytes)); + } + assertEquals((10 + BATCH_SIZE) * typeLength, writer.getBufferedSize()); + + writer.reset(); + assertEquals(0, writer.getBufferedSize()); + writer.close(); + } + + // --------------------------------------------------------------------------- + // close() with a pending (unflushed) batch resets batch state + // --------------------------------------------------------------------------- + + @Test + public void testIntCloseWithPendingBatch() throws Exception { + ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.IntegerByteStreamSplitValuesWriter( + 256, 256, new DirectByteBufferAllocator()); + // Write fewer than BATCH_SIZE values so the batch stays unflushed + for (int i = 0; i < 10; i++) { + writer.writeInteger(i); + } + assertEquals(10 * 4, writer.getBufferedSize()); + writer.close(); + assertEquals(0, writer.getBufferedSize()); + } + + @Test + public void testLongCloseWithPendingBatch() throws Exception { + ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.LongByteStreamSplitValuesWriter( + 256, 256, new DirectByteBufferAllocator()); + for (int i = 0; i < 10; i++) { + writer.writeLong(i); + } + assertEquals(10 * 8, writer.getBufferedSize()); + writer.close(); + assertEquals(0, writer.getBufferedSize()); + } + + @Test + public void testFlbaCloseWithPendingBatch() throws Exception { + final int typeLength = 12; + ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + typeLength, 256, 256, new DirectByteBufferAllocator()); + Random rand = new Random(42); + for (int i = 0; i < 10; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + writer.writeBytes(Binary.fromConstantByteArray(bytes)); + } + assertEquals(10 * typeLength, writer.getBufferedSize()); + writer.close(); + assertEquals(0, writer.getBufferedSize()); + } + + // --------------------------------------------------------------------------- + // Direct ByteBuffer decode for 2-byte and 16-byte element sizes + // --------------------------------------------------------------------------- + + @Test + public void testDecodeFromDirectByteBufferFlba2() throws Exception { + directByteBufferFlbaRoundTrip(2, 256); + } + + @Test + public void testDecodeFromDirectByteBufferFlba16() throws Exception { + directByteBufferFlbaRoundTrip(16, 256); + } + + private void directByteBufferFlbaRoundTrip(int typeLength, int numElements) throws Exception { + Random rand = new Random(42); + Binary[] values = new Binary[numElements]; + for (int i = 0; i < numElements; i++) { + byte[] bytes = new byte[typeLength]; + rand.nextBytes(bytes); + values[i] = Binary.fromConstantByteArray(bytes); + } + + ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter writer = + new ByteStreamSplitValuesWriter.FixedLenByteArrayByteStreamSplitValuesWriter( + typeLength, + numElements * typeLength, + numElements * typeLength, + new DirectByteBufferAllocator()); + for (Binary v : values) { + writer.writeBytes(v); + } + byte[] encoded = writer.getBytes().toByteArray(); + + // Use a direct ByteBuffer (no backing array) to exercise the else branch in decodeData + ByteBuffer direct = ByteBuffer.allocateDirect(encoded.length); + direct.put(encoded); + direct.flip(); + + ByteStreamSplitValuesReaderForFLBA reader = new ByteStreamSplitValuesReaderForFLBA(typeLength); + reader.initFromPage(numElements, ByteBufferInputStream.wrap(direct)); + + for (int i = 0; i < numElements; i++) { + assertEquals("Mismatch at index " + i, values[i], reader.readBytes()); + } + + writer.reset(); + writer.close(); + } +} diff --git a/pom.xml b/pom.xml index 1bd9893d87..c941159798 100644 --- a/pom.xml +++ b/pom.xml @@ -594,6 +594,7 @@ org.apache.parquet.internal.column.columnindex.IndexIterator org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesReader#gatherElementDataFromStreams(byte[]) + org.apache.parquet.column.values.bytestreamsplit.ByteStreamSplitValuesWriter#scatterBytes(byte[]) org.apache.parquet.arrow.schema.SchemaMapping$TypeMappingVisitor#visit(org.apache.parquet.arrow.schema.SchemaMapping$MapTypeMapping)