diff --git a/datasketches/tests/countmin_serialization_test.rs b/datasketches/tests/countmin_serialization_test.rs new file mode 100644 index 0000000..c88bcd5 --- /dev/null +++ b/datasketches/tests/countmin_serialization_test.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +use std::fs; + +use common::serialization_test_data; +use datasketches::countmin::CountMinSketch; +use googletest::assert_that; +use googletest::prelude::contains_substring; + +// This test validates binary format compatibility (deserialize + byte round-trip) for +// C++ Count-Min snapshots. It intentionally does not assert estimate equivalence against +// original input keys because per-row hash seed derivation differs across implementations. +fn assert_cpp_snapshot( + filename: &str, + seed: u64, + expected_num_hashes: u8, + expected_num_buckets: u32, + expected_total_weight: u64, +) { + let path = serialization_test_data("cpp_generated_files", filename); + let bytes = fs::read(&path).unwrap(); + + let sketch = CountMinSketch::::deserialize_with_seed(&bytes, seed).unwrap(); + + assert_eq!(sketch.num_hashes(), expected_num_hashes); + assert_eq!(sketch.num_buckets(), expected_num_buckets); + assert_eq!(sketch.seed(), seed); + assert_eq!(sketch.total_weight(), expected_total_weight); + assert_eq!(sketch.is_empty(), expected_total_weight == 0); + + let roundtrip = sketch.serialize(); + assert_eq!(roundtrip, bytes, "round-trip bytes differ for {filename}"); +} + +#[test] +fn test_deserialize_cpp_empty_snapshot() { + assert_cpp_snapshot("countmin_empty_cpp.sk", 9001, 1, 5, 0); +} + +#[test] +fn test_deserialize_cpp_non_empty_snapshot() { + assert_cpp_snapshot("countmin_non_empty_cpp.sk", 9001, 3, 1024, 2850); +} + +#[test] +fn test_deserialize_cpp_snapshot_with_wrong_seed() { + let path = serialization_test_data("cpp_generated_files", "countmin_non_empty_cpp.sk"); + let bytes = fs::read(&path).unwrap(); + + let err = CountMinSketch::::deserialize_with_seed(&bytes, 9000).unwrap_err(); + assert_that!(err.message(), contains_substring("incompatible seed hash")); +} diff --git a/datasketches/tests/countmin_test.rs b/datasketches/tests/countmin_test.rs index dddc72c..c7e2068 100644 --- a/datasketches/tests/countmin_test.rs +++ b/datasketches/tests/countmin_test.rs @@ -16,6 +16,9 @@ // under the License. use datasketches::countmin::CountMinSketch; +use googletest::assert_that; +use googletest::prelude::ge; +use googletest::prelude::le; #[test] fn test_init_defaults() { @@ -41,7 +44,7 @@ fn test_parameter_suggestions() { let buckets = CountMinSketch::::suggest_num_buckets(0.1); let sketch = CountMinSketch::::new(3, buckets); - assert!(sketch.relative_error() <= 0.1); + assert_that!(sketch.relative_error(), le(0.1)); } #[test] @@ -54,8 +57,8 @@ fn test_update_and_bounds() { let estimate = sketch.estimate("x"); let upper = sketch.upper_bound("x"); let lower = sketch.lower_bound("x"); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); } #[test] @@ -67,8 +70,8 @@ fn test_update_and_bounds_with_scaling() { let upper = sketch.upper_bound("x"); let lower = sketch.lower_bound("x"); assert_eq!(estimate, 10); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); let eps = sketch.relative_error(); @@ -78,8 +81,8 @@ fn test_update_and_bounds_with_scaling() { let lower = sketch.lower_bound("x"); assert_eq!(sketch.total_weight(), 5); assert_eq!(estimate, 5); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); assert_eq!( upper, estimate + (eps * sketch.total_weight() as f64) as u64 @@ -91,8 +94,8 @@ fn test_update_and_bounds_with_scaling() { let lower = sketch.lower_bound("x"); assert_eq!(sketch.total_weight(), 2); assert_eq!(estimate, 2); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); assert_eq!( upper, estimate + (eps * sketch.total_weight() as f64) as u64 @@ -122,13 +125,13 @@ fn test_halve() { } for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= i as u64); + assert_that!(sketch.estimate(i as u64), ge(i as u64)); } sketch.halve(); for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= (i as u64) / 2); + assert_that!(sketch.estimate(i as u64), ge((i as u64) / 2)); } } @@ -145,7 +148,7 @@ fn test_decay() { } for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= i as u64); + assert_that!(sketch.estimate(i as u64), ge(i as u64)); } const FACTOR: f64 = 0.5; @@ -153,7 +156,7 @@ fn test_decay() { for i in 0..1000usize { let expected = ((i as f64) * FACTOR).floor() as u64; - assert!(sketch.estimate(i as u64) >= expected); + assert_that!(sketch.estimate(i as u64), ge(expected)); } } @@ -170,8 +173,8 @@ fn test_merge() { } left.merge(&right); assert_eq!(left.total_weight(), 18); - assert!(left.estimate("a") >= 14); - assert!(left.estimate("b") >= 4); + assert_that!(left.estimate("a"), ge(14)); + assert_that!(left.estimate("b"), ge(4)); } #[test] @@ -245,6 +248,6 @@ fn test_increment_multi_like_rust_count_min_sketch() { sketch.update(i % 100); } for key in 0..100u64 { - assert!(sketch.estimate(key) >= 9_000); + assert_that!(sketch.estimate(key), ge(9_000)); } } diff --git a/tools/generate_serialization_test_data.py b/tools/generate_serialization_test_data.py index c965f39..ad8f102 100755 --- a/tools/generate_serialization_test_data.py +++ b/tools/generate_serialization_test_data.py @@ -166,13 +166,22 @@ def generate_cpp_files(workspace_dir, project_root): output_dir.mkdir(parents=True, exist_ok=True) files_copied = 0 - # Search recursively in build directory for *_cpp.sk + + # Search recursively in build directory for standard C++ compatibility snapshots. for file_path in build_dir.rglob("*_cpp.sk"): - # Avoid copying from CMakeFiles or other intermediate dirs if possible, but the pattern is specific enough shutil.copy2(file_path, output_dir) print(f"Copied: {file_path.name}") files_copied += 1 + # Count-Min test binaries are produced as `count_min-*.bin`. + # Normalize names to match the repository snapshot convention. + for file_path in build_dir.rglob("count_min-*.bin"): + base_name = file_path.stem[len("count_min-"):].replace("-", "_") + output_name = f"countmin_{base_name}_cpp.sk" + shutil.copy2(file_path, output_dir / output_name) + print(f"Copied: {output_name} (from {file_path.name})") + files_copied += 1 + if files_copied == 0: print("Warning: No *_cpp.sk files were found to copy.") else: