Skip to content
Open
23 changes: 7 additions & 16 deletions datasketches/src/bloom/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use std::hash::Hasher;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in_range;
use crate::codec::utility::ensure_serial_version_is;
use crate::error::Error;
use crate::hash::XxHash64;

Expand Down Expand Up @@ -412,22 +414,11 @@ impl BloomFilter {

// Validate
Family::BLOOMFILTER.validate_id(family_id)?;
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
if !(Family::BLOOMFILTER.min_pre_longs..=Family::BLOOMFILTER.max_pre_longs)
.contains(&preamble_longs)
{
return Err(Error::deserial(format!(
"invalid preamble longs: expected [{}, {}], got {}",
Family::BLOOMFILTER.min_pre_longs,
Family::BLOOMFILTER.max_pre_longs,
preamble_longs
)));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
ensure_preamble_longs_in_range(
Family::BLOOMFILTER.min_pre_longs..=Family::BLOOMFILTER.max_pre_longs,
preamble_longs,
)?;

let is_empty = (flags & EMPTY_FLAG_MASK) != 0;

Expand Down
8 changes: 8 additions & 0 deletions datasketches/src/codec/family.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ pub struct Family {
}

impl Family {
/// Theta Sketch for cardinality estimation.
pub const THETA: Family = Family {
id: 3,
name: "THETA",
min_pre_longs: 1,
max_pre_longs: 3,
};

/// The HLL family of sketches.
pub const HLL: Family = Family {
id: 7,
Expand Down
1 change: 1 addition & 0 deletions datasketches/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ pub use self::encode::SketchBytes;

// private to datasketches crate
pub(crate) mod family;
pub(crate) mod utility;
65 changes: 65 additions & 0 deletions datasketches/src/codec/utility.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::Bound;
use std::ops::RangeBounds;

use crate::error::Error;

pub(crate) fn ensure_serial_version_is(expected: u8, actual: u8) -> Result<(), Error> {
if expected == actual {
Ok(())
} else {
Err(Error::deserial(format!(
"unsupported serial version: expected {expected}, got {actual}"
)))
}
}

pub(crate) fn ensure_preamble_longs_in(expected: &[u8], actual: u8) -> Result<(), Error> {
if expected.contains(&actual) {
Ok(())
} else {
Err(Error::invalid_preamble_longs(expected, actual))
}
}

pub(crate) fn ensure_preamble_longs_in_range(
expected: impl RangeBounds<u8>,
actual: u8,
) -> Result<(), Error> {
let start = expected.start_bound();
let end = expected.end_bound();
if expected.contains(&actual) {
Ok(())
} else {
Err(Error::deserial(format!(
"invalid preamble longs: expected {}, got {actual}",
match (start, end) {
(Bound::Included(a), Bound::Included(b)) => format!("[{a}, {b}]"),
(Bound::Included(a), Bound::Excluded(b)) => format!("[{a}, {b})"),
(Bound::Excluded(a), Bound::Included(b)) => format!("({a}, {b}]"),
(Bound::Excluded(a), Bound::Excluded(b)) => format!("({a}, {b})"),
(Bound::Unbounded, Bound::Included(b)) => format!("at most {b}"),
(Bound::Unbounded, Bound::Excluded(b)) => format!("less than {b}"),
(Bound::Included(a), Bound::Unbounded) => format!("at least {a}"),
(Bound::Excluded(a), Bound::Unbounded) => format!("greater than {a}"),
(Bound::Unbounded, Bound::Unbounded) => unreachable!("unbounded range"),
}
)))
}
}
4 changes: 2 additions & 2 deletions datasketches/src/common/binomial_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ pub(crate) fn lower_bound(
/// # Arguments
///
/// * `num_samples` - The number of samples in the sample set.
/// * `theta` - The sampling probability. Must be in the range (0.0, 1.0].
/// * `theta` - The sampling probability. Must be in the range `(0.0, 1.0]`.
/// * `num_std_dev` - The number of standard deviations for confidence bounds.
/// * `no_data_seen` - This is normally false. However, in the case where you have zero samples and
/// a theta < 1.0, this flag enables the distinction between a virgin case when no actual data has
Expand All @@ -315,7 +315,7 @@ pub(crate) fn lower_bound(
///
/// # Errors
///
/// Returns an error if `theta` is not in the range (0.0, 1.0].
/// Returns an error if `theta` is not in the range `(0.0, 1.0]`.
pub(crate) fn upper_bound(
num_samples: u64,
theta: f64,
Expand Down
16 changes: 4 additions & 12 deletions datasketches/src/countmin/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use std::hash::Hasher;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in;
use crate::codec::utility::ensure_serial_version_is;
use crate::countmin::CountMinValue;
use crate::countmin::UnsignedCountMinValue;
use crate::countmin::serialization::FLAGS_IS_EMPTY;
Expand Down Expand Up @@ -350,18 +352,8 @@ impl<T: CountMinValue> CountMinSketch<T> {
cursor.read_u32_le().map_err(make_error("<unused>"))?;

Family::COUNTMIN.validate_id(family_id)?;
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
if preamble_longs != PREAMBLE_LONGS_SHORT {
return Err(Error::invalid_preamble_longs(
PREAMBLE_LONGS_SHORT,
preamble_longs,
));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
ensure_preamble_longs_in(&[PREAMBLE_LONGS_SHORT], preamble_longs)?;

let num_buckets = cursor.read_u32_le().map_err(make_error("num_buckets"))?;
let num_hashes = cursor.read_u8().map_err(make_error("num_hashes"))?;
Expand Down
16 changes: 4 additions & 12 deletions datasketches/src/cpc/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::hash::Hash;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in;
use crate::codec::utility::ensure_serial_version_is;
use crate::common::NumStdDev;
use crate::common::canonical_double;
use crate::common::inv_pow2_table::INVERSE_POWERS_OF_2;
Expand Down Expand Up @@ -518,12 +520,7 @@ impl CpcSketch {
let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?;
let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
Family::CPC.validate_id(family_id)?;
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;

let lg_k = cursor.read_u8().map_err(make_error("lg_k"))?;
let first_interesting_column = cursor
Expand Down Expand Up @@ -594,12 +591,7 @@ impl CpcSketch {

let expected_preamble_ints =
make_preamble_ints(num_coupons, has_hip, has_table, has_window);
if preamble_ints != expected_preamble_ints {
return Err(Error::invalid_preamble_longs(
expected_preamble_ints,
preamble_ints,
));
}
ensure_preamble_longs_in(&[expected_preamble_ints], preamble_ints)?;
if seed_hash != compute_seed_hash(seed) {
return Err(Error::new(
ErrorKind::InvalidData,
Expand Down
12 changes: 3 additions & 9 deletions datasketches/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,9 @@ impl Error {
))
}

pub(crate) fn unsupported_serial_version(expected: u8, actual: u8) -> Self {
Self::deserial(format!(
"unsupported serial version: expected {expected}, got {actual}"
))
}

pub(crate) fn invalid_preamble_longs(expected: u8, actual: u8) -> Self {
Self::deserial(format!(
"invalid preamble longs: expected {expected}, got {actual}"
pub(crate) fn invalid_preamble_longs(expected: &[u8], actual: u8) -> Self {
Error::deserial(format!(
"invalid preamble longs: expected {expected:?}, got {actual}"
))
}
}
Expand Down
29 changes: 7 additions & 22 deletions datasketches/src/frequencies/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::hash::Hash;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in;
use crate::codec::utility::ensure_serial_version_is;
use crate::error::Error;
use crate::frequencies::reverse_purge_item_hash_map::ReversePurgeItemHashMap;
use crate::frequencies::serialization::*;
Expand Down Expand Up @@ -141,7 +143,7 @@ impl<T: Eq + Hash> FrequentItemsSketch<T> {

/// Returns the estimated frequency for an item.
///
/// If the item is tracked, this is `item_count + offset`. Otherwise it is zero.
/// If the item is tracked, this is `item_count + offset`. Otherwise, it is zero.
///
/// # Examples
///
Expand Down Expand Up @@ -464,35 +466,18 @@ impl<T: Eq + Hash> FrequentItemsSketch<T> {
cursor.read_u16_le().map_err(make_error("<unused>"))?;

Family::FREQUENCY.validate_id(family)?;
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
if lg_cur > lg_max {
return Err(Error::deserial("lg_cur_map_size exceeds lg_max_map_size"));
}

let is_empty = (flags & EMPTY_FLAG_MASK) != 0;
if is_empty {
return if pre_longs != PREAMBLE_LONGS_EMPTY {
Err(Error::invalid_preamble_longs(
PREAMBLE_LONGS_EMPTY,
pre_longs,
))
} else {
Ok(Self::with_lg_map_sizes(lg_max, lg_cur))
};
}

if pre_longs != PREAMBLE_LONGS_NONEMPTY {
return Err(Error::invalid_preamble_longs(
PREAMBLE_LONGS_NONEMPTY,
pre_longs,
));
ensure_preamble_longs_in(&[PREAMBLE_LONGS_EMPTY], pre_longs)?;
return Ok(Self::with_lg_map_sizes(lg_max, lg_cur));
}

ensure_preamble_longs_in(&[PREAMBLE_LONGS_NONEMPTY], pre_longs)?;
let active_items = cursor.read_u32_le().map_err(make_error("active_items"))?;
let active_items = active_items as usize;
cursor.read_u32_le().map_err(make_error("<unused>"))?;
Expand Down
8 changes: 2 additions & 6 deletions datasketches/src/hll/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::hash::Hash;

use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_serial_version_is;
use crate::common::NumStdDev;
use crate::error::Error;
use crate::hll::HllType;
Expand Down Expand Up @@ -281,12 +282,7 @@ impl HllSketch {
Family::HLL.validate_id(family_id)?;

// Verify serialization version
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;

// Verify lg_k range (4-21 are valid)
if !(4..=21).contains(&lg_config_k) {
Expand Down
16 changes: 4 additions & 12 deletions datasketches/src/tdigest/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use std::num::NonZeroU64;
use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in;
use crate::codec::utility::ensure_serial_version_is;
use crate::error::Error;
use crate::tdigest::serialization::*;

Expand Down Expand Up @@ -501,12 +503,7 @@ impl TDigestMut {
Err(err)
};
}
if serial_version != SERIAL_VERSION {
return Err(Error::unsupported_serial_version(
SERIAL_VERSION,
serial_version,
));
}
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
let k = cursor.read_u16_le().map_err(make_error("k"))?;
if k < 10 {
return Err(Error::deserial(format!("k must be at least 10, got {k}")));
Expand All @@ -519,12 +516,7 @@ impl TDigestMut {
} else {
PREAMBLE_LONGS_MULTIPLE
};
if preamble_longs != expected_preamble_longs {
return Err(Error::invalid_preamble_longs(
expected_preamble_longs,
preamble_longs,
));
}
ensure_preamble_longs_in(&[expected_preamble_longs], preamble_longs)?;
cursor.read_u16_le().map_err(make_error("<unused>"))?; // unused
if is_empty {
return Ok(TDigestMut::new(k));
Expand Down
Loading