diff --git a/diskann-tools/src/utils/gen_associated_data_from_range.rs b/diskann-tools/src/utils/gen_associated_data_from_range.rs index 7d91ec865..e28ce7d21 100644 --- a/diskann-tools/src/utils/gen_associated_data_from_range.rs +++ b/diskann-tools/src/utils/gen_associated_data_from_range.rs @@ -8,7 +8,7 @@ use std::io::Write; use diskann_providers::storage::StorageWriteProvider; use diskann_utils::io::Metadata; -use super::CMDResult; +use super::{CMDResult, CMDToolError}; pub fn gen_associated_data_from_range( storage_provider: &S, @@ -16,10 +16,21 @@ pub fn gen_associated_data_from_range( start: u32, end: u32, ) -> CMDResult<()> { + if end < start { + return Err(CMDToolError { + details: format!( + "invalid range: end ({end}) must be greater than or equal to start ({start})" + ), + }); + } + let mut file = storage_provider.create_for_write(associated_data_path)?; // Calculate the number of integers and the number of integers in associated data - let num_ints = end - start + 1; + // Use checked arithmetic to avoid overflow when end == u32::MAX and start == 0 + let num_ints = (end - start).checked_add(1).ok_or_else(|| CMDToolError { + details: format!("range [{start}, {end}] is too large: count overflows u32"), + })?; let int_length: u32 = 1; // Write the number of integers and the length of each integer as little endian @@ -105,4 +116,37 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn test_gen_associated_data_from_range_end_less_than_start() { + let storage_provider = VirtualStorageProvider::new_memory(); + let path = "/test_gen_associated_data_invalid.bin"; + + let result = gen_associated_data_from_range(&storage_provider, path, 10, 5); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.details.contains("end") && err.details.contains("start"), + "error message should mention end and start: {err}" + ); + assert!( + err.details.contains("5") && err.details.contains("10"), + "error message should include the specific values 5 and 10: {err}" + ); + } + + #[test] + fn test_gen_associated_data_from_range_max_overflow() { + let storage_provider = VirtualStorageProvider::new_memory(); + let path = "/test_gen_associated_data_overflow.bin"; + + // end == u32::MAX and start == 0 would make count == u32::MAX + 1, which overflows + let result = gen_associated_data_from_range(&storage_provider, path, 0, u32::MAX); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.details.contains("overflow") || err.details.contains("too large"), + "error message should mention overflow or too large: {err}" + ); + } }