Skip to content

Commit 54b3321

Browse files
authored
AVRO-4172: [C++] Fix ZSTD codec compatibility (#3457)
* AVRO-4172: [C++] Fix ZSTD codec compatibility * AVRO-4172: [C++] add codec compatibility test * fix * fix * fix * add ZstdCodecWrapper for compression * fix * fix * split zstd compress and decompress wrapper * fix
1 parent 2b11dba commit 54b3321

9 files changed

Lines changed: 326 additions & 41 deletions

lang/c++/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ set (AVRO_SOURCE_FILES
129129
impl/Stream.cc impl/FileStream.cc
130130
impl/Generic.cc impl/GenericDatum.cc
131131
impl/DataFile.cc
132+
impl/ZstdCompressWrapper.cc
133+
impl/ZstdDecompressWrapper.cc
132134
impl/parsing/Symbol.cc
133135
impl/parsing/ValidatingCodec.cc
134136
impl/parsing/JsonCodec.cc

lang/c++/impl/DataFile.cc

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
#endif
2929

3030
#ifdef ZSTD_CODEC_AVAILABLE
31-
#include <zstd.h>
31+
#include "ZstdCompressWrapper.hh"
32+
#include "ZstdDecompressWrapper.hh"
3233
#endif
3334

3435
#include <zlib.h>
@@ -244,24 +245,12 @@ void DataFileWriterBase::sync() {
244245
reinterpret_cast<const char *>(data) + len);
245246
}
246247

247-
// Pre-allocate buffer for compressed data
248-
size_t max_compressed_size = ZSTD_compressBound(uncompressed.size());
249-
std::vector<char> compressed(max_compressed_size);
248+
ZstdCompressWrapper zstdCompressWrapper;
249+
std::vector<char> compressed = zstdCompressWrapper.compress(uncompressed);
250250

251-
// Compress the data using ZSTD block API
252-
size_t compressed_size = ZSTD_compress(
253-
compressed.data(), max_compressed_size,
254-
uncompressed.data(), uncompressed.size(),
255-
ZSTD_CLEVEL_DEFAULT);
256-
257-
if (ZSTD_isError(compressed_size)) {
258-
throw Exception("ZSTD compression error: {}", ZSTD_getErrorName(compressed_size));
259-
}
260-
261-
compressed.resize(compressed_size);
262251
std::unique_ptr<InputStream> in = memoryInputStream(
263252
reinterpret_cast<const uint8_t *>(compressed.data()), compressed.size());
264-
avro::encode(*encoderPtr_, static_cast<int64_t>(compressed_size));
253+
avro::encode(*encoderPtr_, static_cast<int64_t>(compressed.size()));
265254
encoderPtr_->flush();
266255
copy(*in, *stream_);
267256
#endif
@@ -482,35 +471,15 @@ void DataFileReaderBase::readDataBlock() {
482471
#ifdef ZSTD_CODEC_AVAILABLE
483472
} else if (codec_ == ZSTD_CODEC) {
484473
compressed_.clear();
474+
uncompressed.clear();
485475
const uint8_t *data;
486476
size_t len;
487477
while (st->next(&data, &len)) {
488478
compressed_.insert(compressed_.end(), data, data + len);
489479
}
490480

491-
// Get the decompressed size
492-
size_t decompressed_size = ZSTD_getFrameContentSize(
493-
reinterpret_cast<const char *>(compressed_.data()), compressed_.size());
494-
if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) {
495-
throw Exception("ZSTD: Not a valid compressed frame");
496-
} else if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
497-
throw Exception("ZSTD: Unable to determine decompressed size");
498-
}
499-
500-
// Decompress the data
501-
uncompressed.clear();
502-
uncompressed.resize(decompressed_size);
503-
size_t result = ZSTD_decompress(
504-
uncompressed.data(), decompressed_size,
505-
reinterpret_cast<const char *>(compressed_.data()), compressed_.size());
506-
507-
if (ZSTD_isError(result)) {
508-
throw Exception("ZSTD decompression error: {}", ZSTD_getErrorName(result));
509-
}
510-
if (result != decompressed_size) {
511-
throw Exception("ZSTD: Decompressed size mismatch: expected {}, got {}",
512-
decompressed_size, result);
513-
}
481+
ZstdDecompressWrapper zstdDecompressWrapper;
482+
uncompressed = zstdDecompressWrapper.decompress(compressed_);
514483

515484
std::unique_ptr<InputStream> in = memoryInputStream(
516485
reinterpret_cast<const uint8_t *>(uncompressed.data()),
@@ -620,8 +589,7 @@ void DataFileReaderBase::readHeader() {
620589
codec_ = SNAPPY_CODEC;
621590
#endif
622591
#ifdef ZSTD_CODEC_AVAILABLE
623-
} else if (it != metadata_.end()
624-
&& toString(it->second) == AVRO_ZSTD_CODEC) {
592+
} else if (it != metadata_.end() && toString(it->second) == AVRO_ZSTD_CODEC) {
625593
codec_ = ZSTD_CODEC;
626594
#endif
627595
} else {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* https://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#ifdef ZSTD_CODEC_AVAILABLE
20+
21+
#include "ZstdCompressWrapper.hh"
22+
#include "Exception.hh"
23+
24+
#include <zstd.h>
25+
26+
namespace avro {
27+
28+
std::vector<char> ZstdCompressWrapper::compress(const std::vector<char> &uncompressed) {
29+
// Pre-allocate buffer for compressed data
30+
size_t max_compressed_size = ZSTD_compressBound(uncompressed.size());
31+
if (ZSTD_isError(max_compressed_size)) {
32+
throw Exception("ZSTD compression error: {}", ZSTD_getErrorName(max_compressed_size));
33+
}
34+
std::vector<char> compressed(max_compressed_size);
35+
36+
// Compress the data using ZSTD block API
37+
size_t compressed_size = ZSTD_compress(
38+
compressed.data(), max_compressed_size,
39+
uncompressed.data(), uncompressed.size(),
40+
ZSTD_CLEVEL_DEFAULT);
41+
42+
if (ZSTD_isError(compressed_size)) {
43+
throw Exception("ZSTD compression error: {}", ZSTD_getErrorName(compressed_size));
44+
}
45+
compressed.resize(compressed_size);
46+
return compressed;
47+
}
48+
49+
ZstdCompressWrapper::ZstdCompressWrapper() {
50+
cctx_ = ZSTD_createCCtx();
51+
if (!cctx_) {
52+
throw Exception("ZSTD_createCCtx() failed");
53+
}
54+
}
55+
56+
ZstdCompressWrapper::~ZstdCompressWrapper() {
57+
ZSTD_freeCCtx(cctx_);
58+
}
59+
60+
} // namespace avro
61+
62+
#endif // ZSTD_CODEC_AVAILABLE
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* https://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#ifndef avro_ZstdCompressWrapper_hh__
20+
#define avro_ZstdCompressWrapper_hh__
21+
22+
#ifdef ZSTD_CODEC_AVAILABLE
23+
24+
#include <vector>
25+
26+
#include <zstd.h>
27+
28+
namespace avro {
29+
30+
class ZstdCompressWrapper {
31+
public:
32+
ZstdCompressWrapper();
33+
~ZstdCompressWrapper();
34+
35+
std::vector<char> compress(const std::vector<char> &uncompressed);
36+
37+
private:
38+
ZSTD_CCtx *cctx_ = nullptr;
39+
};
40+
41+
} // namespace avro
42+
43+
#endif // ZSTD_CODEC_AVAILABLE
44+
45+
#endif // avro_ZstdCompressWrapper_hh__
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* https://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#ifdef ZSTD_CODEC_AVAILABLE
20+
21+
#include "ZstdDecompressWrapper.hh"
22+
#include "Exception.hh"
23+
24+
#include <zstd.h>
25+
26+
namespace avro {
27+
28+
std::string ZstdDecompressWrapper::decompress(const std::vector<char> &compressed) {
29+
std::string uncompressed;
30+
// Get the decompressed size
31+
size_t decompressed_size = ZSTD_getFrameContentSize(compressed.data(), compressed.size());
32+
if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) {
33+
throw Exception("ZSTD: Not a valid compressed frame");
34+
} else if (decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
35+
// Stream decompress the data
36+
ZSTD_inBuffer in{compressed.data(), compressed.size(), 0};
37+
std::vector<char> tmp(ZSTD_DStreamOutSize());
38+
ZSTD_outBuffer out{tmp.data(), tmp.size(), 0};
39+
size_t ret;
40+
do {
41+
out.pos = 0;
42+
ret = ZSTD_decompressStream(dctx_, &out, &in);
43+
if (ZSTD_isError(ret)) {
44+
throw Exception("ZSTD decompression error: {}", ZSTD_getErrorName(ret));
45+
}
46+
uncompressed.append(tmp.data(), out.pos);
47+
} while (ret != 0);
48+
} else {
49+
// Batch decompress the data
50+
uncompressed.resize(decompressed_size);
51+
size_t result = ZSTD_decompress(
52+
uncompressed.data(), decompressed_size, compressed.data(), compressed.size());
53+
54+
if (ZSTD_isError(result)) {
55+
throw Exception("ZSTD decompression error: {}", ZSTD_getErrorName(result));
56+
}
57+
if (result != decompressed_size) {
58+
throw Exception("ZSTD: Decompressed size mismatch: expected {}, got {}",
59+
decompressed_size, result);
60+
}
61+
}
62+
return uncompressed;
63+
}
64+
65+
ZstdDecompressWrapper::ZstdDecompressWrapper() {
66+
dctx_ = ZSTD_createDCtx();
67+
if (!dctx_) {
68+
throw Exception("ZSTD_createDCtx() failed");
69+
}
70+
}
71+
72+
ZstdDecompressWrapper::~ZstdDecompressWrapper() {
73+
ZSTD_freeDCtx(dctx_);
74+
}
75+
76+
} // namespace avro
77+
78+
#endif // ZSTD_CODEC_AVAILABLE
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* https://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#ifndef avro_ZstdDecompressWrapper_hh__
20+
#define avro_ZstdDecompressWrapper_hh__
21+
22+
#ifdef ZSTD_CODEC_AVAILABLE
23+
24+
#include <string>
25+
#include <vector>
26+
27+
#include <zstd.h>
28+
29+
namespace avro {
30+
31+
class ZstdDecompressWrapper {
32+
public:
33+
ZstdDecompressWrapper();
34+
~ZstdDecompressWrapper();
35+
36+
std::string decompress(const std::vector<char> &compressed);
37+
38+
private:
39+
ZSTD_DCtx *dctx_ = nullptr;
40+
};
41+
42+
} // namespace avro
43+
44+
#endif // ZSTD_CODEC_AVAILABLE
45+
46+
#endif // avro_ZstdDecompressWrapper_hh__

0 commit comments

Comments
 (0)