2626#include " shuffle/Utils.h"
2727#include " utils/Exception.h"
2828#include " utils/Timer.h"
29+ #include " utils/tac/TypeAwareCompressCodec.h"
2930
3031namespace gluten {
3132namespace {
@@ -36,6 +37,7 @@ static const Payload::Type kUncompressedType = gluten::BlockPayload::kUncompress
3637static constexpr int64_t kZeroLengthBuffer = 0 ;
3738static constexpr int64_t kNullBuffer = -1 ;
3839static constexpr int64_t kUncompressedBuffer = -2 ;
40+ static constexpr int64_t kTypeAwareBuffer = -3 ;
3941
4042template <typename T>
4143void write (uint8_t ** dst, T data) {
@@ -86,6 +88,51 @@ arrow::Result<int64_t> compressBuffer(
8688 return kCompressedBufferHeaderLength + compressedLength;
8789}
8890
91+ // Type-aware buffer compression via TypeAwareCompressCodec.
92+ // Same wire format as compressBuffer:
93+ // kTypeAwareBuffer (int64) | uncompressedLength (int64) | compressedLength (int64) | compressed data
94+ // If compressed size >= uncompressed size, falls back to kUncompressedBuffer (same as standard codec).
95+ arrow::Result<int64_t > compressTypeAwareBuffer (
96+ const std::shared_ptr<arrow::Buffer>& buffer,
97+ uint8_t * output,
98+ int64_t outputLength,
99+ int8_t typeKind) {
100+ auto outputPtr = &output;
101+ if (!buffer) {
102+ write<int64_t >(outputPtr, kNullBuffer );
103+ return sizeof (int64_t );
104+ }
105+ if (buffer->size () == 0 ) {
106+ write<int64_t >(outputPtr, kZeroLengthBuffer );
107+ return sizeof (int64_t );
108+ }
109+
110+ static const int64_t kHeaderLength = 3 * sizeof (int64_t ); // marker + uncompressedLen + compressedLen
111+ if (outputLength < kHeaderLength + buffer->size ()) {
112+ return arrow::Status::Invalid (" Output buffer too small for type-aware compression." );
113+ }
114+ auto * dataOutput = output + kHeaderLength ;
115+ auto availableOutput = outputLength - kHeaderLength ;
116+
117+ ARROW_ASSIGN_OR_RAISE (
118+ auto compressedSize,
119+ TypeAwareCompressCodec::compress (buffer->data (), buffer->size (), dataOutput, availableOutput, typeKind));
120+
121+ if (compressedSize >= buffer->size ()) {
122+ // Compression didn't help. Fall back to uncompressed, same as compressBuffer.
123+ write<int64_t >(outputPtr, kUncompressedBuffer );
124+ write (outputPtr, static_cast <int64_t >(buffer->size ()));
125+ memcpy (*outputPtr, buffer->data (), buffer->size ());
126+ return 2 * sizeof (int64_t ) + buffer->size ();
127+ }
128+
129+ write<int64_t >(outputPtr, kTypeAwareBuffer );
130+ write (outputPtr, static_cast <int64_t >(buffer->size ()));
131+ write (outputPtr, static_cast <int64_t >(compressedSize));
132+ // compressed data already written at dataOutput by TypeAwareCompressCodec::compress.
133+ return kHeaderLength + compressedSize;
134+ }
135+
89136arrow::Status compressAndFlush (
90137 const std::shared_ptr<arrow::Buffer>& buffer,
91138 arrow::io::OutputStream* outputStream,
@@ -146,6 +193,24 @@ arrow::Result<std::shared_ptr<arrow::Buffer>> readCompressedBuffer(
146193
147194 int64_t uncompressedLength;
148195 RETURN_NOT_OK (inputStream->Read (sizeof (int64_t ), &uncompressedLength));
196+
197+ if (compressedLength == kTypeAwareBuffer ) {
198+ // Type-aware compressed buffer. This marker only appears when compression helped.
199+ // Wire format: compressedLength (int64) already consumed above as kTypeAwareBuffer,
200+ // then uncompressedLength (already read), then actualCompressedLen, then data.
201+ int64_t actualCompressedLen;
202+ RETURN_NOT_OK (inputStream->Read (sizeof (int64_t ), &actualCompressedLen));
203+ ARROW_ASSIGN_OR_RAISE (auto compressed, arrow::AllocateResizableBuffer (actualCompressedLen, pool));
204+ RETURN_NOT_OK (inputStream->Read (actualCompressedLen, compressed->mutable_data ()));
205+
206+ timer.switchTo (&decompressTime);
207+ ARROW_ASSIGN_OR_RAISE (auto output, arrow::AllocateResizableBuffer (uncompressedLength, pool));
208+ RETURN_NOT_OK (TypeAwareCompressCodec::decompress (
209+ compressed->data (), actualCompressedLen, output->mutable_data (), uncompressedLength)
210+ .status ());
211+ return output;
212+ }
213+
149214 if (compressedLength == kUncompressedBuffer ) {
150215 ARROW_ASSIGN_OR_RAISE (auto uncompressed, arrow::AllocateResizableBuffer (uncompressedLength, pool));
151216 RETURN_NOT_OK (inputStream->Read (uncompressedLength, uncompressed->mutable_data ()));
@@ -185,25 +250,38 @@ arrow::Result<std::unique_ptr<BlockPayload>> BlockPayload::fromBuffers(
185250 std::vector<std::shared_ptr<arrow::Buffer>> buffers,
186251 const std::vector<bool >* isValidityBuffer,
187252 arrow::MemoryPool* pool,
188- arrow::util::Codec* codec) {
253+ arrow::util::Codec* codec,
254+ const std::vector<int8_t >* bufferTypes) {
189255 const uint32_t numBuffers = buffers.size ();
190256
191257 if (payloadType == Payload::Type::kCompressed ) {
192258 Timer compressionTime;
193259 compressionTime.start ();
194- // Compress.
195- auto maxLength = maxCompressedLength (buffers, codec);
196- std::shared_ptr<arrow::Buffer> compressedBuffer;
197260
261+ // Compute max compressed length, accounting for type-aware compression where applicable.
262+ auto maxLength = maxCompressedLength (buffers, codec, bufferTypes);
263+
264+ std::shared_ptr<arrow::Buffer> compressedBuffer;
198265 ARROW_ASSIGN_OR_RAISE (compressedBuffer, arrow::AllocateResizableBuffer (maxLength, pool));
199266 auto * output = compressedBuffer->mutable_data ();
200267
201268 int64_t actualLength = 0 ;
202269 // Compress buffers one by one.
203- for (auto & buffer : buffers) {
270+ for (size_t i = 0 ; i < buffers. size (); ++i ) {
204271 auto availableLength = maxLength - actualLength;
205- // Release buffer after compression.
206- ARROW_ASSIGN_OR_RAISE (auto compressedSize, compressBuffer (std::move (buffer), output, availableLength, codec));
272+ auto typeKind =
273+ (bufferTypes != nullptr && i < bufferTypes->size ()) ? (*bufferTypes)[i] : tac::kUnsupported ;
274+
275+ int64_t compressedSize = 0 ;
276+ if (TypeAwareCompressCodec::support (typeKind)) {
277+ // Use type-aware compression for supported types.
278+ ARROW_ASSIGN_OR_RAISE (
279+ compressedSize, compressTypeAwareBuffer (std::move (buffers[i]), output, availableLength, typeKind));
280+ } else {
281+ // Use standard codec (LZ4/ZSTD) for unsupported types.
282+ ARROW_ASSIGN_OR_RAISE (
283+ compressedSize, compressBuffer (std::move (buffers[i]), output, availableLength, codec));
284+ }
207285 output += compressedSize;
208286 actualLength += compressedSize;
209287 }
@@ -327,16 +405,29 @@ int64_t BlockPayload::rawSize() {
327405
328406int64_t BlockPayload::maxCompressedLength (
329407 const std::vector<std::shared_ptr<arrow::Buffer>>& buffers,
330- arrow::util::Codec* codec) {
408+ arrow::util::Codec* codec,
409+ const std::vector<int8_t >* bufferTypes) {
331410 // Compressed buffer layout: | buffer1 compressedLength | buffer1 uncompressedLength | buffer1 | ...
332- const auto metadataLength = sizeof (int64_t ) * 2 * buffers.size ();
333- int64_t totalCompressedLength =
334- std::accumulate (buffers.begin (), buffers.end (), 0LL , [&](auto sum, const auto & buffer) {
335- if (!buffer) {
336- return sum;
337- }
338- return sum + codec->MaxCompressedLen (buffer->size (), buffer->data ());
339- });
411+ int64_t metadataLength = sizeof (int64_t ) * 2 * buffers.size ();
412+ int64_t totalCompressedLength = 0 ;
413+ for (size_t i = 0 ; i < buffers.size (); ++i) {
414+ const auto & buffer = buffers[i];
415+ if (!buffer) {
416+ continue ;
417+ }
418+ if (bufferTypes != nullptr && i < bufferTypes->size ()) {
419+ auto typeKind = (*bufferTypes)[i];
420+ if (TypeAwareCompressCodec::support (typeKind)) {
421+ // Type-aware compressed buffer has an extra int64 marker to indicate type-aware compression.
422+ // buffer layout: | kTypeAwareBuffer (int64) | buffer 1 uncompressedLength | buffer 1 compressedLength | buffer 1 | ...
423+ metadataLength += sizeof (int64_t );
424+ totalCompressedLength += TypeAwareCompressCodec::maxCompressedLen (buffer->size (), typeKind);
425+ continue ;
426+ }
427+ }
428+ // Standard codec: compressed data.
429+ totalCompressedLength += codec->MaxCompressedLen (buffer->size (), buffer->data ());
430+ }
340431 return metadataLength + totalCompressedLength;
341432}
342433
@@ -413,12 +504,14 @@ arrow::Result<std::unique_ptr<InMemoryPayload>> InMemoryPayload::merge(
413504 }
414505 }
415506 }
416- return std::make_unique<InMemoryPayload>(mergedRows, isValidityBuffer, source->schema (), std::move (merged));
507+ return std::make_unique<InMemoryPayload>(
508+ mergedRows, isValidityBuffer, source->schema (), std::move (merged), false , source->bufferTypes_ );
417509}
418510
419511arrow::Result<std::unique_ptr<BlockPayload>>
420512InMemoryPayload::toBlockPayload (Payload::Type payloadType, arrow::MemoryPool* pool, arrow::util::Codec* codec) {
421- return BlockPayload::fromBuffers (payloadType, numRows_, std::move (buffers_), isValidityBuffer_, pool, codec);
513+ return BlockPayload::fromBuffers (
514+ payloadType, numRows_, std::move (buffers_), isValidityBuffer_, pool, codec, bufferTypes_);
422515}
423516
424517arrow::Status InMemoryPayload::serialize (arrow::io::OutputStream* outputStream) {
0 commit comments