|
28 | 28 | #endif |
29 | 29 |
|
30 | 30 | #ifdef ZSTD_CODEC_AVAILABLE |
31 | | -#include <zstd.h> |
| 31 | +#include "ZstdCompressWrapper.hh" |
| 32 | +#include "ZstdDecompressWrapper.hh" |
32 | 33 | #endif |
33 | 34 |
|
34 | 35 | #include <zlib.h> |
@@ -244,24 +245,12 @@ void DataFileWriterBase::sync() { |
244 | 245 | reinterpret_cast<const char *>(data) + len); |
245 | 246 | } |
246 | 247 |
|
247 | | - // Pre-allocate buffer for compressed data |
248 | | - size_t max_compressed_size = ZSTD_compressBound(uncompressed.size()); |
249 | | - std::vector<char> compressed(max_compressed_size); |
| 248 | + ZstdCompressWrapper zstdCompressWrapper; |
| 249 | + std::vector<char> compressed = zstdCompressWrapper.compress(uncompressed); |
250 | 250 |
|
251 | | - // Compress the data using ZSTD block API |
252 | | - size_t compressed_size = ZSTD_compress( |
253 | | - compressed.data(), max_compressed_size, |
254 | | - uncompressed.data(), uncompressed.size(), |
255 | | - ZSTD_CLEVEL_DEFAULT); |
256 | | - |
257 | | - if (ZSTD_isError(compressed_size)) { |
258 | | - throw Exception("ZSTD compression error: {}", ZSTD_getErrorName(compressed_size)); |
259 | | - } |
260 | | - |
261 | | - compressed.resize(compressed_size); |
262 | 251 | std::unique_ptr<InputStream> in = memoryInputStream( |
263 | 252 | reinterpret_cast<const uint8_t *>(compressed.data()), compressed.size()); |
264 | | - avro::encode(*encoderPtr_, static_cast<int64_t>(compressed_size)); |
| 253 | + avro::encode(*encoderPtr_, static_cast<int64_t>(compressed.size())); |
265 | 254 | encoderPtr_->flush(); |
266 | 255 | copy(*in, *stream_); |
267 | 256 | #endif |
@@ -482,35 +471,15 @@ void DataFileReaderBase::readDataBlock() { |
482 | 471 | #ifdef ZSTD_CODEC_AVAILABLE |
483 | 472 | } else if (codec_ == ZSTD_CODEC) { |
484 | 473 | compressed_.clear(); |
| 474 | + uncompressed.clear(); |
485 | 475 | const uint8_t *data; |
486 | 476 | size_t len; |
487 | 477 | while (st->next(&data, &len)) { |
488 | 478 | compressed_.insert(compressed_.end(), data, data + len); |
489 | 479 | } |
490 | 480 |
|
491 | | - // Get the decompressed size |
492 | | - size_t decompressed_size = ZSTD_getFrameContentSize( |
493 | | - reinterpret_cast<const char *>(compressed_.data()), compressed_.size()); |
494 | | - if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { |
495 | | - throw Exception("ZSTD: Not a valid compressed frame"); |
496 | | - } else if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { |
497 | | - throw Exception("ZSTD: Unable to determine decompressed size"); |
498 | | - } |
499 | | - |
500 | | - // Decompress the data |
501 | | - uncompressed.clear(); |
502 | | - uncompressed.resize(decompressed_size); |
503 | | - size_t result = ZSTD_decompress( |
504 | | - uncompressed.data(), decompressed_size, |
505 | | - reinterpret_cast<const char *>(compressed_.data()), compressed_.size()); |
506 | | - |
507 | | - if (ZSTD_isError(result)) { |
508 | | - throw Exception("ZSTD decompression error: {}", ZSTD_getErrorName(result)); |
509 | | - } |
510 | | - if (result != decompressed_size) { |
511 | | - throw Exception("ZSTD: Decompressed size mismatch: expected {}, got {}", |
512 | | - decompressed_size, result); |
513 | | - } |
| 481 | + ZstdDecompressWrapper zstdDecompressWrapper; |
| 482 | + uncompressed = zstdDecompressWrapper.decompress(compressed_); |
514 | 483 |
|
515 | 484 | std::unique_ptr<InputStream> in = memoryInputStream( |
516 | 485 | reinterpret_cast<const uint8_t *>(uncompressed.data()), |
@@ -620,8 +589,7 @@ void DataFileReaderBase::readHeader() { |
620 | 589 | codec_ = SNAPPY_CODEC; |
621 | 590 | #endif |
622 | 591 | #ifdef ZSTD_CODEC_AVAILABLE |
623 | | - } else if (it != metadata_.end() |
624 | | - && toString(it->second) == AVRO_ZSTD_CODEC) { |
| 592 | + } else if (it != metadata_.end() && toString(it->second) == AVRO_ZSTD_CODEC) { |
625 | 593 | codec_ = ZSTD_CODEC; |
626 | 594 | #endif |
627 | 595 | } else { |
|
0 commit comments