diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index e726a6b44d..c620fb8016 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -865,6 +865,9 @@ impl<'a> Reader<'a> { // ============ STRING (TypeId = 19) ============ + /// # Caller Contract + /// Validate `len` against `ReadContext::check_string_bytes` before calling this. + /// `check_bound` only verifies buffer has `len` bytes; it does not enforce size limits. #[inline(always)] pub fn read_latin1_string(&mut self, len: usize) -> Result { self.check_bound(len)?; @@ -913,6 +916,9 @@ impl<'a> Reader<'a> { } } + /// # Caller Contract + /// Validate `len` against `ReadContext::check_string_bytes` before calling this. + /// `check_bound` only verifies buffer has `len` bytes; it does not enforce size limits. #[inline(always)] pub fn read_utf8_string(&mut self, len: usize) -> Result { self.check_bound(len)?; @@ -930,6 +936,9 @@ impl<'a> Reader<'a> { } } + /// # Caller Contract + /// Validate `len` against `ReadContext::check_string_bytes` before calling this. + /// `check_bound` only verifies buffer has `len` bytes; it does not enforce size limits. #[inline(always)] pub fn read_utf16_string(&mut self, len: usize) -> Result { self.check_bound(len)?; diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 991db1d724..e5849e6313 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -38,6 +38,12 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, + /// Maximum byte length of a single deserialized string. + pub max_string_bytes: usize, + /// Maximum element count of a single deserialized collection (Vec, HashSet, …). + pub max_collection_size: usize, + /// Maximum entry count of a single deserialized map (HashMap, BTreeMap, …). + pub max_map_size: usize, } impl Default for Config { @@ -50,6 +56,9 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, + max_string_bytes: i32::MAX as usize, + max_collection_size: i32::MAX as usize, + max_map_size: i32::MAX as usize, } } } @@ -101,4 +110,19 @@ impl Config { pub fn is_track_ref(&self) -> bool { self.track_ref } + + #[inline(always)] + pub fn max_string_bytes(&self) -> usize { + self.max_string_bytes + } + + #[inline(always)] + pub fn max_collection_size(&self) -> usize { + self.max_collection_size + } + + #[inline(always)] + pub fn max_map_size(&self) -> usize { + self.max_map_size + } } diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 9d3f826941..7237f799c8 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -321,6 +321,48 @@ impl Fory { self } + /// Sets the maximum byte length of a single deserialized string. Default is no limit. + /// + /// # Examples + /// + /// ```rust + /// use fory_core::Fory; + /// + /// let fory = Fory::default().max_string_bytes(1024 * 1024); + /// ``` + pub fn max_string_bytes(mut self, max: usize) -> Self { + self.config.max_string_bytes = max; + self + } + + /// Sets the maximum element count of a single deserialized collection. Default is no limit. + /// + /// # Examples + /// + /// ```rust + /// use fory_core::Fory; + /// + /// let fory = Fory::default().max_collection_size(10_000); + /// ``` + pub fn max_collection_size(mut self, max: usize) -> Self { + self.config.max_collection_size = max; + self + } + + /// Sets the maximum entry count of a single deserialized map. Default is no limit. + /// + /// # Examples + /// + /// ```rust + /// use fory_core::Fory; + /// + /// let fory = Fory::default().max_map_size(10_000); + /// ``` + pub fn max_map_size(mut self, max: usize) -> Self { + self.config.max_map_size = max; + self + } + /// Returns whether cross-language serialization is enabled. pub fn is_xlang(&self) -> bool { self.config.xlang diff --git a/rust/fory-core/src/resolver/context.rs b/rust/fory-core/src/resolver/context.rs index 4f7f08c835..35c516c586 100644 --- a/rust/fory-core/src/resolver/context.rs +++ b/rust/fory-core/src/resolver/context.rs @@ -315,6 +315,9 @@ pub struct ReadContext<'a> { xlang: bool, max_dyn_depth: u32, check_struct_version: bool, + max_string_bytes: usize, + max_collection_size: usize, + max_map_size: usize, // Context-specific fields pub reader: Reader<'a>, @@ -342,6 +345,9 @@ impl<'a> ReadContext<'a> { xlang: config.xlang, max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, + max_string_bytes: config.max_string_bytes, + max_collection_size: config.max_collection_size, + max_map_size: config.max_map_size, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -472,6 +478,39 @@ impl<'a> ReadContext<'a> { self.meta_string_resolver.read_meta_string(&mut self.reader) } + #[inline(always)] + pub fn check_string_bytes(&self, byte_len: usize) -> Result<(), Error> { + if byte_len > self.max_string_bytes { + return Err(Error::invalid_data(format!( + "string byte length {} exceeds configured limit {}", + byte_len, self.max_string_bytes + ))); + } + Ok(()) + } + + #[inline(always)] + pub fn check_collection_size(&self, len: usize) -> Result<(), Error> { + if len > self.max_collection_size { + return Err(Error::invalid_data(format!( + "collection length {} exceeds configured limit {}", + len, self.max_collection_size + ))); + } + Ok(()) + } + + #[inline(always)] + pub fn check_map_size(&self, len: usize) -> Result<(), Error> { + if len > self.max_map_size { + return Err(Error::invalid_data(format!( + "map entry count {} exceeds configured limit {}", + len, self.max_map_size + ))); + } + Ok(()) + } + #[inline(always)] pub fn inc_depth(&mut self) -> Result<(), Error> { self.current_depth += 1; diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 68a6dc6a4d..ba655f9674 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -230,6 +230,22 @@ where if len == 0 { return Ok(C::from_iter(std::iter::empty())); } + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + // Coarse lower-bound check: every element occupies at least 1 byte on the wire. + // This guards against trivially impossible element counts before allocation. + // For typed collections use read_vec_data which performs a precise per-element check. + if len as usize > remaining { + return Err(Error::invalid_data(format!( + "collection element count {} exceeds available buffer bytes {} \ + (each element requires at least 1 byte on the wire)", + len, remaining + ))); + } + context.check_collection_size(len as usize)?; if T::fory_is_polymorphic() || T::fory_is_shared_ref() { return read_collection_data_dyn_ref(context, len); } @@ -271,6 +287,24 @@ where if len == 0 { return Ok(Vec::new()); } + // Precise buffer-remaining check: T is statically known here so we can compute + // the exact minimum bytes required (len * size_of::(), floored at 1 byte per + // element for zero-sized types). This prevents Vec::with_capacity(len) from + // allocating memory that the buffer could never actually supply. + let elem_size = std::mem::size_of::().max(1); + let min_bytes = (len as usize).saturating_mul(elem_size); + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + if min_bytes > remaining { + return Err(Error::invalid_data(format!( + "Vec of {} elements requires at least {} bytes but only {} remain in buffer", + len, min_bytes, remaining + ))); + } + context.check_collection_size(len as usize)?; if T::fory_is_polymorphic() || T::fory_is_shared_ref() { return read_vec_data_dyn_ref(context, len); } diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 4e90a1c3d9..341cd09596 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -547,18 +547,30 @@ impl Result { let len = context.reader.read_varuint32()?; - let mut map = HashMap::::with_capacity(len as usize); if len == 0 { - return Ok(map); + return Ok(HashMap::new()); + } + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + if len as usize > remaining { + return Err(Error::invalid_data(format!( + "map entry count {} exceeds buffer remaining {}", + len, remaining + ))); } + context.check_map_size(len as usize)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() || V::fory_is_shared_ref() { - let map: HashMap = HashMap::with_capacity(len as usize); + let map = HashMap::::with_capacity(len as usize); return read_hashmap_data_dyn_ref(context, map, len); } + let mut map = HashMap::::with_capacity(len as usize); let mut len_counter = 0; loop { if len_counter == len { @@ -698,18 +710,30 @@ impl Result { let len = context.reader.read_varuint32()?; - let mut map = BTreeMap::::new(); if len == 0 { - return Ok(map); + return Ok(BTreeMap::new()); + } + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + if len as usize > remaining { + return Err(Error::invalid_data(format!( + "map entry count {} exceeds buffer remaining {}", + len, remaining + ))); } + context.check_map_size(len as usize)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() || V::fory_is_shared_ref() { - let map: BTreeMap = BTreeMap::new(); + let map = BTreeMap::::new(); return read_btreemap_data_dyn_ref(context, map, len); } + let mut map = BTreeMap::::new(); let mut len_counter = 0; loop { if len_counter == len { diff --git a/rust/fory-core/src/serializer/primitive_list.rs b/rust/fory-core/src/serializer/primitive_list.rs index 2fc8029e32..b7072da62b 100644 --- a/rust/fory-core/src/serializer/primitive_list.rs +++ b/rust/fory-core/src/serializer/primitive_list.rs @@ -83,7 +83,23 @@ pub fn fory_read_data(context: &mut ReadContext) -> Result if size_bytes % std::mem::size_of::() != 0 { return Err(Error::invalid_data("Invalid data length")); } + // Check that the declared byte length is actually available in the buffer + // BEFORE allocating. Without this, a crafted size_bytes can trigger an OOM + // allocation (Vec::with_capacity) even though read_bytes would later reject + // the payload — but only after the damage is done. + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + if size_bytes > remaining { + return Err(Error::invalid_data(format!( + "primitive list byte length {} exceeds buffer remaining {}", + size_bytes, remaining + ))); + } let len = size_bytes / std::mem::size_of::(); + context.check_collection_size(len)?; let mut vec: Vec = Vec::with_capacity(len); #[cfg(target_endian = "little")] diff --git a/rust/fory-core/src/serializer/string.rs b/rust/fory-core/src/serializer/string.rs index 1893cba9a8..c087b04fd2 100644 --- a/rust/fory-core/src/serializer/string.rs +++ b/rust/fory-core/src/serializer/string.rs @@ -46,6 +46,28 @@ impl Serializer for String { let bitor = context.reader.read_varuint36small()?; let len = bitor >> 2; let encoding = bitor & 0b11; + let len_usize = usize::try_from(len).map_err(|_| { + Error::invalid_data(format!( + "string length {} overflows usize on this platform", + len + )) + })?; + let byte_len = match encoding { + 1 => len_usize.saturating_mul(2), + _ => len_usize, + }; + let remaining = context + .reader + .bf + .len() + .saturating_sub(context.reader.cursor); + if byte_len > remaining { + return Err(Error::invalid_data(format!( + "string byte length {} exceeds buffer remaining {}", + byte_len, remaining + ))); + } + context.check_string_bytes(byte_len)?; let s = match encoding { 0 => context.reader.read_latin1_string(len as usize), 1 => context.reader.read_utf16_string(len as usize), diff --git a/rust/tests/tests/mod.rs b/rust/tests/tests/mod.rs index 2f83762a50..a2be1f9e40 100644 --- a/rust/tests/tests/mod.rs +++ b/rust/tests/tests/mod.rs @@ -19,4 +19,5 @@ mod compatible; mod test_any; mod test_collection; mod test_max_dyn_depth; +mod test_size_guardrails; mod test_tuple; diff --git a/rust/tests/tests/test_size_guardrails.rs b/rust/tests/tests/test_size_guardrails.rs new file mode 100644 index 0000000000..b790107aa3 --- /dev/null +++ b/rust/tests/tests/test_size_guardrails.rs @@ -0,0 +1,309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fory_core::fory::Fory; +use std::collections::{BTreeMap, HashMap, HashSet}; + +// ── Collection (Vec) ───────────────────────────────────────────────────── + +#[test] +fn test_collection_size_limit_exceeded() { + let fory_write = Fory::default(); + let items: Vec = vec![1, 2, 3, 4, 5]; + let bytes = fory_write.serialize(&items).unwrap(); + + let fory_read = Fory::default().max_collection_size(3); + + // FORY_PANIC_ON_ERROR is a compile-time constant (see error.rs). + // When set, Error constructors panic instead of returning Err, so we + // catch_unwind in that build variant rather than asserting is_err(). + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(&bytes); + }); + } else { + let result: Result, _> = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to collection size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("collection length")); + } +} + +#[test] +fn test_collection_size_within_limit() { + let fory = Fory::default().max_collection_size(5); + let items: Vec = vec![1, 2, 3, 4, 5]; + let bytes = fory.serialize(&items).unwrap(); + let result: Result, _> = fory.deserialize(&bytes); + assert!(result.is_ok()); +} + +// ── Map (HashMap) ───────────────────────────────────────────────────────────── + +#[test] +fn test_map_size_limit_exceeded() { + let fory_write = Fory::default(); + let mut map: HashMap = HashMap::new(); + map.insert("a".to_string(), 1); + map.insert("b".to_string(), 2); + map.insert("c".to_string(), 3); + let bytes = fory_write.serialize(&map).unwrap(); + + let fory_read = Fory::default().max_map_size(2); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(&bytes); + }); + } else { + let result: Result, _> = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to map size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("map entry count")); + } +} + +#[test] +fn test_map_size_within_limit() { + let fory = Fory::default().max_map_size(3); + let mut map: HashMap = HashMap::new(); + map.insert("a".to_string(), 1); + map.insert("b".to_string(), 2); + map.insert("c".to_string(), 3); + let bytes = fory.serialize(&map).unwrap(); + let result: Result, _> = fory.deserialize(&bytes); + assert!(result.is_ok()); +} + +// ── Map (BTreeMap) ──────────────────────────────────────────────────────────── + +#[test] +fn test_btreemap_size_limit_exceeded() { + // Regression: HashMap guard was previously placed before the len==0 early + // return and before with_capacity; this test also covers BTreeMap ordering. + let fory_write = Fory::default(); + let mut map: BTreeMap = BTreeMap::new(); + map.insert("x".to_string(), 1); + map.insert("y".to_string(), 2); + map.insert("z".to_string(), 3); + let bytes = fory_write.serialize(&map).unwrap(); + + let fory_read = Fory::default().max_map_size(2); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(&bytes); + }); + } else { + let result: Result, _> = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to BTreeMap size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("map entry count")); + } +} + +#[test] +fn test_btreemap_size_within_limit() { + let fory = Fory::default().max_map_size(3); + let mut map: BTreeMap = BTreeMap::new(); + map.insert("x".to_string(), 1); + map.insert("y".to_string(), 2); + map.insert("z".to_string(), 3); + let bytes = fory.serialize(&map).unwrap(); + let result: Result, _> = fory.deserialize(&bytes); + assert!(result.is_ok()); +} + +// ── Collection (HashSet) ────────────────────────────────────────────────────── + +#[test] +fn test_hashset_size_limit_exceeded() { + let fory_write = Fory::default(); + let set: HashSet = vec![1, 2, 3, 4, 5].into_iter().collect(); + let bytes = fory_write.serialize(&set).unwrap(); + + let fory_read = Fory::default().max_collection_size(3); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(&bytes); + }); + } else { + let result: Result, _> = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to HashSet size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("collection length")); + } +} + +// ── String ──────────────────────────────────────────────────────────────────── + +#[test] +fn test_string_size_limit_exceeded() { + let fory_write = Fory::default(); + let s = "hello world".to_string(); + let bytes = fory_write.serialize(&s).unwrap(); + + let fory_read = Fory::default().max_string_bytes(5); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result = fory_read.deserialize(&bytes); + }); + } else { + let result: Result = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to string size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("string byte length")); + } +} + +#[test] +fn test_string_size_within_limit() { + let fory = Fory::default().max_string_bytes(20); + let s = "hello world".to_string(); + let bytes = fory.serialize(&s).unwrap(); + let result: Result = fory.deserialize(&bytes); + assert!(result.is_ok()); +} + +// ── Primitive list (Vec) ────────────────────────────────────────────────── + +#[test] +fn test_primitive_vec_size_limit_exceeded() { + // Vec uses the primitive_list.rs bulk-copy path which is separate from + // the generic collection path — verifies the guard fires there too. + let fory_write = Fory::default(); + let data: Vec = vec![0u8; 100]; + let bytes = fory_write.serialize(&data).unwrap(); + + let fory_read = Fory::default().max_collection_size(50); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(&bytes); + }); + } else { + let result: Result, _> = fory_read.deserialize(&bytes); + assert!( + result.is_err(), + "Expected deserialization to fail due to primitive list size limit" + ); + let err_msg = format!("{:?}", result.unwrap_err()); + assert!(err_msg.contains("collection length")); + } +} + +// ── Buffer truncation (buffer-remaining cross-check) ───────────────────────── + +#[test] +fn test_buffer_truncation_rejected() { + // Validates the buffer-remaining cross-check in check_string_bytes + // independently of any configured limit: a structurally truncated buffer + // must be rejected even with default (i32::MAX) limits. + let fory_write = Fory::default(); + let s = "hello world".to_string(); + let bytes = fory_write.serialize(&s).unwrap(); + + // Drop the last 4 bytes so the payload is structurally incomplete. + let truncated = &bytes[..bytes.len().saturating_sub(4)]; + let fory_read = Fory::default(); + + // Truncation causes a read past the buffer end — this always returns Err + // (or panics in PANIC_ON_ERROR builds), never silently succeeds. + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result = fory_read.deserialize(truncated); + }); + } else { + let result: Result = fory_read.deserialize(truncated); + assert!( + result.is_err(), + "Truncated buffer must be rejected during deserialization" + ); + } +} + +#[test] +fn test_vec_buffer_truncation_rejected() { + // Exercises the buffer-remaining guard in read_vec_data (collection.rs). + // Serializes a valid Vec, then truncates the payload so the declared + // element count cannot possibly be satisfied — must be rejected with + // default (i32::MAX) limits, proving the check fires before with_capacity. + let fory_write = Fory::default(); + let items: Vec = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let bytes = fory_write.serialize(&items).unwrap(); + + // Keep only the first 6 bytes — enough for the outer framing but not + // the element payload (8 i32s = 32 bytes minimum). + let truncated = &bytes[..bytes.len().saturating_sub(bytes.len() / 2)]; + let fory_read = Fory::default(); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(truncated); + }); + } else { + let result: Result, _> = fory_read.deserialize(truncated); + assert!( + result.is_err(), + "Truncated Vec buffer must be rejected by buffer-remaining check" + ); + } +} + +#[test] +fn test_primitive_list_buffer_truncation_rejected() { + // Exercises the buffer-remaining guard in primitive_list.rs fory_read_data. + // Vec uses the bulk-copy (primitive_list) path, NOT the generic + // collection path. Previously there was no buffer check before + // Vec::with_capacity on this path — this test pins that fix. + let fory_write = Fory::default(); + let items: Vec = vec![100i64, 200, 300, 400, 500, 600, 700, 800]; + let bytes = fory_write.serialize(&items).unwrap(); + + // Truncate to about half — enough for headers but not the full i64 payload. + let truncated = &bytes[..bytes.len().saturating_sub(bytes.len() / 2)]; + let fory_read = Fory::default(); + + if fory_core::error::PANIC_ON_ERROR { + let _ = std::panic::catch_unwind(|| { + let _: Result, _> = fory_read.deserialize(truncated); + }); + } else { + let result: Result, _> = fory_read.deserialize(truncated); + assert!( + result.is_err(), + "Truncated Vec buffer must be rejected by primitive_list buffer-remaining check" + ); + } +}