Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion cpp/fory/serialization/collection_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ inline void collection_insert(Container &result, T &&elem) {
/// Read collection data for polymorphic or shared-ref elements.
template <typename T, typename Container>
inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) {
// Guardrail: Enforce max_collection_size for collection reads
if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return Container{};
}

Container result;
if constexpr (has_reserve_v<Container>) {
result.reserve(length);
Expand Down Expand Up @@ -611,15 +618,22 @@ struct Serializer<
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::vector<T, Alloc>();
}
// Guardrail: Enforce max_binary_size for binary byte-length reads
if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) {
ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size"));
return std::vector<T, Alloc>();
}
if (sizeof(T) == 0) {
return std::vector<T, Alloc>();
}

size_t elem_count = total_bytes_u32 / sizeof(T);

if (total_bytes_u32 % sizeof(T) != 0) {
ctx.set_error(Error::invalid_data(
"Vector byte size not aligned with element size"));
return std::vector<T, Alloc>();
}
size_t elem_count = total_bytes_u32 / sizeof(T);
std::vector<T, Alloc> result(elem_count);
if (total_bytes_u32 > 0) {
ctx.read_bytes(result.data(), static_cast<uint32_t>(total_bytes_u32),
Expand Down Expand Up @@ -677,6 +691,13 @@ struct Serializer<
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::vector<T, Alloc>();
}

if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::vector<T, Alloc>();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition on the read() path. One gap remains: several read_data() paths still trust payload size without max_collection_size checks (for example vector non-arithmetic, vector, list, deque, and forward_list). Those paths are used in native/nested deserialization, so guardrails can still be bypassed there.

// Per xlang spec: header and type_info are omitted when length is 0
if (length == 0) {
return std::vector<T, Alloc>();
Expand Down Expand Up @@ -808,6 +829,13 @@ struct Serializer<
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::vector<T, Alloc>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::vector<T, Alloc>();
}

std::vector<T, Alloc> result;
result.reserve(size);
for (uint32_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -897,6 +925,12 @@ template <typename Alloc> struct Serializer<std::vector<bool, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::vector<bool, Alloc>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_binary_size)) {
ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size"));
return std::vector<bool, Alloc>();
}

std::vector<bool, Alloc> result(size);
// Fast path: bulk read all bytes at once if we have enough buffer
Buffer &buffer = ctx.buffer();
Expand Down Expand Up @@ -971,6 +1005,13 @@ template <typename T, typename Alloc> struct Serializer<std::list<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::list<T, Alloc>();
}

if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::list<T, Alloc>();
}

// Per xlang spec: header and type_info are omitted when length is 0
if (length == 0) {
return std::list<T, Alloc>();
Expand Down Expand Up @@ -1101,6 +1142,13 @@ template <typename T, typename Alloc> struct Serializer<std::list<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::list<T, Alloc>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::list<T, Alloc>();
}

std::list<T, Alloc> result;
for (uint32_t i = 0; i < size; ++i) {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
Expand Down Expand Up @@ -1161,6 +1209,13 @@ template <typename T, typename Alloc> struct Serializer<std::deque<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::deque<T, Alloc>();
}

if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::deque<T, Alloc>();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still an unguarded preallocation in forward_list::read: temp.reserve(length) happens before any max_collection_size validation. A malicious length can force a large allocation before we fail. Please validate length immediately after reading it, before reserve.

// Per xlang spec: header and type_info are omitted when length is 0
if (length == 0) {
return std::deque<T, Alloc>();
Expand Down Expand Up @@ -1291,6 +1346,13 @@ template <typename T, typename Alloc> struct Serializer<std::deque<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::deque<T, Alloc>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::deque<T, Alloc>();
}

std::deque<T, Alloc> result;
for (uint32_t i = 0; i < size; ++i) {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
Expand Down Expand Up @@ -1352,6 +1414,13 @@ struct Serializer<std::forward_list<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::forward_list<T, Alloc>();
}

if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::forward_list<T, Alloc>();
}

// Per xlang spec: header and type_info are omitted when length is 0
if (length == 0) {
return std::forward_list<T, Alloc>();
Expand Down Expand Up @@ -1716,6 +1785,13 @@ struct Serializer<std::forward_list<T, Alloc>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::forward_list<T, Alloc>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::forward_list<T, Alloc>();
}

std::vector<T> temp;
temp.reserve(size);
for (uint32_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -1814,6 +1890,13 @@ struct Serializer<std::set<T, Args...>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::set<T, Args...>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::set<T, Args...>();
}

// Per xlang spec: header and type_info are omitted when length is 0
if (size == 0) {
return std::set<T, Args...>();
Expand Down Expand Up @@ -1894,6 +1977,13 @@ struct Serializer<std::set<T, Args...>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::set<T, Args...>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::set<T, Args...>();
}

std::set<T, Args...> result;
for (uint32_t i = 0; i < size; ++i) {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
Expand Down Expand Up @@ -1988,6 +2078,12 @@ struct Serializer<std::unordered_set<T, Args...>> {
return std::unordered_set<T, Args...>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::unordered_set<T, Args...>();
}

// Per xlang spec: header and type_info are omitted when length is 0
if (size == 0) {
return std::unordered_set<T, Args...>();
Expand Down Expand Up @@ -2070,6 +2166,13 @@ struct Serializer<std::unordered_set<T, Args...>> {
if (FORY_PREDICT_FALSE(ctx.has_error())) {
return std::unordered_set<T, Args...>();
}

if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Collection length exceeds max_collection_size"));
return std::unordered_set<T, Args...>();
}

std::unordered_set<T, Args...> result;
result.reserve(size);
for (uint32_t i = 0; i < size; ++i) {
Expand Down
36 changes: 36 additions & 0 deletions cpp/fory/serialization/collection_serializer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,42 @@ TEST(CollectionSerializerTest, ForwardListEmptyRoundTrip) {
EXPECT_TRUE(deserialized.strings.empty());
}

// Test max_collection_size using objects (e.g., strings)
TEST(CollectionSerializerTest, MaxCollectionSizeNativeGuardrail) {
auto fory = Fory::builder().xlang(false).max_collection_size(2).build();
fory.register_struct<VectorStringHolder>(200);

VectorStringHolder original;
original.strings = {"A", "B", "C"};

auto bytes_result = fory.serialize(original);
ASSERT_TRUE(bytes_result.ok());

auto deserialize_result = fory.deserialize<VectorStringHolder>(
bytes_result->data(), bytes_result->size());

ASSERT_FALSE(deserialize_result.ok());
EXPECT_TRUE(deserialize_result.error().message().find(
"exceeds max_collection_size") != std::string::npos);
}

// Test max_binary_size using primitive numbers
TEST(CollectionSerializerTest, MaxBinarySizeNativeGuardrail) {
auto fory = Fory::builder().xlang(false).max_binary_size(10).build();

std::vector<int32_t> large_data = {1, 2, 3, 4, 5};

auto bytes_result = fory.serialize(large_data);
ASSERT_TRUE(bytes_result.ok());

auto deserialize_result = fory.deserialize<std::vector<int32_t>>(
bytes_result->data(), bytes_result->size());

ASSERT_FALSE(deserialize_result.ok());
EXPECT_TRUE(deserialize_result.error().message().find(
"exceeds max_binary_size") != std::string::npos);
}

} // namespace
} // namespace serialization
} // namespace fory
6 changes: 6 additions & 0 deletions cpp/fory/serialization/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ struct Config {
/// When enabled, avoids duplicating shared objects and handles cycles.
bool track_ref = true;

/// Maximum allowed size for binary data in bytes.
uint32_t max_binary_size = 64 * 1024 * 1024; // 64MB default

/// Maximum allowed number of elements in a collection or entries in a map.
uint32_t max_collection_size = 1024 * 1024; // 1M elements default

/// Default constructor with sensible defaults
Config() = default;
};
Expand Down
3 changes: 3 additions & 0 deletions cpp/fory/serialization/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ class ReadContext {
/// reset context for reuse.
void reset();

/// get associated configuration.
inline const Config &config() const { return *config_; }

private:
// Error state - accumulated during deserialization, checked at the end
Error error_;
Expand Down
13 changes: 13 additions & 0 deletions cpp/fory/serialization/fory.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ class ForyBuilder {
/// Build a thread-safe Fory instance (uses context pools).
ThreadSafeFory build_thread_safe();

/// Set the maximum allowed size for binary data in bytes.
inline ForyBuilder &max_binary_size(uint32_t size) {
config_.max_binary_size = size;
return *this;
}

/// Set the maximum allowed number of elements in a collection or entries in a
/// map.
inline ForyBuilder &max_collection_size(uint32_t size) {
config_.max_collection_size = size;
return *this;
}

private:
Config config_;
std::shared_ptr<TypeResolver> type_resolver_;
Expand Down
14 changes: 14 additions & 0 deletions cpp/fory/serialization/map_serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) {
static_assert(!is_shared_ref_v<K> && !is_shared_ref_v<V>,
"Fast path is for non-shared-ref types only");

// Guardrail: Enforce max_collection_size for map reads (entry count)
if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Map entry count exceeds max_collection_size"));
return MapType{};
}

MapType result;
MapReserver<MapType>::reserve(result, length);

Expand Down Expand Up @@ -682,6 +689,13 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) {
/// Read map data for polymorphic or shared-ref maps
template <typename K, typename V, typename MapType>
inline MapType read_map_data_slow(ReadContext &ctx, uint32_t length) {
// Guardrail: Enforce max_collection_size for map reads (entry count)
if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) {
ctx.set_error(
Error::invalid_data("Map entry count exceeds max_collection_size"));
return MapType{};
}

MapType result;
MapReserver<MapType>::reserve(result, length);

Expand Down
16 changes: 16 additions & 0 deletions cpp/fory/serialization/map_serializer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,22 @@ TEST(MapSerializerTest, LargeMapWithPolymorphicValues) {
EXPECT_EQ(deserialized[299]->name, "value_y_299");
}

TEST(MapSerializerTest, MaxMapSizeGuardrail) {
auto fory = Fory::builder().xlang(true).max_collection_size(2).build();

std::map<std::string, int32_t> large_map = {{"a", 1}, {"b", 2}, {"c", 3}};

auto serialize_result = fory.serialize(large_map);
ASSERT_TRUE(serialize_result.ok());

auto deserialize_result = fory.deserialize<std::map<std::string, int32_t>>(
serialize_result->data(), serialize_result->size());

ASSERT_FALSE(deserialize_result.ok());
EXPECT_TRUE(deserialize_result.error().message().find(
"exceeds max_collection_size") != std::string::npos);
}

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
Expand Down
Loading
Loading