diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index 76cb442dc..6dcda408a 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -112,8 +112,16 @@ impl App { // List the available benchmarks. Commands::Benchmarks {} => { writeln!(output, "Registered Benchmarks:")?; - for (name, method) in benchmarks.methods() { - writeln!(output, " {}: {}", name, method.signatures()[0])?; + for (name, description) in benchmarks.names() { + let mut lines = description.lines(); + if let Some(first) = lines.next() { + writeln!(output, " {}: {}", name, first)?; + for line in lines { + writeln!(output, " {}", line)?; + } + } else { + writeln!(output, " {}: ", name)?; + } } } Commands::Skeleton => { @@ -130,23 +138,10 @@ impl App { let run = Jobs::load(input_file, inputs)?; // Check if we have a match for each benchmark. for job in run.jobs().iter() { - if !benchmarks.has_match(job) { + const MAX_METHODS: usize = 3; + if let Err(mismatches) = benchmarks.debug(job, MAX_METHODS) { let repr = serde_json::to_string_pretty(&job.serialize()?)?; - const MAX_METHODS: usize = 3; - let mismatches = match benchmarks.debug(job, MAX_METHODS) { - // Debug should return `Err` if there is not a match. - // Returning `Ok(())` here indicates an internal error with the - // dispatcher. - Ok(()) => { - return Err(anyhow::Error::msg(format!( - "experienced internal error while debugging:\n{}", - repr - ))) - } - Err(m) => m, - }; - writeln!( output, "Could not find a match for the following input:\n\n{}\n", @@ -165,7 +160,7 @@ impl App { writeln!(output)?; return Err(anyhow::Error::msg( - "could not find find a benchmark for all inputs", + "could not find a benchmark for all inputs", )); } } diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs new file mode 100644 index 000000000..30eb28de3 --- /dev/null +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -0,0 +1,140 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use serde::Serialize; + +use crate::{ + dispatcher::{FailureScore, MatchScore}, + Any, Checkpoint, Input, Output, +}; + +/// A registered benchmark. +/// +/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will +/// first be validated with the benchmark using [`try_match`](Self::try_match). Only +/// successful matches will be passed to [`run`](Self::run). +pub trait Benchmark { + /// The [`Input`] type this benchmark matches against. + type Input: Input + 'static; + + /// The concrete type of the results generated by this benchmark. + type Output: Serialize; + + /// Return whether or not this benchmark is compatible with `input`. + /// + /// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be + /// collected and the benchmark with the lowest final score will be selected. + /// + /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure. + /// + /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`] + /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations + /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging. + fn try_match(input: &Self::Input) -> Result; + + /// Return descriptive information about the benchmark. + /// + /// If `input` is `None`, then high level information about the benchmark should be relayed. + /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what + /// was expected should be generated to help users. + fn description( + f: &mut std::fmt::Formatter<'_>, + input: Option<&Self::Input>, + ) -> std::fmt::Result; + + /// Run the benchmark with `input`. + /// + /// All prints should be directed to `output`. The `checkpoint` is provided so + /// long-running benchmarks can periodically save output to prevent data loss due to + /// an early error. + /// + /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`. + fn run( + input: &Self::Input, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result; +} + +////////////// +// Internal // +////////////// + +/// Object-safe trait for type-erased benchmarks stored in the registry. +pub(crate) trait DynBenchmark { + fn try_match(&self, input: &Any) -> Result; + + fn description(&self, f: &mut std::fmt::Formatter<'_>, input: Option<&Any>) + -> std::fmt::Result; + + fn run( + &self, + input: &Any, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result; +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct Wrapper(std::marker::PhantomData); + +impl Wrapper { + pub(crate) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + +/// The score given to unsuccessful downcasts in [`DynBenchmark::try_match`]. +const MATCH_FAIL: FailureScore = FailureScore(10_000); + +impl DynBenchmark for Wrapper +where + T: Benchmark, +{ + fn try_match(&self, input: &Any) -> Result { + if let Some(cast) = input.downcast_ref::() { + T::try_match(cast) + } else { + Err(MATCH_FAIL) + } + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&Any>, + ) -> std::fmt::Result { + match input { + Some(input) => match input.downcast_ref::() { + Some(cast) => T::description(f, Some(cast)), + None => write!( + f, + "expected tag \"{}\" - instead got \"{}\"", + T::Input::tag(), + input.tag(), + ), + }, + None => { + writeln!(f, "tag \"{}\"", ::tag())?; + T::description(f, None) + } + } + } + + fn run( + &self, + input: &Any, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + match input.downcast_ref::() { + Some(input) => { + let result = T::run(input, checkpoint, output)?; + Ok(serde_json::to_value(result)?) + } + None => Err(anyhow::anyhow!("INTERNAL ERROR: invalid downcast!")), + } + } +} diff --git a/diskann-benchmark-runner/src/dispatcher/api.rs b/diskann-benchmark-runner/src/dispatcher/api.rs index 560947605..aeedf784e 100644 --- a/diskann-benchmark-runner/src/dispatcher/api.rs +++ b/diskann-benchmark-runner/src/dispatcher/api.rs @@ -8,7 +8,7 @@ use std::fmt::{self, Display, Formatter}; /// Successful matches from [`DispatchRule`] will return `MatchScores`. /// /// A lower numerical value indicates a better match for purposes of overload resolution. -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct MatchScore(pub u32); impl Display for MatchScore { @@ -21,7 +21,7 @@ impl Display for MatchScore { /// /// A lower numerical value indicates a better match, which can help when compiling a /// list of considered and rejected candidates. -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct FailureScore(pub u32); impl Display for FailureScore { @@ -218,155 +218,6 @@ impl<'a, T: Sized> DispatchRule<&'a mut T> for &'a T { } } -/// # Lifetime Mapping -/// -/// The types in signatures for dispatches need to be `'static` due to Rust. -/// However, it is nice to allow objects with lifetimes to cross the dispatcher boundary. -/// -/// The `Map` trait facilitates this by allowing `'static` types to have an optional -/// lifetime attached as a generic associated type. -/// -/// This associated type is that is what is actually given to dispatcher methods. -/// -/// ## Example -/// -/// To pass a `Vec` across a dispatcher boundary, we can use the [`Type`] helper: -/// -/// ``` -/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Type}; -/// -/// let mut d = Dispatcher1::<&'static str, Type>>::new(); -/// d.register::<_, Type>>("method", |_: Vec| "called"); -/// assert_eq!(d.call(vec![1.0]), Some("called")); -/// ``` -/// -/// This is a bit tedious to write every time, so instead types can implement [`Map`] for -/// themselves: -/// -/// ``` -/// use diskann_benchmark_runner::{self_map, dispatcher::{Dispatcher1}}; -/// -/// struct MyNum(f32); -/// self_map!(MyNum); -/// -/// // Now, `MyNum` can be used directly in dispatcher signatures. -/// let mut d = Dispatcher1::::new(); -/// d.register::<_, MyNum>("method", |n: MyNum| n.0); -/// assert_eq!(d.call(MyNum(0.0)), Some(0.0)); -/// ``` -/// -/// ## See Also: -/// -/// * [`Ref`]: Mapping References -/// * [`MutRef`]: Mapping Mutable References -/// * [`Type`]: Mapper for generic types -/// * [`crate::self_map!`]: Allow types to represent themselves in dispatcher signatures. -/// -pub trait Map: 'static { - /// The actual type provided to the dispatcher, with an optional additional lifetime. - type Type<'a>; -} - -/// Allow references to cross dispatcher boundaries as shown in the following example: -/// -/// ``` -/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, Ref}; -/// -/// let mut d = Dispatcher1::<*const f32, Ref<[f32]>>::new(); -/// d.register::<_, Ref<[f32]>>("method", |data: &[f32]| data.as_ptr()); -/// -/// let v = vec![1.0, 2.0]; -/// assert_eq!(d.call(&v), Some(v.as_ptr())); -/// ``` -pub struct Ref(std::marker::PhantomData); - -impl Map for Ref { - type Type<'a> = &'a T; -} - -/// Allow mutable references to cross dispatcher boundaries as shown below. -/// -/// ``` -/// use diskann_benchmark_runner::dispatcher::{Dispatcher1, MutRef}; -/// -/// let mut d = Dispatcher1::<(), MutRef>>::new(); -/// d.register::<_, MutRef>>("method", |v: &mut Vec| v.push(0.0)); -/// -/// let mut v = Vec::new(); -/// d.call(&mut v).unwrap(); -/// assert_eq!(&v, &[0.0]); -/// ``` -pub struct MutRef(std::marker::PhantomData); -impl Map for MutRef { - type Type<'a> = &'a mut T; -} - -pub struct Type(std::marker::PhantomData); -impl Map for Type { - type Type<'a> = T; -} - -#[macro_export] -macro_rules! self_map { - ($($type:tt)*) => { - impl $crate::dispatcher::Map for $($type)* { - type Type<'a> = $($type)*; - } - } -} - -self_map!(bool); -self_map!(usize); -self_map!(u8); -self_map!(u16); -self_map!(u32); -self_map!(u64); -self_map!(u128); -self_map!(i8); -self_map!(i16); -self_map!(i32); -self_map!(i64); -self_map!(i128); -self_map!(String); -self_map!(f32); -self_map!(f64); - -/// Reasons for a method call mismatch. -/// -/// The name of the associated method can be queried using `self.method()` and reasons -/// are obtained in `self.mismatches()`. -pub struct ArgumentMismatch<'a, const N: usize> { - pub(crate) method: &'a str, - pub(crate) mismatches: [Option>; N], -} - -impl<'a, const N: usize> ArgumentMismatch<'a, N> { - /// Return the name of the associated method. - pub fn method(&self) -> &str { - self.method - } - - /// Return a slice of reasons for method match failure. - /// - /// The returned slice contains one entry per argument. An entry is `None` if that - /// argument matched the input value. - /// - /// If the argument did not match the input value, then the corresponding - /// [`std::fmt::Display`] object can be used to retrieve the reason. - pub fn mismatches(&self) -> &[Option>; N] { - &self.mismatches - } -} - -/// Return the signature for an argument type. -pub struct Signature(pub(crate) fn(&mut Formatter<'_>) -> std::fmt::Result); - -impl std::fmt::Display for Signature { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - (self.0)(f) - } -} - /////////// // Tests // /////////// diff --git a/diskann-benchmark-runner/src/dispatcher/dispatch.rs b/diskann-benchmark-runner/src/dispatcher/dispatch.rs deleted file mode 100644 index d33c8c6f3..000000000 --- a/diskann-benchmark-runner/src/dispatcher/dispatch.rs +++ /dev/null @@ -1,638 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::fmt::Formatter; - -use super::{ - ArgumentMismatch, DispatchRule, FailureScore, Map, MatchScore, Signature, TaggedFailureScore, -}; - -/// Return `Some` if all the entries in `Input` are `Ok(MatchScore)`. -/// -/// Otherwise, return None -fn coalesce( - input: &[Result; N], -) -> Option<[MatchScore; N]> { - let mut output = [MatchScore(0); N]; - for i in 0..N { - output[i] = match input[i] { - Ok(score) => score, - Err(_) => return None, - } - } - Some(output) -} - -/// Return `true` if all values in `inpu` are `Ok`. -fn all_match(input: &[Result; N]) -> bool { - input.iter().all(|i| matches!(i, Ok(MatchScore(_)))) -} - -/// A method match along with a tagged failure score. -/// -/// This is used as part of the match failure debugging process. -struct TaggedMatch<'a, const N: usize> { - method: &'a str, - score: [Result>; N], -} - -impl<'a, const N: usize> From> for ArgumentMismatch<'a, N> { - fn from(value: TaggedMatch<'a, N>) -> Self { - ArgumentMismatch { - method: value.method, - mismatches: value.score.map(|r| match r { - Ok(_) => None, - Err(tagged) => Some(tagged.why), - }), - } - } -} - -/// An ordered priority queue that keeps track of the "closest" mismatches. -struct Queue<'a, const N: usize> { - buffer: Vec>, - max_methods: usize, -} - -impl<'a, const N: usize> Queue<'a, N> { - fn new(max_methods: usize) -> Self { - Self { - buffer: Vec::with_capacity(max_methods), - max_methods, - } - } - - fn finish(self) -> Vec> { - self.buffer.into_iter().map(|m| m.into()).collect() - } - - /// Insert `r` into the queue in sorted order. - /// - /// Returns `Err(())` if all entries in `r` are matches. This provies a means for - /// algorithms reporting errors to detect if in fact the collection of arguments - /// are dispatchable and debugging is not actually needed. - fn push(&mut self, y: TaggedMatch<'a, N>) -> Result<(), ()> { - use std::cmp::Ordering; - - if all_match(&y.score) { - return Err(()); - } - - // Now we get the fun part of ranking methods. - // We rank first on `MatchScore`, then on `FailureScore`. - let lt = |x: &TaggedMatch<'a, N>| { - for i in 0..N { - let xi = &x.score[i]; - let yi = &y.score[i]; - match xi { - Ok(MatchScore(x_score)) => match yi { - Ok(MatchScore(y_score)) => match x_score.cmp(y_score) { - Ordering::Equal => {} - strict => return strict, - }, - Err(_) => { - return Ordering::Less; - } - }, - Err(TaggedFailureScore { score: x_score, .. }) => match yi { - Ok(_) => { - return Ordering::Greater; - } - Err(TaggedFailureScore { score: y_score, .. }) => { - match x_score.cmp(y_score) { - Ordering::Equal => {} - strict => return strict, - } - } - }, - } - } - Ordering::Equal - }; - - // `binary_search_by` will always return an index that will allow the key to be - // placed in sorted order. - // - // We do not care if the method is present or not, we just want the index. - let i = match self.buffer.binary_search_by(lt) { - Ok(i) => i, - Err(i) => i, - }; - - if self.buffer.len() == self.max_methods { - // No need to insert, it's greater than our worst match so far. - if i > self.buffer.len() { - return Ok(()); - } - self.buffer.insert(i, y); - self.buffer.truncate(self.max_methods); - } else { - self.buffer.insert(i, y); - } - Ok(()) - } -} - -pub trait Sealed {} - -macro_rules! implement_dispatch { - ($trait:ident, - $method:ident, - $dispatcher:ident, - $N:literal, - { $($T:ident )+ }, - { $($x:ident )+ }, - { $($A:ident )+ }, - { $($lf:lifetime )+ } - ) => { - /// A dispatchable method. - /// - /// # Macro Expansion - /// - /// Generates the code below: - /// ```text - /// pub trait DispatcherN - /// where - /// T0: Map, - /// T1: Map, - /// ..., - /// { - /// fn try_match(&self, x0: &T0::Type<'_>, x1: &T1::Type<'_>, ...); - /// - /// fn call(&self, x0: T0::Type<'_>, x1: T1::Type<'_), ...) -> R; - /// - /// fn signatures(&self) -> [Signature; N]; - /// - /// fn try_match_verbose<'a, 'a0, 'a1, ...>( - /// &'a self, - /// x0: &'a T0::Type<'a0>, - /// x1: &'a T1::Type<'a1>, - /// ... - /// ) -> [Result>; N] - /// where - /// 'a0: 'a, - /// 'a1: 'a, - /// ...; - /// } - /// ``` - pub trait $trait: Sealed - where - $($T: Map,)* - { - /// Invoke [`DispatchRule::try_match`] on each argument/type pair where the type - /// comes from the backend method. - /// - /// Return all results. - fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result; $N]; - - /// Invoke this method with the given types, invoking [`DispatchRule::convert`] - /// on each argument to the target types of the backend method. - /// - /// This function is only safe to call if [`Self::try_match`] returns a success. - /// Calling this method incorrectly may panic. - /// - /// # Panics - /// - /// Panics if any call to [`DispatchRule::convert`] fails. - fn call(&self, $($x: $T::Type<'_>,)*) -> R; - - /// Return the signatures for each back-end argument type. - fn signatures(&self) -> [Signature; $N]; - - /// The equivalent of [`Self::try_match`], but using the - /// [`DispatchRule::try_match_verbose`] interface. - /// - /// This provides a method for inspecting the reason for match failures. - fn try_match_verbose<'a, $($lf,)*>( - &'a self, - $($x: &'a $T::Type<$lf>,)* - ) -> [Result>; $N] - where - $($lf: 'a,)*; - } - - /// # Macro Expansion - /// - /// ```text - /// pub struct MethodN - /// where - /// A0: Map, - /// A1: Map, - /// ..., - /// { - /// f: Box Fn(A0::Type<'a0>, A1::Type<'a1>, ...) -> R>, - /// _types: std::marker::PhantomData<(A0, A1, ...)>, - /// } - /// ``` - pub struct $method - where - $($A: Map,)* - { - f: Box Fn($($A::Type<$lf>,)*) -> R>, - _types: std::marker::PhantomData<($($A,)*)>, - } - - /// # Macro Expansion - /// - /// ```text - /// impl MethodN - /// where - /// R: 'static, - /// A0: Map, - /// A1: Map, - /// ..., - /// { - /// pub fn new(f: F) -> Self - /// where - /// F: for<'a0, 'a1, ...> Fn(A0::Type<'a0>, A1::Type<'a1>, ...) -> R + 'static, - /// { - /// Self { - /// f: Box::new(f), - /// _types: std::marker::PhantomData, - /// } - /// } - /// } - /// ``` - impl $method - where - $($A: Map,)* - { - fn new(f: F) -> Self - where - F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static, - { - Self { - f: Box::new(f), - _types: std::marker::PhantomData, - } - } - } - - impl Sealed for $method - where - $($A: Map,)* - {} - - impl $trait for $method - where - $($T: Map,)* - $($A: Map,)* - $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)* - { - fn try_match(&self, $($x: &$T::Type<'_>,)*) -> [Result; $N] { - // Splat out all the pair-wise `try_match`es. - [$($A::Type::try_match($x),)*] - } - - fn call(&self, $($x: $T::Type<'_>,)*) -> R { - // Convert and unwrap all pair-wise matches. - (self.f)($($A::Type::convert($x).unwrap(),)*) - } - - fn signatures(&self) -> [Signature; $N] { - // The strategy here involves decaying a stateless lambda to a function - // pointer, and generating one such lambda for each input type. - // - // Note that we need to couple it with its corresponding dispatch type - // to ensure we get routed to the correct description. - [ - $(Signature(|f: &mut Formatter<'_>| { - $A::Type::description(f, None::<&$T::Type<'_>>) - }),)* - ] - } - - fn try_match_verbose<'a, $($lf,)*>( - &self, - $($x: &'a $T::Type<$lf>,)* - ) -> [Result>; $N] - where - $($lf: 'a,)* - { - // Simply construct an array by calling `try_match_verbose` on each pair. - [$($A::Type::try_match_verbose($x),)*] - } - } - - /// A central dispatcher for multi-method overloading. - pub struct $dispatcher - where - R: 'static, - $($T: Map,)* - { - pub(super) methods: Vec<(String, Box>)>, - } - - impl Default for $dispatcher - where - R: 'static, - $($T: Map,)* - { - fn default() -> Self { - Self::new() - } - } - - impl $dispatcher - where - R: 'static, - $($T: Map,)* - { - /// Construct a new, empty dispatcher. - pub fn new() -> Self { - Self { methods: Vec::new() } - } - - /// Register the new named method with the dispatcher. - pub fn register(&mut self, name: impl Into, f: F) - where - $($A: Map,)* - $(for<'a> $A::Type<'a>: DispatchRule<$T::Type<'a>>,)* - F: for<$($lf,)*> Fn($($A::Type<$lf>,)*) -> R + 'static, - { - let method = $method::::new(f); - self.methods.push((name.into(), Box::new(method))) - } - - /// Try to invoke the best fitting method with the given arguments. - /// - /// If no such method can be found, returns `None`. - pub fn call(&self, $($x: $T::Type<'_>,)*) -> Option { - let mut method: Option<(&_, [MatchScore; $N])> = None; - self.methods.iter().for_each(|m| { - match coalesce(&(m.1.try_match($(&$x,)*))) { - // Valid match - Some(score) => match method.as_mut() { - Some(method) => { - if score < method.1 { - *method = (m, score) - } - } - None => { - method.replace((m, score)); - } - }, - None => {} - } - }); - - // Invoke the best method - method.map(|(m, _)| m.1.call($($x,)*)) - } - - /// Return an iterator to the methods registered in this dispatcher. - pub fn methods( - &self - ) -> impl ExactSizeIterator>)> { - self.methods.iter() - } - - /// Query whether the combination of values has a valid matching method without - /// trying to invoke that method. - pub fn has_match(&self, $($x: &$T::Type<'_>,)*) -> bool { - for m in self.methods.iter() { - if all_match(&m.1.try_match($(&$x,)*)) { - return true; - } - } - return false; - } - - /// Check if a back-end method exists for the arguments. - /// - /// If so, returns `Ok(())`. - /// - /// Otherwise, returns a vector of `ArgumentMismatch` for the up-to - /// `max_methods` closest methods. - /// - /// In this context, "closeness" is defined by first comparing match or failure - /// scores for argument 0, followed by argument 1 if equal and so on. - pub fn debug<'a, $($lf,)*>( - &'a self, - max_methods: usize, - $($x: &'a $T::Type<$lf>,)* - ) -> Result<(), Vec>> - where - $($lf: 'a,)* - { - let mut methods = Queue::new(max_methods); - for m in self.methods.iter() { - let t = TaggedMatch { - method: &m.0, - score: m.1.try_match_verbose($($x,)*), - }; - match methods.push(t) { - Ok(()) => {}, - Err(()) => return Ok(()), - } - } - Err(methods.finish()) - } - } - } -} - -implement_dispatch!(Dispatch1, Method1, Dispatcher1, 1, { T0 }, { x0 }, { A0 }, { 'a0 }); -implement_dispatch!( - Dispatch2, Method2, Dispatcher2, 2, - { T0 T1 }, { x0 x1 }, { A0 A1 }, { 'a0 'a1 } -); -implement_dispatch!( - Dispatch3, Method3, Dispatcher3, 3, - { T0 T1 T2 }, { x0 x1 x2 }, { A0 A1 A2 }, { 'a0 'a1 'a2 } -); - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - struct Num; - - impl Map for Num { - type Type<'a> = Self; - } - - impl DispatchRule for Num { - type Error = std::convert::Infallible; - - // For testing purposes, we accept values within 2 of `N`, but with decreasing - // precedence. - fn try_match(from: &usize) -> Result { - let diff = from.abs_diff(N); - if diff <= 2 { - Ok(MatchScore(diff as u32)) - } else { - Err(FailureScore(diff as u32)) - } - } - - fn convert(from: usize) -> Result { - assert!(from.abs_diff(N) <= 2); - Ok(Self) - } - - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&usize>) -> std::fmt::Result { - match from { - None => write!(f, "{}", N), - Some(value) => { - let diff = value.abs_diff(N); - match diff { - 0 => write!(f, "success: exact match"), - 1 => write!(f, "success: off by 1"), - 2 => write!(f, "success: off by 2"), - x => write!(f, "error: off by {}", x), - } - } - } - } - } - - //////////////// - // Dispatch 1 // - //////////////// - - #[test] - fn test_dispatch_1() { - let mut x = Dispatcher1::::default(); - x.register::<_, Num<0>>("method 0", |_| 0); - x.register::<_, Num<3>>("method 3", |_| 3); - x.register::<_, Num<5>>("method 5", |_| 5); - x.register::<_, Num<8>>("method 8", |_| 8); - - { - let methods: Vec<_> = x.methods().collect(); - assert_eq!(methods.len(), 4); - assert_eq!(methods[0].0, "method 0"); - assert_eq!(methods[0].1.signatures()[0].to_string(), "0"); - - assert_eq!(methods[1].0, "method 3"); - assert_eq!(methods[1].1.signatures()[0].to_string(), "3"); - } - - // Test that dispatching works properly. - assert_eq!(x.call(0), Some(0)); - assert_eq!(x.call(1), Some(0)); - assert_eq!(x.call(2), Some(3)); - assert_eq!(x.call(3), Some(3)); - assert_eq!(x.call(4), Some(3)); - assert_eq!(x.call(5), Some(5)); - assert_eq!(x.call(6), Some(5)); - assert_eq!(x.call(7), Some(8)); - assert_eq!(x.call(8), Some(8)); - assert_eq!(x.call(11), None); - - for i in 0..11 { - assert!(x.has_match(&i)); - } - for i in 11..20 { - assert!(!x.has_match(&i)); - } - - // Make sure `Debug` works. - assert!(x.debug(3, &10).is_ok()); - - let mismatches = x.debug(3, &11).unwrap_err(); - assert_eq!(mismatches.len(), 3); - - // Method 8 is the closest. - assert_eq!(mismatches[0].method(), "method 8"); - assert_eq!( - mismatches[0].mismatches()[0].as_ref().unwrap().to_string(), - "error: off by 3" - ); - - // Method 5 is next. - assert_eq!(mismatches[1].method(), "method 5"); - assert_eq!( - mismatches[1].mismatches()[0].as_ref().unwrap().to_string(), - "error: off by 6" - ); - - // Method 3 is next. - assert_eq!(mismatches[2].method(), "method 3"); - assert_eq!( - mismatches[2].mismatches()[0].as_ref().unwrap().to_string(), - "error: off by 8" - ); - - // Make sure that if we request more than the total number of methods that it is - // capped. - assert_eq!(x.debug(10, &20).unwrap_err().len(), 4); - } - - //////////////// - // Dispatch 2 // - //////////////// - - #[test] - fn test_dispatch_2() { - let mut x = Dispatcher2::::default(); - - x.register::<_, Num<10>, Num<10>>("method 0", |_, _| 0); - x.register::<_, Num<10>, Num<13>>("method 1", |_, _| 1); - x.register::<_, Num<13>, Num<12>>("method 3", |_, _| 3); - x.register::<_, Num<12>, Num<10>>("method 2", |_, _| 2); - - { - let methods: Vec<_> = x.methods().collect(); - assert_eq!(methods.len(), 4); - assert_eq!(methods[0].0, "method 0"); - assert_eq!(methods[0].1.signatures()[0].to_string(), "10"); - assert_eq!(methods[0].1.signatures()[1].to_string(), "10"); - - assert_eq!(methods[1].0, "method 1"); - assert_eq!(methods[1].1.signatures()[0].to_string(), "10"); - assert_eq!(methods[1].1.signatures()[1].to_string(), "13"); - } - - // This is where things get weird. - assert_eq!(x.call(10, 10), Some(0)); // Match method 0 - assert_eq!(x.call(10, 11), Some(0)); // Match method 0 - assert_eq!(x.call(10, 12), Some(1)); // Match method 1 - assert_eq!(x.call(11, 12), Some(1)); // Match method 1 - assert_eq!(x.call(12, 12), Some(2)); // Match method 2 - assert_eq!(x.call(13, 12), Some(3)); // Match method 3 - - // Check error handling. - { - assert!(x.call(10, 7).is_none()); - let m = x.debug(3, &9, &7).unwrap_err(); - // The closest hit is method 0, followed by method 1. - assert_eq!(m[0].method(), "method 0"); - assert_eq!(m[1].method(), "method 1"); - assert_eq!(m[2].method(), "method 2"); - - let mismatches = m[0].mismatches(); - // The first argument is a match - the second argument is a mismatch. - assert!(mismatches[0].is_none()); - assert_eq!( - mismatches[1].as_ref().unwrap().to_string(), - "error: off by 3" - ); - - let mismatches = m[2].mismatches(); - assert_eq!( - mismatches[0].as_ref().unwrap().to_string(), - "error: off by 3" - ); - assert_eq!( - mismatches[1].as_ref().unwrap().to_string(), - "error: off by 3" - ); - } - - // Try again, but this time from the other direction. - { - let m = x.debug(4, &16, &12).unwrap_err(); - assert_eq!(m[0].method(), "method 3"); - assert_eq!(m[1].method(), "method 2"); - assert_eq!(m[2].method(), "method 1"); - } - } -} diff --git a/diskann-benchmark-runner/src/dispatcher/examples.rs b/diskann-benchmark-runner/src/dispatcher/examples.rs deleted file mode 100644 index b7b218b8e..000000000 --- a/diskann-benchmark-runner/src/dispatcher/examples.rs +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Example types for `dispatcher` module-level documentation. - -use crate::{ - dispatcher::{DispatchRule, FailureScore, Map, MatchScore}, - self_map, -}; - -/// An example type representing Rust primitive type. -#[derive(Debug, Clone, Copy)] -pub enum DataType { - Float64, - Float32, - UInt8, - UInt16, - UInt32, - UInt64, - Int8, - Int16, - Int32, - Int64, -} - -// Make `DataType` a dispatch type. -self_map!(DataType); - -/// A type-domain lifting of Rust primitive types. -pub struct Type(std::marker::PhantomData); - -/// Make `Type` reflexive to facilitate dispatch. -impl Map for Type { - type Type<'a> = Self; -} - -macro_rules! type_map { - ($variant:ident, $T:ty) => { - impl DispatchRule for Type<$T> { - type Error = std::convert::Infallible; - - fn try_match(from: &DataType) -> Result { - match from { - DataType::$variant => Ok(MatchScore(0)), - _ => Err(FailureScore(u32::MAX)), - } - } - - fn convert(from: DataType) -> Result { - assert!(matches!(from, DataType::$variant)); - Ok(Self(std::marker::PhantomData)) - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&DataType>, - ) -> std::fmt::Result { - match from { - None => write!(f, "{:?}", DataType::$variant), - Some(v) => { - if matches!(v, DataType::$variant) { - write!(f, "success") - } else { - write!(f, "expected {:?} but got {:?}", DataType::$variant, v) - } - } - } - } - } - }; -} - -type_map!(Float64, f64); -type_map!(Float32, f32); -type_map!(UInt8, u8); -type_map!(UInt16, u16); -type_map!(UInt32, u32); -type_map!(UInt64, u64); -type_map!(Int8, i8); -type_map!(Int16, i16); -type_map!(Int32, i32); -type_map!(Int64, i64); diff --git a/diskann-benchmark-runner/src/dispatcher/mod.rs b/diskann-benchmark-runner/src/dispatcher/mod.rs index 6cd76c3df..76eba7646 100644 --- a/diskann-benchmark-runner/src/dispatcher/mod.rs +++ b/diskann-benchmark-runner/src/dispatcher/mod.rs @@ -3,75 +3,20 @@ * Licensed under the MIT license. */ -//! # Dynamic Dispatcher +//! # Dispatch Rules //! -//! This crate implements a family of generic structures supporting pull-driven generic -//! multiple dispatch. +//! This module provides the [`DispatchRule`] trait and supporting types for value-to-type +//! matching and conversion. //! -//! In other words, it allows functions to be registered in a central location, enables -//! value-to-type lifting of arguments, overload resolution, and utilities for inspecting -//! failures for diagnostic reporting. -//! -//! # Quick Example -//! -//! Suppose we have a small collection of operations for which we would like to specialize -//! for some type `T`, which we may wish to alter from time to time. -//! -//! Furthermore, suppose this operation can returns a `String`. -//! -//! We can do that fairly easily: -//! ``` -//! use diskann_benchmark_runner::dispatcher::{self, examples::{DataType, Type}}; -//! -//! // A dynamic dispatcher that takes 1 argument of type `DataType` and returns a `String`. -//! let mut d = dispatcher::Dispatcher1::::new(); -//! -//! // We can register two methods with the dispatcher. -//! d.register::<_, Type>("method-a", |_: Type| "called method A".to_string()); -//! d.register::<_, Type>("method-b", |_: Type| "called method B".to_string()); -//! -//! // We can now verify that these methods are reachable. -//! assert_eq!(&d.call(DataType::Float32).unwrap(), "called method A"); -//! assert_eq!(&d.call(DataType::Float64).unwrap(), "called method B"); -//! -//! // If we try to call the dispatcher with a unregistered value for `DataType`, we -//! // get `None` as a result. -//! assert!(d.call(DataType::UInt8).is_none()); -//! -//! // But now suppose that we can implement a generic method, taking *all* data types. -//! // -//! // We can register that method and call it. -//! d.register::<_, DataType>("generic", |_: DataType| "called generic method".to_string()); -//! assert_eq!(&d.call(DataType::UInt8).unwrap(), "called generic method"); -//! -//! // However, more specific methods will be called if available. -//! assert_eq!(&d.call(DataType::Float32).unwrap(), "called method A"); -//! -//! // This is not order dependent. -//! // -//! // If we register yet another method, this time specialized for `UInt8`, it will get -//! // called when applicable. -//! d.register::<_, Type>("method-c", |_: Type| "called method C".to_string()); -//! assert_eq!(&d.call(DataType::UInt8).unwrap(), "called method C"); -//! ``` +//! [`DispatchRule`] is used by benchmark implementations to match runtime enum values +//! (e.g., `DataType::Float32`) to static Rust types (e.g., `Type`), enabling +//! type-driven overload resolution. mod api; -mod dispatch; - -pub mod examples; pub use api::{ - ArgumentMismatch, Description, DispatchRule, FailureScore, Map, MatchScore, MutRef, Ref, - Signature, TaggedFailureScore, Type, Why, IMPLICIT_MATCH_SCORE, -}; - -////////////////////// -// Dispatch Related // -////////////////////// - -pub use dispatch::{ - Dispatch1, Dispatch2, Dispatch3, Dispatcher1, Dispatcher2, Dispatcher3, Method1, Method2, - Method3, + Description, DispatchRule, FailureScore, MatchScore, TaggedFailureScore, Why, + IMPLICIT_MATCH_SCORE, }; /////////// @@ -80,119 +25,19 @@ pub use dispatch::{ #[cfg(test)] mod tests { - use std::marker::PhantomData; - use super::*; - use crate::self_map; - - /////////// - // Types // - /////////// + struct TestDescription; - struct TestType { - _phantom: PhantomData, - } - - impl TestType { - fn new() -> Self { - Self { - _phantom: PhantomData, - } - } - } - - impl Map for TestType { - type Type<'a> = Self; - } - - #[derive(Debug)] - enum TypeEnum { - Float32, - Int8, - UInt8, - } - - self_map!(TypeEnum); - - impl DispatchRule for TestType { + impl DispatchRule for TestDescription { type Error = std::convert::Infallible; - fn try_match(from: &TypeEnum) -> Result { - match from { - TypeEnum::Float32 => Ok(MatchScore(0)), - _ => Err(FailureScore(0)), - } + fn try_match(_from: &usize) -> Result { + panic!("should not be called"); } - fn convert(from: TypeEnum) -> Result { - assert!(Self::try_match(&from).is_ok()); - Ok(Self::new()) - } - } - - impl DispatchRule for TestType { - type Error = std::convert::Infallible; - - fn try_match(from: &TypeEnum) -> Result { - match from { - TypeEnum::Int8 => Ok(MatchScore(0)), - _ => Err(FailureScore(0)), - } - } - - fn convert(from: TypeEnum) -> Result { - assert!(Self::try_match(&from).is_ok()); - Ok(Self::new()) - } - } - - ////////////// - // UnaryOps // - ////////////// - - enum UnaryOp { - Square, - Double, - DoesNotExist, - } - - struct Square; - struct Double; - - self_map!(UnaryOp); - self_map!(Square); - self_map!(Double); - - impl DispatchRule for Square { - type Error = std::convert::Infallible; - - fn try_match(from: &UnaryOp) -> Result { - match from { - UnaryOp::Square => Ok(MatchScore(0)), - _ => Err(FailureScore(0)), - } - } - - fn convert(from: UnaryOp) -> Result { - assert!(Self::try_match(&from).is_ok()); - Ok(Self) - } - } - - impl DispatchRule for Double { - type Error = std::convert::Infallible; - - fn try_match(from: &UnaryOp) -> Result { - match from { - UnaryOp::Double => Ok(MatchScore(0)), - _ => Err(FailureScore(0)), - } - } - - fn convert(from: UnaryOp) -> Result { - assert!(Self::try_match(&from).is_ok()); - Ok(Self) + fn convert(_from: usize) -> Result { + panic!("should not be called"); } } @@ -203,48 +48,12 @@ mod tests { #[test] fn test_empty_description() { assert_eq!( - Description::::new().to_string(), + Description::::new().to_string(), "" ); assert_eq!( - Why::::new(&UnaryOp::Double).to_string(), + Why::::new(&0).to_string(), "" ); } - - #[test] - fn test_dispatch1() { - let mut dispatcher = Dispatcher1::<&'static str, TypeEnum>::new(); - dispatcher.register::<_, TestType>("method1", |_: TestType| "float32"); - dispatcher.register::<_, TestType>("method2", |_: TestType| "int8"); - - assert_eq!(dispatcher.call(TypeEnum::Int8), Some("int8")); - assert_eq!(dispatcher.call(TypeEnum::Float32), Some("float32")); - assert_eq!(dispatcher.call(TypeEnum::UInt8), None); - - let mut dispatcher = Dispatcher1::<(), MutRef<[f32]>>::new(); - dispatcher.register::<_, MutRef<[f32]>>("method1", |_: &mut [f32]| println!("hello world")); - } - - #[test] - fn test_dispatch2() { - let mut dispatcher = Dispatcher2::::new(); - dispatcher.register::<_, Square, u64>("square", |_: Square, x: u64| x * x); - dispatcher.register::<_, Double, u64>("double", |_: Double, x: u64| 2 * x); - - assert_eq!(dispatcher.call(UnaryOp::Square, 10).unwrap(), 100); - assert_eq!(dispatcher.call(UnaryOp::Double, 10).unwrap(), 20); - assert_eq!(dispatcher.call(UnaryOp::DoesNotExist, 0), None); - - let mut dispatcher = Dispatcher2::<(), UnaryOp, MutRef>::new(); - dispatcher.register::<_, Square, MutRef>("square", |_: Square, x: &mut u64| *x *= *x); - dispatcher.register::<_, Double, MutRef>("double", |_: Double, x: &mut u64| *x *= 2); - - let mut x: u64 = 10; - dispatcher.call(UnaryOp::Square, &mut x).unwrap(); - assert_eq!(x, 100); - - dispatcher.call(UnaryOp::Double, &mut x).unwrap(); - assert_eq!(x, 200); - } } diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs new file mode 100644 index 000000000..5f4a0dc80 --- /dev/null +++ b/diskann-benchmark-runner/src/input.rs @@ -0,0 +1,130 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use crate::{Any, Checker}; + +pub trait Input { + /// Return the discriminant associated with this type. + /// + /// This is used to map inputs types to their respective parsers. + /// + /// Well formed implementations should always return the same result. + fn tag() -> &'static str; + + /// Attempt to deserialize an opaque object from the raw `serialized` representation. + /// + /// Deserialized values can be constructed and returned via [`Checker::any`], + /// [`Any::new`] or [`Any::raw`]. + /// + /// If using the [`Any`] constructors directly, implementations should associate + /// [`Self::tag`] with the returned `Any`. If [`Checker::any`] is used - this will + /// happen automatically. + /// + /// Implementations are **strongly** encouraged to implement + /// [`CheckDeserialization`](crate::CheckDeserialization) and use this API to ensure + /// shared resources (like input files or output files) are correctly resolved and + /// properly shared among all jobs in a benchmark run. + fn try_deserialize( + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result; + + /// Print an example JSON representation of objects this input is expected to parse. + /// + /// Well-formed implementations should ensure that passing the returned + /// [`serde_json::Value`] back to [`Self::try_deserialize`] correctly deserializes, + /// though it need not necessarily pass + /// [`CheckDeserialization`](crate::CheckDeserialization). + fn example() -> anyhow::Result; +} + +/// A registered input. See [`crate::registry::Inputs::get`]. +#[derive(Clone, Copy)] +pub struct Registered<'a>(pub(crate) &'a dyn DynInput); + +impl Registered<'_> { + /// Return the input tag of the registered input. + /// + /// See: [`Input::tag`]. + pub fn tag(&self) -> &'static str { + self.0.tag() + } + + /// Try to deserialize raw JSON into the dynamic type of the input. + /// + /// See: [`Input::try_deserialize`]. + pub fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + self.0.try_deserialize(serialized, checker) + } + + /// Return an example JSON for the dynamic type of the input. + /// + /// See: [`Input::example`]. + pub fn example(&self) -> anyhow::Result { + self.0.example() + } +} + +impl std::fmt::Debug for Registered<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("input::Registered") + .field("tag", &self.tag()) + .finish() + } +} + +////////////// +// Internal // +////////////// + +#[derive(Debug)] +pub(crate) struct Wrapper(std::marker::PhantomData); + +impl Wrapper { + pub(crate) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + +impl Clone for Wrapper { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Wrapper {} + +pub(crate) trait DynInput { + fn tag(&self) -> &'static str; + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result; + fn example(&self) -> anyhow::Result; +} + +impl DynInput for Wrapper +where + T: Input, +{ + fn tag(&self) -> &'static str { + T::tag() + } + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + T::try_deserialize(serialized, checker) + } + fn example(&self) -> anyhow::Result { + T::example() + } +} diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index da71275c0..64cdf5f50 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, registry, Any, Input}; +use crate::{checker::Checker, input, registry, Any}; #[derive(Debug)] pub(crate) struct Jobs { @@ -96,7 +96,7 @@ impl Unprocessed { Self { tag, content } } - pub(crate) fn format_input(example: &dyn Input) -> anyhow::Result { + pub(crate) fn format_input(example: input::Registered<'_>) -> anyhow::Result { let tag = example.tag().to_string(); Ok(Self { tag, diff --git a/diskann-benchmark-runner/src/lib.rs b/diskann-benchmark-runner/src/lib.rs index d1a40667a..ab3f0b676 100644 --- a/diskann-benchmark-runner/src/lib.rs +++ b/diskann-benchmark-runner/src/lib.rs @@ -5,6 +5,7 @@ //! A moderately functional utility for making simple benchmarking CLI applications. +mod benchmark; mod checker; mod jobs; mod result; @@ -13,13 +14,16 @@ pub mod any; pub mod app; pub mod dispatcher; pub mod files; +pub mod input; pub mod output; pub mod registry; pub mod utils; pub use any::Any; pub use app::App; +pub use benchmark::Benchmark; pub use checker::{CheckDeserialization, Checker}; +pub use input::Input; pub use output::Output; pub use result::Checkpoint; @@ -29,41 +33,3 @@ pub mod test; #[cfg(any(test, feature = "ux-tools"))] #[doc(hidden)] pub mod ux; - -//-------// -// Input // -//-------// - -pub trait Input { - /// Return the discriminant associated with this type. - /// - /// This is used to map inputs types to their respective parsers. - /// - /// Well formed implementations should always return the same result. - fn tag(&self) -> &'static str; - - /// Attempt to deserialize an opaque object from the raw `serialized` representation. - /// - /// Deserialized values can be constructed and returned via [`Checker::any`], - /// [`Any::new`] or [`Any::raw`]. - /// - /// If using the [`Any`] constructors directly, implementations should associate - /// [`Self::tag`] with the returned `Any`. If [`Checker::any`] is used - this will - /// happen automatically. - /// - /// Implementations are **strongly** encouraged to implement [`CheckDeserialization`] - /// and use this API to ensure shared resources (like input files or output files) - /// are correctly resolved and properly shared among all jobs in a benchmark run. - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result; - - /// Print an example JSON representation of objects this input is expected to parse. - /// - /// Well formed implementations should passing the returned [`serde_json::Value`] back - /// to [`Self::try_deserialize`] correctly deserializes, though it need not necessarily - /// pass [`CheckDeserialization`]. - fn example(&self) -> anyhow::Result; -} diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 82d2712e1..3eab77fcf 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -8,15 +8,15 @@ use std::collections::HashMap; use thiserror::Error; use crate::{ - dispatcher::{DispatchRule, Map}, - output::Sink, - Any, Checkpoint, Input, Output, + benchmark::{self, Benchmark, DynBenchmark}, + dispatcher::FailureScore, + input, Any, Checkpoint, Input, Output, }; /// A collection of [`crate::Input`]. pub struct Inputs { // Inputs keyed by their tag type. - inputs: HashMap<&'static str, Box>, + inputs: HashMap<&'static str, Box>, } impl Inputs { @@ -27,28 +27,25 @@ impl Inputs { } } - /// Return the input with the registerd `tag` if present. Otherwise, return `None`. - pub fn get(&self, tag: &str) -> Option<&dyn Input> { - match self.inputs.get(tag) { - Some(v) => Some(&**v), - None => None, - } + /// Return the input with the registered `tag` if present. Otherwise, return `None`. + pub fn get(&self, tag: &str) -> Option> { + self.inputs.get(tag).map(|v| input::Registered(&**v)) } - /// Register `input` in the registry. + /// Register the [`Input`] `T` in the registry. /// /// Returns an error if any other input with the same [`Input::tag()`] has been registered /// while leaving the underlying registry unchanged. - pub fn register(&mut self, input: T) -> anyhow::Result<()> + pub fn register(&mut self) -> anyhow::Result<()> where T: Input + 'static, { use std::collections::hash_map::Entry; - let tag = input.tag(); + let tag = T::tag(); match self.inputs.entry(tag) { Entry::Vacant(entry) => { - entry.insert(Box::new(input)); + entry.insert(Box::new(crate::input::Wrapper::::new())); Ok(()) } Entry::Occupied(_) => { @@ -73,47 +70,52 @@ impl Default for Inputs { } } -/// A collection of registerd benchmarks. +/// A registered benchmark entry: a name paired with a type-erased benchmark. +struct RegisteredBenchmark { + name: String, + benchmark: Box, +} + +/// A collection of registered benchmarks. pub struct Benchmarks { - dispatcher: Dispatcher, + benchmarks: Vec, } impl Benchmarks { /// Return a new empty registry. pub fn new() -> Self { Self { - dispatcher: Dispatcher::new(), + benchmarks: Vec::new(), } } /// Register a new benchmark with the given name. - /// - /// The type parameter `T` is used to match this benchmark with a registered - /// [`crate::Any`], which is determined using by `>`. - pub fn register( - &mut self, - name: impl Into, - benchmark: impl Fn(T::Type<'_>, Checkpoint<'_>, &mut dyn Output) -> anyhow::Result - + 'static, - ) where - T: for<'a> Map: DispatchRule<&'a Any>>, + pub fn register(&mut self, name: impl Into) + where + T: Benchmark + 'static, { - self.dispatcher - .register::<_, T, CheckpointRef, DynOutput>(name.into(), benchmark) + self.benchmarks.push(RegisteredBenchmark { + name: name.into(), + benchmark: Box::new(benchmark::Wrapper::::new()), + }); } - pub(crate) fn methods(&self) -> impl ExactSizeIterator { - self.dispatcher.methods() + /// Return an iterator over registered benchmark names and their descriptions. + pub(crate) fn names(&self) -> impl ExactSizeIterator { + self.benchmarks.iter().map(|entry| { + ( + entry.name.as_str(), + Capture(&*entry.benchmark, None).to_string(), + ) + }) } - /// Return `true` if `job` matches with any registerd benchmark. Otherwise, return `false`. + /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`. pub fn has_match(&self, job: &Any) -> bool { - let sink: &mut dyn Output = &mut Sink::new(); - self.dispatcher.has_match(&job, &Checkpoint::empty(), &sink) + self.find_best_match(job).is_some() } - /// Attempt to the best matching benchmark for `job` - forwarding the `checkpoint` and - /// `output` to the benchmark. + /// Attempt to run the best matching benchmark for `job`. /// /// Returns the results of the benchmark if successful. /// @@ -124,40 +126,64 @@ impl Benchmarks { checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - self.dispatcher.call(job, checkpoint, output).unwrap() + match self.find_best_match(job) { + Some(entry) => entry.benchmark.run(job, checkpoint, output), + None => Err(anyhow::Error::msg( + "could not find a matching benchmark for the given input", + )), + } } - /// Attempt to debug reasons for a missed dispatch, returning at most `methods` reasons. + /// Attempt to debug reasons for a missed dispatch, returning at most `max_methods` + /// reasons. /// - /// This implementation works by invoking [`DispatchRule::try_match`] with - /// `job` on all registered benchmarks. If no successful matches are found, the lowest - /// ranking [`crate::dispatcher::FailureScore`]s are collected and used to report details - /// of the nearest misses using [`DispatchRule::description`]. - /// - /// Returns `Ok(())` is a match was found. - pub fn debug(&self, job: &Any, methods: usize) -> Result<(), Vec> { - let checkpoint = Checkpoint::empty(); - let sink: &mut dyn Output = &mut Sink::new(); - let mismatches = match self.dispatcher.debug(methods, &job, &checkpoint, &sink) { - Ok(()) => return Ok(()), - Err(mismatches) => mismatches, - }; - - // Just retrieve the mismatch information for the first argument since that is the - // one that does all the heavy lifting. - Err(mismatches + /// Returns `Ok(())` if a match was found. + pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec> { + if self.has_match(job) { + return Ok(()); + } + + // Collect all failures with their scores, sorted by score (best near-misses first). + let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self + .benchmarks + .iter() + .filter_map(|entry| match entry.benchmark.try_match(job) { + Ok(_) => None, + Err(score) => Some((entry, score)), + }) + .collect(); + + failures.sort_by_key(|(_, score)| *score); + failures.truncate(max_methods); + + let mismatches = failures .into_iter() - .map(|m| { - let reason = m.mismatches()[0] - .as_ref() - .map(|opt| opt.to_string()) - .unwrap_or("".into()); + .map(|(entry, _)| { + let reason = Capture(&*entry.benchmark, Some(job)).to_string(); + Mismatch { - method: m.method().to_string(), + method: entry.name.clone(), reason, } }) - .collect()) + .collect(); + + Err(mismatches) + } + + /// Find the best matching benchmark for `job` by score. + fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> { + self.benchmarks + .iter() + .filter_map(|entry| { + entry + .benchmark + .try_match(job) + .ok() + .map(|score| (entry, score)) + }) + .min_by_key(|(_, score)| *score) + .map(|(entry, _)| entry) } } @@ -167,7 +193,7 @@ impl Default for Benchmarks { } } -/// Document the reason for a method mathing failure. +/// Document the reason for a method matching failure. pub struct Mismatch { method: String, reason: String, @@ -185,39 +211,11 @@ impl Mismatch { } } -//------------------// -// Dispatch Helpers // -//------------------// - -/// A [`Map`] for `&mut dyn Output`. -pub(crate) struct DynOutput; +/// Helper to capture a `DynBenchmark::description` call into a `String` via `Display`. +struct Capture<'a>(&'a dyn DynBenchmark, Option<&'a Any>); -impl Map for DynOutput { - type Type<'a> = &'a mut dyn Output; -} - -/// A dispatcher compatible mapper for [`Checkpoint`]. -#[derive(Debug, Clone, Copy)] -pub(crate) struct CheckpointRef; - -impl Map for CheckpointRef { - type Type<'a> = Checkpoint<'a>; +impl std::fmt::Display for Capture<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.description(f, self.1) + } } - -/// The internal `Dispatcher` used for method resolution. -type Dispatcher = crate::dispatcher::Dispatcher3< - anyhow::Result, - crate::dispatcher::Ref, - CheckpointRef, - DynOutput, ->; - -/// The concrete type of a method. -type Method = Box< - dyn crate::dispatcher::Dispatch3< - anyhow::Result, - crate::dispatcher::Ref, - CheckpointRef, - DynOutput, - >, ->; diff --git a/diskann-benchmark-runner/src/result.rs b/diskann-benchmark-runner/src/result.rs index 53485f728..5be95f8db 100644 --- a/diskann-benchmark-runner/src/result.rs +++ b/diskann-benchmark-runner/src/result.rs @@ -51,11 +51,6 @@ impl<'a> Checkpoint<'a> { } } - /// Create an empty checkpointer that turns calls to `checkpoint` into a no-op. - pub(crate) fn empty() -> Self { - Self { inner: None } - } - /// Atomically save the zip of the inputs and results to the configured path. pub fn save(&self) -> anyhow::Result<()> { if let Some(inner) = &self.inner { @@ -251,15 +246,6 @@ mod tests { assert!(message.contains("already exists")); } - #[test] - fn test_empty() { - let checkpoint = Checkpoint::empty(); - - // Make sure we can still call "save" and "checkpoint". - assert!(checkpoint.save().is_ok()); - assert!(checkpoint.checkpoint("hello world").is_ok()); - } - #[test] fn test_checkpoint() { let dir = tempfile::tempdir().unwrap(); diff --git a/diskann-benchmark-runner/src/test.rs b/diskann-benchmark-runner/src/test.rs index a2977ddc1..11a08e850 100644 --- a/diskann-benchmark-runner/src/test.rs +++ b/diskann-benchmark-runner/src/test.rs @@ -8,10 +8,10 @@ use std::io::Write; use serde::{Deserialize, Serialize}; use crate::{ - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, registry, utils::datatype::{DataType, Type}, - Any, CheckDeserialization, Checker, Checkpoint, Input, Output, + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, }; ///////// @@ -19,73 +19,22 @@ use crate::{ ///////// pub fn register_inputs(inputs: &mut registry::Inputs) -> anyhow::Result<()> { - inputs.register(AsTypeInput)?; - inputs.register(AsDimInput)?; + inputs.register::()?; + inputs.register::()?; Ok(()) } pub fn register_benchmarks(benchmarks: &mut registry::Benchmarks) { - benchmarks - .register::>("type-bench-f32", TypeBench::<'static, f32>::run); - benchmarks.register::>("type-bench-i8", TypeBench::<'static, i8>::run); - - benchmarks.register::("dim-bench", DimBench::run); + benchmarks.register::>("type-bench-f32"); + benchmarks.register::>("type-bench-i8"); + benchmarks.register::>("exact-type-bench-f32-1000"); + benchmarks.register::("dim-bench"); } //////////// // Inputs // //////////// -#[derive(Debug)] -struct AsTypeInput; - -impl Input for AsTypeInput { - fn tag(&self) -> &'static str { - TypeInput::tag() - } - - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(TypeInput::deserialize(serialized)?) - } - - fn example(&self) -> anyhow::Result { - Ok(serde_json::to_value(TypeInput::new( - DataType::Float32, - 128, - false, - ))?) - } -} - -#[derive(Debug)] -struct AsDimInput; - -impl Input for AsDimInput { - fn tag(&self) -> &'static str { - "test-input-dim" - } - - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(DimInput::deserialize(serialized)?) - } - - fn example(&self) -> anyhow::Result { - Ok(serde_json::to_value(DimInput::new(Some(128)))?) - } -} - -///////////////////////// -// Deserialized Inputs // -///////////////////////// - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub(crate) struct TypeInput { pub(crate) data_type: DataType, @@ -106,10 +55,27 @@ impl TypeInput { checked: false, } } +} - const fn tag() -> &'static str { +impl Input for TypeInput { + fn tag() -> &'static str { "test-input-types" } + + fn try_deserialize( + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + checker.any(TypeInput::deserialize(serialized)?) + } + + fn example() -> anyhow::Result { + Ok(serde_json::to_value(TypeInput::new( + DataType::Float32, + 128, + false, + ))?) + } } impl CheckDeserialization for TypeInput { @@ -134,6 +100,23 @@ impl DimInput { } } +impl Input for DimInput { + fn tag() -> &'static str { + "test-input-dim" + } + + fn try_deserialize( + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + checker.any(DimInput::deserialize(serialized)?) + } + + fn example() -> anyhow::Result { + Ok(serde_json::to_value(DimInput::new(Some(128)))?) + } +} + impl CheckDeserialization for DimInput { fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { Ok(()) @@ -145,121 +128,119 @@ impl CheckDeserialization for DimInput { //////////////// #[derive(Debug)] -struct TypeBench<'a, T> { - input: &'a TypeInput, - _type: Type, -} +struct TypeBench(std::marker::PhantomData); -impl TypeBench<'static, T> { - fn run( - this: TypeBench<'_, T>, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> anyhow::Result { - write!(output, "hello: {}", this.input.data_type.as_str())?; - checkpoint.checkpoint(this.input.data_type.as_str())?; - Ok(serde_json::Value::String( - this.input.data_type.as_str().into(), - )) - } -} - -impl dispatcher::Map for TypeBench<'static, T> +impl Benchmark for TypeBench where T: 'static, -{ - type Type<'a> = TypeBench<'a, T>; -} - -impl<'a, T> DispatchRule<&'a TypeInput> for TypeBench<'a, T> -where Type: DispatchRule, { - type Error = anyhow::Error; + type Input = TypeInput; + type Output = &'static str; - fn try_match(from: &&'a TypeInput) -> Result { - Type::::try_match(&from.data_type) + fn try_match(input: &TypeInput) -> Result { + // Try to match based on data type. + // Add a small penalty so `ExactTypeBench` can be more specific if it hits. + Type::::try_match(&input.data_type).map(|m| MatchScore(m.0 + 10)) } - fn convert(from: &'a TypeInput) -> Result { - Ok(Self { - input: from, - _type: Type::::convert(from.data_type)?, - }) + + fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + Type::::description(f, input.map(|i| &i.data_type)) } - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a TypeInput>, - ) -> std::fmt::Result { - match from { - Some(v) => Type::::description(f, Some(&v.data_type)), - None => Type::::description(f, None::<&DataType>), - } + + fn run( + input: &TypeInput, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + write!(output, "hello: {}", input.data_type.as_str())?; + checkpoint.checkpoint(input.data_type.as_str())?; + Ok(input.data_type.as_str()) } } -impl<'a, T> DispatchRule<&'a Any> for TypeBench<'a, T> +#[derive(Debug)] +struct ExactTypeBench(std::marker::PhantomData); + +impl Benchmark for ExactTypeBench where - Self: DispatchRule<&'a TypeInput, Error = anyhow::Error>, + T: 'static, + Type: DispatchRule, { - type Error = anyhow::Error; + type Input = TypeInput; + type Output = String; - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - fn convert(from: &'a Any) -> Result { - from.convert::() - } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, (AsTypeInput).tag()) + fn try_match(input: &TypeInput) -> Result { + if input.dim == N { + Type::::try_match(&input.data_type) + } else { + Err(FailureScore(1000)) + } } -} -#[derive(Debug)] -struct DimBench { - dim: Option, -} + fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + match input { + None => { + write!(f, "{}, dim={}", Description::>::new(), N) + } + Some(input) => { + let type_result = Type::::try_match_verbose(&input.data_type); + let dim_ok = input.dim == N; + match (type_result, dim_ok) { + (Ok(_), true) => write!(f, "successful match"), + (Err(err), true) => write!(f, "{}", err), + (Ok(_), false) => { + write!(f, "expected dim={}, but found dim={}", N, input.dim) + } + (Err(err), false) => { + write!( + f, + "{}; expected dim={}, but found dim={}", + err, N, input.dim + ) + } + } + } + } + } -impl DimBench { fn run( - self, - _checkpoint: Checkpoint<'_>, + input: &TypeInput, + checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, - ) -> anyhow::Result { - write!(output, "dim bench: {:?}", self.dim)?; - Ok(serde_json::Value::from(self.dim.unwrap_or(usize::MAX))) + ) -> anyhow::Result { + let s = format!("hello<{}>: {}", N, input.data_type.as_str()); + write!(output, "{}", s)?; + checkpoint.checkpoint(&s)?; + Ok(s) } } -crate::self_map!(DimBench); +#[derive(Debug)] +struct DimBench; -impl DispatchRule<&DimInput> for DimBench { - type Error = std::convert::Infallible; +impl Benchmark for DimBench { + type Input = DimInput; + type Output = usize; - fn try_match(_: &&DimInput) -> Result { + fn try_match(_input: &DimInput) -> Result { Ok(MatchScore(0)) } - fn convert(from: &DimInput) -> Result { - Ok(Self { dim: from.dim }) - } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&DimInput>) -> std::fmt::Result { - if from.is_some() { + + fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result { + if input.is_some() { write!(f, "perfect match") } else { write!(f, "matches all") } } -} -impl DispatchRule<&Any> for DimBench { - type Error = anyhow::Error; - - fn try_match(from: &&Any) -> Result { - from.try_match::() - } - fn convert(from: &Any) -> Result { - from.convert::() - } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&Any>) -> std::fmt::Result { - Any::description::(f, from, (AsDimInput).tag()) + fn run( + input: &DimInput, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> anyhow::Result { + write!(output, "dim bench: {:?}", input.dim)?; + Ok(input.dim.unwrap_or(usize::MAX)) } } diff --git a/diskann-benchmark-runner/src/utils/datatype.rs b/diskann-benchmark-runner/src/utils/datatype.rs index f7ea15721..3e230290f 100644 --- a/diskann-benchmark-runner/src/utils/datatype.rs +++ b/diskann-benchmark-runner/src/utils/datatype.rs @@ -6,7 +6,7 @@ use half::f16; use serde::{Deserialize, Serialize}; -use crate::dispatcher::{DispatchRule, FailureScore, Map, MatchScore}; +use crate::dispatcher::{DispatchRule, FailureScore, MatchScore}; /// An enum representation for common DiskANN data types. /// @@ -60,11 +60,6 @@ impl std::fmt::Display for DataType { #[derive(Debug, Default, Clone, Copy)] pub struct Type(std::marker::PhantomData); -/// The `Type` meta variable maps to itself. -impl Map for Type { - type Type<'a> = Self; -} - pub const MATCH_FAIL: FailureScore = FailureScore(1000); macro_rules! dispatch_rule { diff --git a/diskann-benchmark-runner/tests/test-4/stdout.txt b/diskann-benchmark-runner/tests/test-4/stdout.txt index 305c12a2a..edb313c6b 100644 --- a/diskann-benchmark-runner/tests/test-4/stdout.txt +++ b/diskann-benchmark-runner/tests/test-4/stdout.txt @@ -1,7 +1,9 @@ Registered Benchmarks: type-bench-f32: tag "test-input-types" -float32 + float32 type-bench-i8: tag "test-input-types" -int8 + int8 + exact-type-bench-f32-1000: tag "test-input-types" + float32, dim=1000 dim-bench: tag "test-input-dim" -matches all \ No newline at end of file + matches all \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/test-mismatch-0/stdout.txt b/diskann-benchmark-runner/tests/test-mismatch-0/stdout.txt index 4089fdd3d..62e7a0ba9 100644 --- a/diskann-benchmark-runner/tests/test-mismatch-0/stdout.txt +++ b/diskann-benchmark-runner/tests/test-mismatch-0/stdout.txt @@ -8,8 +8,8 @@ Could not find a match for the following input: Closest matches: - 1. "type-bench-i8": expected "int8" but found "float16" - 2. "type-bench-f32": expected "float32" but found "float16" - 3. "dim-bench": expected tag "test-input-dim" - instead got "test-input-types" + 1. "type-bench-f32": expected "float32" but found "float16" + 2. "type-bench-i8": expected "int8" but found "float16" + 3. "exact-type-bench-f32-1000": expected "float32" but found "float16"; expected dim=1000, but found dim=128 -could not find find a benchmark for all inputs \ No newline at end of file +could not find a benchmark for all inputs \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/test-mismatch-1/README.md b/diskann-benchmark-runner/tests/test-mismatch-1/README.md new file mode 100644 index 000000000..ba21773c0 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-mismatch-1/README.md @@ -0,0 +1,4 @@ +Mismatch diagnostics for ExactTypeBench description paths. + +Job 1: float16 with dim=1000 — ExactTypeBench fails on type only (dim matches). +Job 2: float16 with dim=128 — ExactTypeBench fails on both type and dim. diff --git a/diskann-benchmark-runner/tests/test-mismatch-1/input.json b/diskann-benchmark-runner/tests/test-mismatch-1/input.json new file mode 100644 index 000000000..608b97534 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-mismatch-1/input.json @@ -0,0 +1,14 @@ +{ + "search_directories": [], + "output_directory": null, + "jobs": [ + { + "type": "test-input-types", + "content": { + "data_type": "float16", + "dim": 1000, + "error_when_checked": false + } + } + ] +} diff --git a/diskann-benchmark-runner/tests/test-mismatch-1/stdin.txt b/diskann-benchmark-runner/tests/test-mismatch-1/stdin.txt new file mode 100644 index 000000000..e14e78017 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-mismatch-1/stdin.txt @@ -0,0 +1 @@ +run --input-file $INPUT --output-file $OUTPUT diff --git a/diskann-benchmark-runner/tests/test-mismatch-1/stdout.txt b/diskann-benchmark-runner/tests/test-mismatch-1/stdout.txt new file mode 100644 index 000000000..3e4c4ca50 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-mismatch-1/stdout.txt @@ -0,0 +1,15 @@ +Could not find a match for the following input: + +{ + "data_type": "float16", + "dim": 1000, + "error_when_checked": false +} + +Closest matches: + + 1. "type-bench-f32": expected "float32" but found "float16" + 2. "type-bench-i8": expected "int8" but found "float16" + 3. "exact-type-bench-f32-1000": expected "float32" but found "float16" + +could not find a benchmark for all inputs \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/test-overload-0/README.md b/diskann-benchmark-runner/tests/test-overload-0/README.md new file mode 100644 index 000000000..0bd9bcffa --- /dev/null +++ b/diskann-benchmark-runner/tests/test-overload-0/README.md @@ -0,0 +1,6 @@ +When two benchmarks match the same input, the one with the better MatchScore wins. + +`ExactTypeBench` matches float32 + dim=1000 with MatchScore(0). +`TypeBench` matches float32 (any dim) with MatchScore(10). + +The runner should pick `ExactTypeBench`. diff --git a/diskann-benchmark-runner/tests/test-overload-0/input.json b/diskann-benchmark-runner/tests/test-overload-0/input.json new file mode 100644 index 000000000..4430e6fb1 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-overload-0/input.json @@ -0,0 +1,22 @@ +{ + "search_directories": [], + "output_directory": null, + "jobs": [ + { + "type": "test-input-types", + "content": { + "data_type": "float32", + "dim": 1000, + "error_when_checked": false + } + }, + { + "type": "test-input-types", + "content": { + "data_type": "float32", + "dim": 128, + "error_when_checked": false + } + } + ] +} diff --git a/diskann-benchmark-runner/tests/test-overload-0/output.json b/diskann-benchmark-runner/tests/test-overload-0/output.json new file mode 100644 index 000000000..8fdbaa7e4 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-overload-0/output.json @@ -0,0 +1,24 @@ +[ + { + "input": { + "content": { + "data_type": "float32", + "dim": 1000, + "error_when_checked": false + }, + "type": "test-input-types" + }, + "results": "hello<1000>: float32" + }, + { + "input": { + "content": { + "data_type": "float32", + "dim": 128, + "error_when_checked": false + }, + "type": "test-input-types" + }, + "results": "float32" + } +] \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/test-overload-0/stdin.txt b/diskann-benchmark-runner/tests/test-overload-0/stdin.txt new file mode 100644 index 000000000..e14e78017 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-overload-0/stdin.txt @@ -0,0 +1 @@ +run --input-file $INPUT --output-file $OUTPUT diff --git a/diskann-benchmark-runner/tests/test-overload-0/stdout.txt b/diskann-benchmark-runner/tests/test-overload-0/stdout.txt new file mode 100644 index 000000000..6f62db794 --- /dev/null +++ b/diskann-benchmark-runner/tests/test-overload-0/stdout.txt @@ -0,0 +1,11 @@ +###################### +# Running Job 1 of 2 # +###################### + +hello<1000>: float32 + +###################### +# Running Job 2 of 2 # +###################### + +hello: float32 \ No newline at end of file diff --git a/diskann-benchmark-simd/src/bin.rs b/diskann-benchmark-simd/src/bin.rs index c570b0111..50efff351 100644 --- a/diskann-benchmark-simd/src/bin.rs +++ b/diskann-benchmark-simd/src/bin.rs @@ -4,7 +4,7 @@ */ use diskann_benchmark_runner::{output, registry, App, Output}; -use diskann_benchmark_simd::{register, SimdInput}; +use diskann_benchmark_simd::{register, SimdOp}; pub fn main() -> anyhow::Result<()> { // Create the pocket bench application. @@ -15,7 +15,7 @@ pub fn main() -> anyhow::Result<()> { fn main_inner(app: &App, output: &mut dyn Output) -> anyhow::Result<()> { // Register inputs and benchmarks. let mut inputs = registry::Inputs::new(); - inputs.register(SimdInput)?; + inputs.register::()?; let mut benchmarks = registry::Benchmarks::new(); register(&mut benchmarks); diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 530be5803..1336933e4 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -19,21 +19,18 @@ use thiserror::Error; use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::{ datatype::{self, DataType}, percentiles, MicroSeconds, }, - Any, CheckDeserialization, Checker, + Any, Benchmark, CheckDeserialization, Checker, Input, }; //////////////// // Public API // //////////////// -#[derive(Debug)] -pub struct SimdInput; - pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) { register_benchmarks_impl(dispatcher) } @@ -112,7 +109,7 @@ pub(crate) struct Run { } #[derive(Debug, Serialize, Deserialize)] -pub(crate) struct SimdOp { +pub struct SimdOp { pub(crate) query_type: DataType, pub(crate) data_type: DataType, pub(crate) arch: Arch, @@ -132,10 +129,6 @@ macro_rules! write_field { } impl SimdOp { - pub(crate) const fn tag() -> &'static str { - "simd-op" - } - fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write_field!(f, "query type", self.query_type)?; write_field!(f, "data type", self.data_type)?; @@ -153,20 +146,19 @@ impl std::fmt::Display for SimdOp { } } -impl diskann_benchmark_runner::Input for SimdInput { - fn tag(&self) -> &'static str { +impl Input for SimdOp { + fn tag() -> &'static str { "simd-op" } fn try_deserialize( - &self, serialized: &serde_json::Value, checker: &mut Checker, ) -> anyhow::Result { - checker.any(SimdOp::deserialize(serialized)?) + checker.any(Self::deserialize(serialized)?) } - fn example(&self) -> anyhow::Result { + fn example() -> anyhow::Result { const DIM: [NonZeroUsize; 2] = [ NonZeroUsize::new(128).unwrap(), NonZeroUsize::new(150).unwrap(), @@ -197,7 +189,7 @@ impl diskann_benchmark_runner::Input for SimdInput { }, ]; - Ok(serde_json::to_value(&SimdOp { + Ok(serde_json::to_value(&Self { query_type: DataType::Float32, data_type: DataType::Float32, arch: Arch::X86_64_V3, @@ -213,16 +205,10 @@ impl diskann_benchmark_runner::Input for SimdInput { macro_rules! register { ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => { #[cfg(target_arch = $arch)] - $dispatcher.register::<$($kernel)*>( - $name, - run_benchmark, - ) + $dispatcher.register::<$($kernel)*>($name) }; ($dispatcher:ident, $name:literal, $($kernel:tt)*) => { - $dispatcher.register::<$($kernel)*>( - $name, - run_benchmark, - ) + $dispatcher.register::<$($kernel)*>($name) }; } @@ -232,25 +218,25 @@ fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry: "x86_64", dispatcher, "simd-op-f32xf32-x86_64_V4", - Kernel<'static, diskann_wide::arch::x86_64::V4, f32, f32> + Kernel ); register!( "x86_64", dispatcher, "simd-op-f16xf16-x86_64_V4", - Kernel<'static, diskann_wide::arch::x86_64::V4, f16, f16> + Kernel ); register!( "x86_64", dispatcher, "simd-op-u8xu8-x86_64_V4", - Kernel<'static, diskann_wide::arch::x86_64::V4, u8, u8> + Kernel ); register!( "x86_64", dispatcher, "simd-op-i8xi8-x86_64_V4", - Kernel<'static, diskann_wide::arch::x86_64::V4, i8, i8> + Kernel ); // x86-64-v3 @@ -258,25 +244,25 @@ fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry: "x86_64", dispatcher, "simd-op-f32xf32-x86_64_V3", - Kernel<'static, diskann_wide::arch::x86_64::V3, f32, f32> + Kernel ); register!( "x86_64", dispatcher, "simd-op-f16xf16-x86_64_V3", - Kernel<'static, diskann_wide::arch::x86_64::V3, f16, f16> + Kernel ); register!( "x86_64", dispatcher, "simd-op-u8xu8-x86_64_V3", - Kernel<'static, diskann_wide::arch::x86_64::V3, u8, u8> + Kernel ); register!( "x86_64", dispatcher, "simd-op-i8xi8-x86_64_V3", - Kernel<'static, diskann_wide::arch::x86_64::V3, i8, i8> + Kernel ); // aarch64-neon @@ -284,69 +270,69 @@ fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry: "aarch64", dispatcher, "simd-op-f32xf32-aarch64_neon", - Kernel<'static, diskann_wide::arch::aarch64::Neon, f32, f32> + Kernel ); register!( "aarch64", dispatcher, "simd-op-f16xf16-aarch64_neon", - Kernel<'static, diskann_wide::arch::aarch64::Neon, f16, f16> + Kernel ); register!( "aarch64", dispatcher, "simd-op-u8xu8-aarch64_neon", - Kernel<'static, diskann_wide::arch::aarch64::Neon, u8, u8> + Kernel ); register!( "aarch64", dispatcher, "simd-op-i8xi8-aarch64_neon", - Kernel<'static, diskann_wide::arch::aarch64::Neon, i8, i8> + Kernel ); // scalar register!( dispatcher, "simd-op-f32xf32-scalar", - Kernel<'static, diskann_wide::arch::Scalar, f32, f32> + Kernel ); register!( dispatcher, "simd-op-f16xf16-scalar", - Kernel<'static, diskann_wide::arch::Scalar, f16, f16> + Kernel ); register!( dispatcher, "simd-op-u8xu8-scalar", - Kernel<'static, diskann_wide::arch::Scalar, u8, u8> + Kernel ); register!( dispatcher, "simd-op-i8xi8-scalar", - Kernel<'static, diskann_wide::arch::Scalar, i8, i8> + Kernel ); // reference register!( dispatcher, "simd-op-f32xf32-reference", - Kernel<'static, Reference, f32, f32> + Kernel ); register!( dispatcher, "simd-op-f16xf16-reference", - Kernel<'static, Reference, f16, f16> + Kernel ); register!( dispatcher, "simd-op-u8xu8-reference", - Kernel<'static, Reference, u8, u8> + Kernel ); register!( dispatcher, "simd-op-i8xi8-reference", - Kernel<'static, Reference, i8, i8> + Kernel ); } @@ -361,38 +347,20 @@ struct Reference; #[derive(Debug)] struct Identity(T); -impl dispatcher::Map for Identity -where - T: 'static, -{ - type Type<'a> = T; -} - -struct Kernel<'a, A, Q, D> { - input: &'a SimdOp, +struct Kernel { arch: A, _type: std::marker::PhantomData<(A, Q, D)>, } -impl<'a, A, Q, D> Kernel<'a, A, Q, D> { - fn new(input: &'a SimdOp, arch: A) -> Self { +impl Kernel { + fn new(arch: A) -> Self { Self { - input, arch, _type: std::marker::PhantomData, } } } -impl dispatcher::Map for Kernel<'static, A, Q, D> -where - A: 'static, - Q: 'static, - D: 'static, -{ - type Type<'a> = Kernel<'a, A, Q, D>; -} - // Map Architectures to the enum. #[derive(Debug, Error)] #[error("architecture {0} is not supported by this CPU")] @@ -508,16 +476,18 @@ match_arch!("x86_64", diskann_wide::arch::x86_64::V4, X86_64_V4); match_arch!("x86_64", diskann_wide::arch::x86_64::V3, X86_64_V3); match_arch!("aarch64", diskann_wide::arch::aarch64::Neon, Neon); -impl<'a, A, Q, D> DispatchRule<&'a SimdOp> for Kernel<'a, A, Q, D> +impl Benchmark for Kernel where datatype::Type: DispatchRule, datatype::Type: DispatchRule, Identity: DispatchRule, + Kernel: RunBenchmark, { - type Error = ArchNotSupported; + type Input = SimdOp; + type Output = Vec; // Matching simply requires that we match the inner type. - fn try_match(from: &&'a SimdOp) -> Result { + fn try_match(from: &SimdOp) -> Result { let mut failscore: Option = None; if datatype::Type::::try_match(&from.query_type).is_err() { *failscore.get_or_insert(0) += 10; @@ -535,29 +505,36 @@ where } } - fn convert(from: &'a SimdOp) -> Result { - assert!(Self::try_match(&from).is_ok()); - let arch = Identity::::convert(from.arch)?.0; - Ok(Self::new(from, arch)) + fn run( + input: &SimdOp, + _: diskann_benchmark_runner::Checkpoint<'_>, + mut output: &mut dyn diskann_benchmark_runner::Output, + ) -> anyhow::Result { + let arch = Identity::::convert(input.arch)?.0; + let kernel = Self::new(arch); + writeln!(output, "{}", input)?; + let results = kernel.run(input)?; + writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; + Ok(results) } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a SimdOp>) -> std::fmt::Result { - match from { + fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&SimdOp>) -> std::fmt::Result { + match input { None => { describeln!( f, "- Query Type: {}", - dispatcher::Description::>::new() + Description::>::new() )?; describeln!( f, "- Data Type: {}", - dispatcher::Description::>::new() + Description::>::new() )?; describeln!( f, "- Implementation: {}", - dispatcher::Description::>::new() + Description::>::new() )?; } Some(input) => { @@ -576,50 +553,12 @@ where } } -impl<'a, A, Q, D> DispatchRule<&'a diskann_benchmark_runner::Any> for Kernel<'a, A, Q, D> -where - Kernel<'a, A, Q, D>: DispatchRule<&'a SimdOp>, - as DispatchRule<&'a SimdOp>>::Error: - std::error::Error + Send + Sync + 'static, -{ - type Error = anyhow::Error; - - fn try_match(from: &&'a diskann_benchmark_runner::Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a diskann_benchmark_runner::Any) -> Result { - from.convert::() - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a diskann_benchmark_runner::Any>, - ) -> std::fmt::Result { - Any::description::(f, from, SimdOp::tag()) - } -} - /////////////// // Benchmark // /////////////// -fn run_benchmark( - kernel: Kernel<'_, A, Q, D>, - _: diskann_benchmark_runner::Checkpoint<'_>, - mut output: &mut dyn diskann_benchmark_runner::Output, -) -> Result -where - for<'a> Kernel<'a, A, Q, D>: RunBenchmark, -{ - writeln!(output, "{}", kernel.input)?; - let results = kernel.run()?; - writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; - Ok(serde_json::to_value(results)?) -} - trait RunBenchmark { - fn run(self) -> Result, anyhow::Error>; + fn run(self, input: &SimdOp) -> Result, anyhow::Error>; } #[derive(Debug, Serialize)] @@ -745,10 +684,10 @@ impl Data { macro_rules! stamp { (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => { - impl RunBenchmark for Kernel<'_, Reference, $Q, $D> { - fn run(self) -> Result, anyhow::Error> { + impl RunBenchmark for Kernel { + fn run(self, input: &SimdOp) -> Result, anyhow::Error> { let mut results = Vec::new(); - for run in self.input.runs.iter() { + for run in input.runs.iter() { let data = Data::<$Q, $D>::new(run); let result = match run.distance { SimilarityMeasure::SquaredL2 => data.run(run, reference::$f_l2), @@ -762,15 +701,15 @@ macro_rules! stamp { } }; ($arch:path, $Q:ty, $D:ty) => { - impl RunBenchmark for Kernel<'_, $arch, $Q, $D> { - fn run(self) -> Result, anyhow::Error> { + impl RunBenchmark for Kernel<$arch, $Q, $D> { + fn run(self, input: &SimdOp) -> Result, anyhow::Error> { let mut results = Vec::new(); let l2 = &simd::L2 {}; let ip = &simd::IP {}; let cosine = &simd::CosineStateless {}; - for run in self.input.runs.iter() { + for run in input.runs.iter() { let data = Data::<$Q, $D>::new(run); // For each kernel, we need to do a two-step wrapping of closures so // the inner-most closure is executed by the architecture. diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index a14d7ea3f..71c89f846 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -8,10 +8,10 @@ use std::io::Write; use diskann::utils::VectorRepr; use diskann_benchmark_runner::{ - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, utils::datatype::{DataType, Type}, - Any, Checkpoint, + Benchmark, Checkpoint, }; use diskann_providers::storage::FileStorageProvider; use half::f16; @@ -82,39 +82,26 @@ where } } -impl dispatcher::Map for DiskIndex<'static, T> -where - T: 'static, -{ - type Type<'a> = DiskIndex<'a, T>; -} - -/// Dispatch to Disk Index operations. -impl<'a, T> DispatchRule<&'a DiskIndexOperation> for DiskIndex<'a, T> +impl Benchmark for DiskIndex<'static, T> where + T: VectorRepr + 'static, Type: DispatchRule, - T: VectorRepr, { - type Error = std::convert::Infallible; + type Input = DiskIndexOperation; + type Output = DiskIndexStats; - // Matching simply requires that we match the inner type. - fn try_match(from: &&'a DiskIndexOperation) -> Result { - match &from.source { + fn try_match(input: &DiskIndexOperation) -> Result { + match &input.source { DiskIndexSource::Load(load) => Type::::try_match(&load.data_type), DiskIndexSource::Build(build) => Type::::try_match(&build.data_type), } } - fn convert(from: &'a DiskIndexOperation) -> Result { - Ok(Self::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a DiskIndexOperation>, + input: Option<&DiskIndexOperation>, ) -> std::fmt::Result { - // At this level, we only care about the data type, so return that description. - match from { + match input { Some(arg) => match &arg.source { DiskIndexSource::Load(load) => Type::::description(f, Some(&load.data_type)), DiskIndexSource::Build(build) => Type::::description(f, Some(&build.data_type)), @@ -122,26 +109,13 @@ where None => Type::::description(f, None::<&DataType>), } } -} - -/// Central Dispatch -impl<'a, T> DispatchRule<&'a Any> for DiskIndex<'a, T> -where - Type: DispatchRule, - T: VectorRepr, -{ - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - fn convert(from: &'a Any) -> Result { - from.convert::() - } - - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, DiskIndexOperation::tag()) + fn run( + input: &DiskIndexOperation, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + DiskIndex::::new(input).run(checkpoint, output) } } @@ -149,18 +123,9 @@ where // Benchmark Registration // //////////////////////////// -macro_rules! register_disk_index { - ($registry:ident, $name:literal, $t:ty) => { - $registry.register::>($name, |object, checkpoint, output| { - let res = object.run(checkpoint, output)?; - Ok(serde_json::to_value(res)?) - }); - }; -} - pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - register_disk_index!(benchmarks, "disk-index-f32", f32); - register_disk_index!(benchmarks, "disk-index-f16", f16); - register_disk_index!(benchmarks, "disk-index-u8", u8); - register_disk_index!(benchmarks, "disk-index-i8", i8); + benchmarks.register::>("disk-index-f32"); + benchmarks.register::>("disk-index-f16"); + benchmarks.register::>("disk-index-u8"); + benchmarks.register::>("disk-index-i8"); } diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index 65e15ee68..bd8e88952 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -12,37 +12,10 @@ crate::utils::stub_impl!("minmax-quantization", inputs::exhaustive::MinMax); // MinMax - requires feature "minmax-quantization" #[cfg(feature = "minmax-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>( - NAME, - |object, _checkpoint, output| match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - }, - ); - - benchmarks.register::>( - NAME, - |object, _checkpoint, output| match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - }, - ); - - benchmarks.register::>( - NAME, - |object, _checkpoint, output| match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - }, - ); - - benchmarks.register::>( - NAME, - |object, _checkpoint, output| match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - }, - ); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); } // Stub implementation @@ -61,9 +34,9 @@ mod imp { use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, - Any, Output, + Benchmark, Output, }; use diskann_quantization::{ algorithms::transforms::Transform, @@ -125,32 +98,6 @@ mod imp { pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result where Unsigned: Representation, - MinMaxL2Squared: for<'x, 'y> PureDistanceFunction< - DataRef<'x, NBITS>, - DataRef<'y, NBITS>, - distances::Result, - > + for<'x, 'y> PureDistanceFunction< - minmax::FullQueryRef<'x>, - DataRef<'y, NBITS>, - distances::Result, - >, - MinMaxIP: for<'x, 'y> PureDistanceFunction< - DataRef<'x, NBITS>, - DataRef<'y, NBITS>, - distances::Result, - > + for<'x, 'y> PureDistanceFunction< - minmax::FullQueryRef<'x>, - DataRef<'y, NBITS>, - distances::Result, - > + for<'x, 'y> PureDistanceFunction< - DataRef<'x, NBITS>, - DataRef<'y, NBITS>, - distances::MathematicalResult, - > + for<'x, 'y> PureDistanceFunction< - minmax::FullQueryRef<'x>, - DataRef<'y, NBITS>, - distances::MathematicalResult, - >, Plan: algos::CreateQuantComputer>, { let input = &self.input; @@ -251,15 +198,16 @@ mod imp { } } - impl dispatcher::Map for MinMaxQ<'static, NBITS> { - type Type<'a> = MinMaxQ<'a, NBITS>; - } - - impl<'a, const NBITS: usize> DispatchRule<&'a inputs::exhaustive::MinMax> for MinMaxQ<'a, NBITS> { - type Error = std::convert::Infallible; + impl Benchmark for MinMaxQ<'static, NBITS> + where + Unsigned: Representation, + Plan: algos::CreateQuantComputer>, + { + type Input = inputs::exhaustive::MinMax; + type Output = Results; - fn try_match(from: &&'a inputs::exhaustive::MinMax) -> Result { - let num_bits = from.num_bits.get(); + fn try_match(input: &inputs::exhaustive::MinMax) -> Result { + let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) } else { @@ -269,20 +217,11 @@ mod imp { } } - fn convert(from: &'a inputs::exhaustive::MinMax) -> Result { - assert_eq!( - from.num_bits.get(), - NBITS, - "This should not have occurred. Please file a bug report" - ); - Ok(Self::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a inputs::exhaustive::MinMax>, + input: Option<&inputs::exhaustive::MinMax>, ) -> std::fmt::Result { - match from { + match input { None => { describeln!( f, @@ -305,28 +244,13 @@ mod imp { } Ok(()) } - } - - impl<'a, const NBITS: usize> DispatchRule<&'a Any> for MinMaxQ<'a, NBITS> { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - fn convert(from: &'a Any) -> Result { - from.convert::() - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::( - f, - from, - inputs::exhaustive::MinMax::tag(), - ) + fn run( + input: &inputs::exhaustive::MinMax, + _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + MinMaxQ::::new(input).run(output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 247711df9..ff5ab7c27 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -11,12 +11,7 @@ crate::utils::stub_impl!("product-quantization", inputs::exhaustive::Product); pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "product-quantization")] - benchmarks.register::>(NAME, |object, _checkpoint, output| match object - .run(output) - { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - }); + benchmarks.register::>(NAME); #[cfg(not(feature = "product-quantization"))] imp::register(NAME, benchmarks) @@ -32,9 +27,9 @@ mod imp { use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, - Any, Output, + Benchmark, Output, }; use diskann_quantization::{product::train::TrainQuantizer, CompressInto}; use indicatif::{ProgressBar, ProgressStyle}; @@ -196,53 +191,31 @@ mod imp { } } - impl dispatcher::Map for ProductQ<'static> { - type Type<'a> = ProductQ<'a>; - } - - impl<'a> DispatchRule<&'a inputs::exhaustive::Product> for ProductQ<'a> { - type Error = std::convert::Infallible; + impl Benchmark for ProductQ<'static> { + type Input = inputs::exhaustive::Product; + type Output = Results; - fn try_match(_from: &&'a inputs::exhaustive::Product) -> Result { + fn try_match(_input: &inputs::exhaustive::Product) -> Result { Ok(MatchScore(0)) } - fn convert(from: &'a inputs::exhaustive::Product) -> Result { - Ok(Self::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a inputs::exhaustive::Product>, + input: Option<&inputs::exhaustive::Product>, ) -> std::fmt::Result { - if from.is_none() { + if input.is_none() { describeln!(f, "- Exhaustive search for product quantization",)?; describeln!(f, "- Requires `float32` data")?; } Ok(()) } - } - - impl<'a> DispatchRule<&'a Any> for ProductQ<'a> { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a Any) -> Result { - from.convert::() - } - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::( - f, - from, - inputs::exhaustive::Product::tag(), - ) + fn run( + input: &inputs::exhaustive::Product, + _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + ProductQ::new(input).run(output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index 87c0dce51..1c0881c56 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -12,33 +12,10 @@ crate::utils::stub_impl!("spherical-quantization", inputs::exhaustive::Spherical // Spherical - requires feature "spherical-quantization" #[cfg(feature = "spherical-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>(NAME, |object, _checkpoint, output| { - match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - - benchmarks.register::>(NAME, |object, _checkpoint, output| { - match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - - benchmarks.register::>(NAME, |object, _checkpoint, output| { - match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - - benchmarks.register::>(NAME, |object, _checkpoint, output| { - match object.run(output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); } // Stub implementation @@ -57,10 +34,9 @@ mod imp { use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, - output::Output, + dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, - Any, + Benchmark, Output, }; use diskann_providers::model::graph::provider::async_::distances::UnwrapErr; use diskann_quantization::{ @@ -226,17 +202,20 @@ mod imp { } } - impl dispatcher::Map for SphericalQ<'static, NBITS> { - type Type<'a> = SphericalQ<'a, NBITS>; - } - - impl<'a, const NBITS: usize> DispatchRule<&'a inputs::exhaustive::Spherical> - for SphericalQ<'a, NBITS> + impl Benchmark for SphericalQ<'static, NBITS> + where + Unsigned: Representation, + Plan: algos::CreateQuantComputer>, + diskann_quantization::spherical::iface::Impl: + diskann_quantization::spherical::iface::Constructible, + SphericalQuantizer: + for<'x> CompressIntoWith<&'x [f32], DataMut<'x, NBITS>, ScopedAllocator<'x>>, { - type Error = std::convert::Infallible; + type Input = inputs::exhaustive::Spherical; + type Output = Results; - fn try_match(from: &&'a inputs::exhaustive::Spherical) -> Result { - let num_bits = from.num_bits.get(); + fn try_match(input: &inputs::exhaustive::Spherical) -> Result { + let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) } else { @@ -246,20 +225,11 @@ mod imp { } } - fn convert(from: &'a inputs::exhaustive::Spherical) -> Result { - assert_eq!( - from.num_bits.get(), - NBITS, - "This should not have occurred. Please file a bug report" - ); - Ok(Self::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a inputs::exhaustive::Spherical>, + input: Option<&inputs::exhaustive::Spherical>, ) -> std::fmt::Result { - match from { + match input { None => { describeln!( f, @@ -282,28 +252,13 @@ mod imp { } Ok(()) } - } - - impl<'a, const NBITS: usize> DispatchRule<&'a Any> for SphericalQ<'a, NBITS> { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - fn convert(from: &'a Any) -> Result { - from.convert::() - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::( - f, - from, - inputs::exhaustive::Spherical::tag(), - ) + fn run( + input: &inputs::exhaustive::Spherical, + _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + SphericalQ::::new(input).run(output) } } diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index ddd79e944..a90ea41ed 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -5,11 +5,11 @@ use anyhow::Result; use diskann_benchmark_runner::{ - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{FailureScore, MatchScore}, output::Output, registry::Benchmarks, utils::{percentiles, MicroSeconds}, - Any, Checkpoint, + Benchmark, Checkpoint, }; use diskann_label_filter::{ kv_index::GenericIndex, @@ -29,14 +29,7 @@ use crate::{ }; pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - // Register the metadata index job - benchmarks.register::>( - "metadata-index-build", - |job, checkpoint, out| { - let stats = job.run(checkpoint, out)?; - Ok(serde_json::to_value(stats)?) - }, - ); + benchmarks.register::>("metadata-index-build"); } // Metadata-only index job wrapper @@ -50,27 +43,17 @@ impl<'a> MetadataIndexJob<'a> { } } -impl dispatcher::Map for MetadataIndexJob<'static> { - type Type<'a> = MetadataIndexJob<'a>; -} - -// Dispatch from the concrete input type -impl<'a> DispatchRule<&'a crate::inputs::filters::MetadataIndexBuild> for MetadataIndexJob<'a> { - type Error = std::convert::Infallible; +impl Benchmark for MetadataIndexJob<'static> { + type Input = MetadataIndexBuild; + type Output = MetadataIndexBuildStats; - fn try_match( - _from: &&'a crate::inputs::filters::MetadataIndexBuild, - ) -> Result { + fn try_match(_input: &MetadataIndexBuild) -> Result { Ok(MatchScore(1)) } - fn convert(from: &'a crate::inputs::filters::MetadataIndexBuild) -> Result { - Ok(MetadataIndexJob::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - _from: Option<&&'a crate::inputs::filters::MetadataIndexBuild>, + _input: Option<&MetadataIndexBuild>, ) -> std::fmt::Result { writeln!( f, @@ -78,22 +61,13 @@ impl<'a> DispatchRule<&'a crate::inputs::filters::MetadataIndexBuild> for Metada crate::inputs::filters::MetadataIndexBuild::tag() ) } -} - -// Central dispatch mapping -impl<'a> DispatchRule<&'a Any> for MetadataIndexJob<'a> { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a Any) -> Result { - from.convert::() - } - fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, MetadataIndexBuild::tag()) + fn run( + input: &MetadataIndexBuild, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + MetadataIndexJob::new(input).run(checkpoint, output) } } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index fa4a77078..638625468 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -20,7 +20,7 @@ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, utils::datatype, - Any, Checkpoint, + Benchmark, Checkpoint, }; use diskann_providers::{ index::diskann_async, @@ -55,78 +55,18 @@ use crate::{ // Benchmark Registration // //////////////////////////// -macro_rules! register { - ($disp:ident, $name:literal, $bench_type:ty) => { - $disp.register::<$bench_type>($name, |object, checkpoint, output| { - match <_ as $crate::backend::index::benchmarks::BuildAndSearch>::run( - object, checkpoint, output, - ) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - }; -} -macro_rules! register_streaming { - ($disp:ident, $name:literal, $bench_type:ty) => { - $disp.register::<$bench_type>($name, |object, checkpoint, output| { - match <_ as $crate::backend::index::benchmarks::BuildAndDynamicRun>::run( - object, checkpoint, output, - ) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - }; -} - -#[cfg(any(feature = "product-quantization", feature = "scalar-quantization"))] -pub(super) use register; - pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { // Full Precision - register!( - benchmarks, - "async-full-precision-f32", - FullPrecision<'static, f32> - ); - register!( - benchmarks, - "async-full-precision-f16", - FullPrecision<'static, f16> - ); - register!( - benchmarks, - "async-full-precision-u8", - FullPrecision<'static, u8> - ); - register!( - benchmarks, - "async-full-precision-i8", - FullPrecision<'static, i8> - ); + benchmarks.register::>("async-full-precision-f32"); + benchmarks.register::>("async-full-precision-f16"); + benchmarks.register::>("async-full-precision-u8"); + benchmarks.register::>("async-full-precision-i8"); // Dynamic Full Precision - register_streaming!( - benchmarks, - "async-dynamic-full-precision-f32", - DynamicFullPrecision<'static, f32> - ); - register_streaming!( - benchmarks, - "async-dynamic-full-precision-f16", - DynamicFullPrecision<'static, f16> - ); - register_streaming!( - benchmarks, - "async-dynamic-full-precision-u8", - DynamicFullPrecision<'static, u8> - ); - register_streaming!( - benchmarks, - "async-dynamic-full-precision-i8", - DynamicFullPrecision<'static, i8> - ); + benchmarks.register::>("async-dynamic-full-precision-f32"); + benchmarks.register::>("async-dynamic-full-precision-f16"); + benchmarks.register::>("async-dynamic-full-precision-u8"); + benchmarks.register::>("async-dynamic-full-precision-i8"); product::register_benchmarks(benchmarks); scalar::register_benchmarks(benchmarks); @@ -176,38 +116,28 @@ impl<'a, T> FullPrecision<'a, T> { } } -impl diskann_benchmark_runner::dispatcher::Map for FullPrecision<'static, T> -where - T: 'static, -{ - type Type<'a> = FullPrecision<'a, T>; -} - -/// Dispatch to a full-precision only build. -impl<'a, T> DispatchRule<&'a IndexOperation> for FullPrecision<'a, T> +impl Benchmark for FullPrecision<'static, T> where + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, datatype::Type: DispatchRule, { - type Error = std::convert::Infallible; + type Input = IndexOperation; + type Output = BuildResult; - // Matching simply requires that we match the inner type. - fn try_match(from: &&'a IndexOperation) -> Result { - match &from.source { + fn try_match(input: &IndexOperation) -> Result { + match &input.source { IndexSource::Load(load) => datatype::Type::::try_match(&load.data_type), IndexSource::Build(build) => datatype::Type::::try_match(&build.data_type), } } - fn convert(from: &'a IndexOperation) -> Result { - Ok(Self::new(from)) - } - fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a IndexOperation>, + input: Option<&IndexOperation>, ) -> std::fmt::Result { - // At this level, we only care about the data type, so return that description. - match from { + match input { Some(arg) => match &arg.source { IndexSource::Load(load) => { datatype::Type::::description(f, Some(&load.data_type)) @@ -219,25 +149,13 @@ where None => datatype::Type::::description(f, None::<&datatype::DataType>), } } -} - -/// Central Dispatch -impl<'a, T> DispatchRule<&'a Any> for FullPrecision<'a, T> -where - datatype::Type: DispatchRule, -{ - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a Any) -> Result { - from.convert::() - } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, IndexOperation::tag()) + fn run( + input: &IndexOperation, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + BuildAndSearch::run(FullPrecision::::new(input), checkpoint, output) } } @@ -256,49 +174,33 @@ impl<'a, T> DynamicFullPrecision<'a, T> { } } -impl diskann_benchmark_runner::dispatcher::Map for DynamicFullPrecision<'static, T> -where - T: 'static, -{ - type Type<'a> = DynamicFullPrecision<'a, T>; -} - -/// Dispatch to a dynamic full-precision async index run. -impl<'a, T> DispatchRule<&'a DynamicIndexRun> for DynamicFullPrecision<'a, T> +impl Benchmark for DynamicFullPrecision<'static, T> where + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, datatype::Type: DispatchRule, { - type Error = std::convert::Infallible; - // Matching simply requires that we match the inner type. - fn try_match(from: &&'a DynamicIndexRun) -> Result { - datatype::Type::::try_match(&from.build.data_type) - } - fn convert(from: &'a DynamicIndexRun) -> Result { - Ok(Self::new(from)) + type Input = DynamicIndexRun; + type Output = Vec>; + + fn try_match(input: &DynamicIndexRun) -> Result { + datatype::Type::::try_match(&input.build.data_type) } + fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a DynamicIndexRun>, + input: Option<&DynamicIndexRun>, ) -> std::fmt::Result { - // At this level, we only care about the data type, so return that description. - datatype::Type::::description(f, from.map(|f| f.build.data_type).as_ref()) + datatype::Type::::description(f, input.map(|f| f.build.data_type).as_ref()) } -} -/// Central Dispatch -impl<'a, T> DispatchRule<&'a Any> for DynamicFullPrecision<'a, T> -where - datatype::Type: DispatchRule, -{ - type Error = anyhow::Error; - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - fn convert(from: &'a Any) -> Result { - from.convert::() - } - fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a Any>) -> std::fmt::Result { - Any::description::(f, from, DynamicIndexRun::tag()) + fn run( + input: &DynamicIndexRun, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result>> { + BuildAndDynamicRun::run(DynamicFullPrecision::::new(input), checkpoint, output) } } @@ -493,152 +395,141 @@ where } } -macro_rules! impl_build { - ($T:ty) => { - impl<'a> BuildAndSearch<'a> for FullPrecision<'a, $T> { - type Data = BuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - let (index, build_stats) = match &self.input.source { - IndexSource::Build(build) => { - let (index, build_stats) = run_build( - &build, - common::FullPrecision, - None, - output, - |data| { - let index = diskann_async::new_index::<$T, _>( - build.try_as_config()?.build()?, - build.inmem_parameters(data.nrows(), data.ncols()), - common::NoDeletes, - )?; - build::set_start_points( - index.provider(), - data.as_view(), - build.start_point_strategy, - )?; - Ok(index) - }, - single_or_multi_insert, +impl<'a, T> BuildAndSearch<'a> for FullPrecision<'a, T> +where + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, +{ + type Data = BuildResult; + fn run( + self, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + writeln!(output, "{}", self.input)?; + let (index, build_stats) = match &self.input.source { + IndexSource::Build(build) => { + let (index, build_stats) = run_build( + build, + common::FullPrecision, + None, + output, + |data| { + let index = diskann_async::new_index::( + build.try_as_config()?.build()?, + build.inmem_parameters(data.nrows(), data.ncols()), + common::NoDeletes, )?; + build::set_start_points( + index.provider(), + data.as_view(), + build.start_point_strategy, + )?; + Ok(index) + }, + single_or_multi_insert, + )?; - // save the index if requested - if let Some(save_path) = &build.save_path { - utils::tokio::block_on(save_index(index.clone(), &save_path))?; - } - - (index, Some(build_stats)) - } - IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &load.to_config()?; - - let index = { - utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? - }; + // save the index if requested + if let Some(save_path) = &build.save_path { + utils::tokio::block_on(save_index(index.clone(), save_path))?; + } - (Arc::new(index), None::) - } - }; + (index, Some(build_stats)) + } + IndexSource::Load(load) => { + let index_config: &IndexConfiguration = &load.to_config()?; - let result = run_search_outer( - &self.input.search_phase, - common::FullPrecision, - index, - build_stats, - checkpoint, - )?; + let index = + { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; - writeln!(output, "\n\n{}", result)?; - Ok(result) + (Arc::new(index), None::) } - } - }; + }; + + let result = run_search_outer( + &self.input.search_phase, + common::FullPrecision, + index, + build_stats, + checkpoint, + )?; + + writeln!(output, "\n\n{}", result)?; + Ok(result) + } } -impl_build!(f32); -impl_build!(f16); -impl_build!(u8); -impl_build!(i8); - -macro_rules! impl_dynamic_run { - ($T:ty) => { - impl<'a> BuildAndDynamicRun<'a> for DynamicFullPrecision<'a, $T> { - type Data = Vec>; - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - - let groundtruth_directory = self - .input - .runbook_params - .resolved_gt_directory - .as_ref() - .ok_or_else(|| { - anyhow::anyhow!( - "Ground truth directory path was not resolved during validation" - ) - })?; - - let mut runbook = bigann::RunBook::load( - &self.input.runbook_params.runbook_path, - &self.input.runbook_params.dataset_name, - &mut bigann::ScanDirectory::new(groundtruth_directory)?, - )?; +impl<'a, T> BuildAndDynamicRun<'a> for DynamicFullPrecision<'a, T> +where + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, +{ + type Data = Vec>; + fn run( + self, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + writeln!(output, "{}", self.input)?; - let mut streamer = full_precision_streaming(&self, runbook.max_points())?; - - let mut results = Vec::new(); - let stages = runbook.len(); - let mut i = 1; - - runbook.run_with( - &mut streamer, - |o: managed::Stats| -> anyhow::Result<()> { - if o.inner().is_maintain() { - let message = format!("Ran maintenance before stage {}", i); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - } else { - let message = - format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - i += 1; - } - writeln!(output, "{}", o)?; - results.push(o); - Ok(()) - }, - )?; + let groundtruth_directory = self + .input + .runbook_params + .resolved_gt_directory + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!("Ground truth directory path was not resolved during validation") + })?; + + let mut runbook = bigann::RunBook::load( + &self.input.runbook_params.runbook_path, + &self.input.runbook_params.dataset_name, + &mut bigann::ScanDirectory::new(groundtruth_directory)?, + )?; + + let mut streamer = full_precision_streaming(&self, runbook.max_points())?; + + let mut results = Vec::new(); + let stages = runbook.len(); + let mut i = 1; + + runbook.run_with( + &mut streamer, + |o: managed::Stats| -> anyhow::Result<()> { + if o.inner().is_maintain() { + let message = format!("Ran maintenance before stage {}", i); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + } else { + let message = + format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + i += 1; + } + writeln!(output, "{}", o)?; + results.push(o); + Ok(()) + }, + )?; - write!( - output, - "{}", - crate::utils::SmallBanner("End of Run Summary") - )?; + write!( + output, + "{}", + crate::utils::SmallBanner("End of Run Summary") + )?; - writeln!( - output, - "{}", - streaming::stats::Summary::new(results.iter().map(|r| r.inner())) - )?; + writeln!( + output, + "{}", + streaming::stats::Summary::new(results.iter().map(|r| r.inner())) + )?; - Ok(results) - } - } - }; + Ok(results) + } } -impl_dynamic_run!(f32); -impl_dynamic_run!(f16); -impl_dynamic_run!(u8); -impl_dynamic_run!(i8); - /// The stack looks like this: /// /// - Bottom: [`FullPrecisionStream`]: The core streaming index implementation. diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index fdf8d2be9..a857e4e57 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -13,18 +13,8 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { { use half::f16; - use crate::backend::index::benchmarks::register; - - register!( - benchmarks, - "async-pq-f32", - imp::ProductQuantized<'static, f32> - ); - register!( - benchmarks, - "async-pq-f16", - imp::ProductQuantized<'static, f16> - ); + benchmarks.register::>("async-pq-f32"); + benchmarks.register::>("async-pq-f16"); } // Stub implementation @@ -36,6 +26,7 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { mod imp { use std::{io::Write, sync::Arc}; + use diskann::utils::VectorRepr; use diskann_providers::{ index::diskann_async::{self}, model::{graph::provider::async_::common, IndexConfiguration}, @@ -43,11 +34,10 @@ mod imp { use diskann_utils::views::{Matrix, MatrixView}; use diskann_benchmark_runner::{ - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, - utils::MicroSeconds, - Any, Checkpoint, Output, + dispatcher::{DispatchRule, FailureScore, MatchScore}, + utils::{datatype, MicroSeconds}, + Benchmark, Checkpoint, Output, }; - use half::f16; use rand::{rngs::StdRng, SeedableRng}; use crate::{ @@ -56,7 +46,7 @@ mod imp { build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, result::QuantBuildResult, }, - inputs::async_::{IndexOperation, IndexPQOperation, IndexSource}, + inputs::async_::{IndexPQOperation, IndexSource}, utils::{self, datafiles}, }; @@ -74,174 +64,147 @@ mod imp { } } - impl dispatcher::Map for ProductQuantized<'static, T> + impl Benchmark for ProductQuantized<'static, T> where - T: 'static, + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, + datatype::Type: DispatchRule, { - type Type<'a> = ProductQuantized<'a, T>; - } - - impl<'a, T> DispatchRule<&'a IndexPQOperation> for ProductQuantized<'a, T> - where - FullPrecision<'a, T>: DispatchRule<&'a IndexOperation>, - { - type Error = std::convert::Infallible; - - // Matching simply requires that we match the inner type. - fn try_match(from: &&'a IndexPQOperation) -> Result { - FullPrecision::<'a, T>::try_match(&&from.index_operation) - } + type Input = IndexPQOperation; + type Output = QuantBuildResult; - fn convert(from: &'a IndexPQOperation) -> Result { - Ok(Self::new(from)) + fn try_match(input: &IndexPQOperation) -> Result { + as Benchmark>::try_match(&input.index_operation) } fn description( f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a IndexPQOperation>, + input: Option<&IndexPQOperation>, ) -> std::fmt::Result { - FullPrecision::<'a, T>::description(f, from.map(|f| &f.index_operation).as_ref()) + as Benchmark>::description( + f, + input.map(|f| &f.index_operation), + ) + } + + fn run( + input: &IndexPQOperation, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + let pq = ProductQuantized::::new(input); + BuildAndSearch::run(pq, checkpoint, output) } } - impl<'a, T> DispatchRule<&'a Any> for ProductQuantized<'a, T> + impl<'a, T> BuildAndSearch<'a> for ProductQuantized<'a, T> where - ProductQuantized<'a, T>: - DispatchRule<&'a IndexPQOperation, Error = std::convert::Infallible>, + T: VectorRepr + + diskann_utils::sampling::WithApproximateNorm + + diskann::graph::SampleableForStart, + datatype::Type: DispatchRule, { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } + type Data = QuantBuildResult; + fn run( + self, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + writeln!(output, "{}", self.input)?; - fn convert(from: &'a Any) -> Result { - from.convert::() - } + let hybrid = common::Hybrid::new(self.input.max_fp_vecs_per_prune); - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::(f, from, IndexPQOperation::tag()) - } - } + let (index, build_stats, quant_training_time) = match &self.input.index_operation.source + { + IndexSource::Load(load) => { + let index_config: &IndexConfiguration = &self.input.to_config()?; - macro_rules! impl_pq_build { - ($T:ty) => { - impl<'a> BuildAndSearch<'a> for ProductQuantized<'a, $T> { - type Data = QuantBuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - - let hybrid = common::Hybrid::new(self.input.max_fp_vecs_per_prune); - - let (index, build_stats, quant_training_time) = match &self - .input - .index_operation - .source - { - IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &self.input.to_config()?; - - let index = { - utils::tokio::block_on(load_index::<_>( - &load.load_path, - index_config, - ))? - }; - - (Arc::new(index), None::, MicroSeconds::new(0)) - } - IndexSource::Build(build) => { - let data: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&build.data))?); - - let start = std::time::Instant::now(); - let table = { - let train_data = Matrix::try_from( - data.as_slice().iter().copied().map(f32::from).collect(), - data.nrows(), - data.ncols(), - )?; - - diskann_async::train_pq( - train_data.as_view(), - self.input.num_pq_chunks, - &mut StdRng::seed_from_u64(self.input.seed), - build.num_threads, - )? - }; - - let create_index = |data_view: MatrixView<$T>| { - let index = diskann_async::new_quant_index::<$T, _, _>( - self.input.try_as_config()?.build()?, - self.input - .inmem_parameters(data_view.nrows(), data_view.ncols())?, - table, - common::NoDeletes, - )?; - build::set_start_points( - index.provider(), - data_view, - build.start_point_strategy, - )?; - Ok(index) - }; - let quant_training_time: MicroSeconds = start.elapsed().into(); - - let (index, build_stats) = run_build( - build, - hybrid, - None, - output, - create_index, - single_or_multi_insert, - )?; - - // Save the index if requested - if let Some(save_path) = &build.save_path { - utils::tokio::block_on(save_index(index.clone(), save_path))?; - } - - (index, Some(build_stats), quant_training_time) - } - }; + let index = + { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; - let build = if self.input.use_fp_for_search { - run_search_outer( - &self.input.index_operation.search_phase, - common::FullPrecision, - index, - build_stats, - checkpoint, - )? - } else { - run_search_outer( - &self.input.index_operation.search_phase, - hybrid, - index, - build_stats, - checkpoint, + (Arc::new(index), None::, MicroSeconds::new(0)) + } + IndexSource::Build(build) => { + let data: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&build.data))?); + + let start = std::time::Instant::now(); + let table = { + let train_data = Matrix::try_from( + (&*T::as_f32(data.as_slice())?).into(), + data.nrows(), + data.ncols(), + )?; + + diskann_async::train_pq( + train_data.as_view(), + self.input.num_pq_chunks, + &mut StdRng::seed_from_u64(self.input.seed), + build.num_threads, )? }; - let result = QuantBuildResult { - quant_training_time, - build, + let create_index = |data_view: MatrixView| { + let index = diskann_async::new_quant_index::( + self.input.try_as_config()?.build()?, + self.input + .inmem_parameters(data_view.nrows(), data_view.ncols())?, + table, + common::NoDeletes, + )?; + build::set_start_points( + index.provider(), + data_view, + build.start_point_strategy, + )?; + Ok(index) }; + let quant_training_time: MicroSeconds = start.elapsed().into(); - writeln!(output, "\n\n{}", result)?; - Ok(result) + let (index, build_stats) = run_build( + build, + hybrid, + None, + output, + create_index, + single_or_multi_insert, + )?; + + // Save the index if requested + if let Some(save_path) = &build.save_path { + utils::tokio::block_on(save_index(index.clone(), save_path))?; + } + + (index, Some(build_stats), quant_training_time) } - } - }; + }; + + let build = if self.input.use_fp_for_search { + run_search_outer( + &self.input.index_operation.search_phase, + common::FullPrecision, + index, + build_stats, + checkpoint, + )? + } else { + run_search_outer( + &self.input.index_operation.search_phase, + hybrid, + index, + build_stats, + checkpoint, + )? + }; + + let result = QuantBuildResult { + quant_training_time, + build, + }; + + writeln!(output, "\n\n{}", result)?; + Ok(result) + } } - - impl_pq_build!(f32); - impl_pq_build!(f16); } diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 4e572a6e8..b418c0d7b 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -13,56 +13,18 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { { use half::f16; - use crate::backend::index::benchmarks::register; - // f32 - register!( - benchmarks, - "async-sq-8-bit-f32", - imp::ScalarQuantized<'static, 8, f32> - ); - register!( - benchmarks, - "async-sq-4-bit-f32", - imp::ScalarQuantized<'static, 4, f32> - ); - register!( - benchmarks, - "async-sq-2-bit-f32", - imp::ScalarQuantized<'static, 2, f32> - ); - register!( - benchmarks, - "async-sq-1-bit-f32", - imp::ScalarQuantized<'static, 1, f32> - ); + benchmarks.register::>("async-sq-8-bit-f32"); + benchmarks.register::>("async-sq-4-bit-f32"); + benchmarks.register::>("async-sq-2-bit-f32"); + benchmarks.register::>("async-sq-1-bit-f32"); // f16 - register!( - benchmarks, - "async-sq-8-bit-f16", - imp::ScalarQuantized<'static, 8, f16> - ); - register!( - benchmarks, - "async-sq-4-bit-f16", - imp::ScalarQuantized<'static, 4, f16> - ); - register!( - benchmarks, - "async-sq-2-bit-f16", - imp::ScalarQuantized<'static, 2, f16> - ); - register!( - benchmarks, - "async-sq-1-bit-f16", - imp::ScalarQuantized<'static, 1, f16> - ); + benchmarks.register::>("async-sq-8-bit-f16"); + benchmarks.register::>("async-sq-4-bit-f16"); + benchmarks.register::>("async-sq-2-bit-f16"); + benchmarks.register::>("async-sq-1-bit-f16"); // i8 - register!( - benchmarks, - "async-sq-1-bit-i8", - imp::ScalarQuantized<'static, 1, i8> - ); + benchmarks.register::>("async-sq-1-bit-i8"); } // Stub implementation @@ -77,9 +39,9 @@ mod imp { use anyhow::Context; use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::{datatype, MicroSeconds}, - Any, Checkpoint, Output, + Benchmark, Checkpoint, Output, }; use diskann_providers::{ index::diskann_async::{self}, @@ -117,138 +79,109 @@ mod imp { } } - impl dispatcher::Map for ScalarQuantized<'static, NBITS, T> - where - T: 'static, - { - type Type<'a> = ScalarQuantized<'a, NBITS, T>; - } - - impl<'a, const NBITS: usize, T> DispatchRule<&'a IndexSQOperation> for ScalarQuantized<'a, NBITS, T> - where - datatype::Type: DispatchRule, - { - type Error = std::convert::Infallible; - - fn try_match(from: &&'a IndexSQOperation) -> Result { - // If this is multi-insert, return a very-close failure. - let mut failure_score: Option = None; - match from.index_operation.source { - IndexSource::Load(_) => {} - IndexSource::Build(ref build) => { - // If the build is not compatible, return a failure score. - if build.multi_insert.is_some() { - failure_score = Some(1); + macro_rules! impl_sq_build { + ($N:literal, $T: ty) => { + impl Benchmark for ScalarQuantized<'static, $N, $T> { + type Input = IndexSQOperation; + type Output = QuantBuildResult; + + fn try_match(input: &IndexSQOperation) -> Result { + let mut failure_score: Option = None; + match input.index_operation.source { + IndexSource::Load(_) => {} + IndexSource::Build(ref build) => { + if build.multi_insert.is_some() { + failure_score = Some(1); + } + } } - } - } - - // make sure the data type is correct - if let Err(FailureScore(_)) = FullPrecision::<'a, T>::try_match(&&from.index_operation) - { - *failure_score.get_or_insert(0) += 1; - } - // Make sure the number of bits is correct. - if from.num_bits != NBITS { - *failure_score.get_or_insert(0) += 10 + NBITS.abs_diff(from.num_bits) as u32; - } - - match failure_score { - None => Ok(MatchScore(0)), - Some(score) => Err(FailureScore(score)), - } - } + if as Benchmark>::try_match(&input.index_operation) + .is_err() + { + *failure_score.get_or_insert(0) += 1; + } - fn convert(from: &'a IndexSQOperation) -> Result { - Ok(Self::new(from)) - } + if input.num_bits != $N { + *failure_score.get_or_insert(0) += 10 + ($N as usize).abs_diff(input.num_bits) as u32; + } - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a IndexSQOperation>, - ) -> std::fmt::Result { - match from { - None => { - describeln!( - f, - "- Index Build and Search using {} scalar quantized bits", - NBITS - )?; - describeln!( - f, - "- Requires `{}` data", - dispatcher::Description::>::new(), - )?; - describeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; - describeln!(f, "- Does not support multi-insert")?; - } - Some(input) => { - if input.num_bits != NBITS { - describeln!( - f, - "- Expected {} bits, instead got {}", - NBITS, - input.num_bits - )?; + match failure_score { + None => Ok(MatchScore(0)), + Some(score) => Err(FailureScore(score)), } + } - let mut check_match = |data_type: &datatype::DataType| { - if datatype::Type::::try_match(data_type).is_err() { + fn description( + f: &mut std::fmt::Formatter<'_>, + input: Option<&IndexSQOperation>, + ) -> std::fmt::Result { + match input { + None => { describeln!( f, - "- Only `{}` data type is supported. Instead, got {}", - dispatcher::Description::>::new(), - data_type - ).unwrap(); - } - }; - - match &input.index_operation.source { - IndexSource::Load(load) => { - check_match(&load.data_type); + "- Index Build and Search using {} scalar quantized bits", + $N + )?; + describeln!( + f, + "- Requires `{}` data", + Description::>::new(), + )?; + describeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; + describeln!(f, "- Does not support multi-insert")?; } - IndexSource::Build(build) => { - check_match(&build.data_type); - - if build.multi_insert.is_some() { + Some(input) => { + if input.num_bits != $N { describeln!( f, - "- Scalar Quantization does not support multi-insert" + "- Expected {} bits, instead got {}", + $N, + input.num_bits )?; } + + let mut check_match = |data_type: &datatype::DataType| { + if datatype::Type::<$T>::try_match(data_type).is_err() { + describeln!( + f, + "- Only `{}` data type is supported. Instead, got {}", + Description::>::new(), + data_type + ).unwrap(); + } + }; + + match &input.index_operation.source { + IndexSource::Load(load) => { + check_match(&load.data_type); + } + IndexSource::Build(build) => { + check_match(&build.data_type); + + if build.multi_insert.is_some() { + describeln!( + f, + "- Scalar Quantization does not support multi-insert" + )?; + } + } + } } } + Ok(()) } - } - Ok(()) - } - } - impl<'a, const NBITS: usize, T> DispatchRule<&'a Any> for ScalarQuantized<'a, NBITS, T> - where - datatype::Type: DispatchRule, - { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a Any) -> Result { - from.convert::() - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::(f, from, IndexSQOperation::tag()) - } - } + fn run( + input: &IndexSQOperation, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + let sq = ScalarQuantized::<$N, $T>::new(input); + BuildAndSearch::run(sq, checkpoint, output) + } + } - macro_rules! impl_sq_build { - ($N:literal, $T: ty) => { impl<'a> BuildAndSearch<'a> for ScalarQuantized<'a, $N, $T> { type Data = QuantBuildResult; fn run( diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 82bb37dae..33cb2e2fe 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -14,36 +14,12 @@ crate::utils::stub_impl!( pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { const NAME: &str = "async-spherical-quantization"; - // Spherical - requires feature "spherical-quantization" #[cfg(feature = "spherical-quantization")] - benchmarks.register::>(NAME, |object, checkpoint, output| { - use crate::backend::index::benchmarks::BuildAndSearch; - - match object.run(checkpoint, output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - - #[cfg(feature = "spherical-quantization")] - benchmarks.register::>(NAME, |object, checkpoint, output| { - use crate::backend::index::benchmarks::BuildAndSearch; - - match object.run(checkpoint, output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); - - #[cfg(feature = "spherical-quantization")] - benchmarks.register::>(NAME, |object, checkpoint, output| { - use crate::backend::index::benchmarks::BuildAndSearch; - - match object.run(checkpoint, output) { - Ok(v) => Ok(serde_json::to_value(v)?), - Err(err) => Err(err), - } - }); + { + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + benchmarks.register::>(NAME); + } // Stub implementation #[cfg(not(feature = "spherical-quantization"))] @@ -60,9 +36,9 @@ mod imp { use diskann_benchmark_core as benchmark_core; use diskann_benchmark_runner::{ describeln, - dispatcher::{self, DispatchRule, FailureScore, MatchScore}, + dispatcher::{DispatchRule, FailureScore, MatchScore}, utils::{datatype, MicroSeconds}, - Any, Checkpoint, Output, + Benchmark, Checkpoint, Output, }; use diskann_providers::{ index::diskann_async::{self}, @@ -102,100 +78,6 @@ mod imp { } } - impl dispatcher::Map for SphericalQ<'static, NBITS> { - type Type<'a> = SphericalQ<'a, NBITS>; - } - - impl<'a, const NBITS: usize> DispatchRule<&'a SphericalQuantBuild> for SphericalQ<'a, NBITS> { - type Error = std::convert::Infallible; - - fn try_match(from: &&'a SphericalQuantBuild) -> Result { - // If this is multi-insert, return a very-close failure. - let mut failure_score: Option = None; - if from.build.multi_insert.is_some() { - failure_score = Some(1); - } - - // Ensure the data type is compatible (float32). - if let Err(FailureScore(_)) = datatype::Type::::try_match(&from.build.data_type) { - *failure_score.get_or_insert(0) += 1; - } - - // Match the number of bits. - let num_bits = from.num_bits.get(); - if num_bits != NBITS { - *failure_score.get_or_insert(0) += - NBITS.abs_diff(num_bits).try_into().unwrap_or(u32::MAX); - } - - match failure_score { - None => Ok(MatchScore(0)), - Some(score) => Err(FailureScore(score)), - } - } - - fn convert(from: &'a SphericalQuantBuild) -> Result { - assert_eq!(from.num_bits.get(), NBITS); - Ok(Self::new(from)) - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a SphericalQuantBuild>, - ) -> std::fmt::Result { - match from { - None => { - describeln!( - f, - "- Index Build and Search using {}-bit spherical quantization", - NBITS - )?; - describeln!(f, "- Requires `float32` data")?; - describeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; - describeln!(f, "- Does not support multi-insert")?; - } - Some(input) => { - let num_bits = input.num_bits.get(); - if num_bits != NBITS { - describeln!(f, "- Expected {} bits, got {}", NBITS, num_bits)?; - } - - if input.build.multi_insert.is_some() { - describeln!(f, "- Spherical Quantization does not support multi-insert")?; - } - - if datatype::Type::::try_match(&input.build.data_type).is_err() { - describeln!( - f, - "- Only `float32` data type is supported. Instead, got {}", - input.build.data_type - )?; - } - } - } - Ok(()) - } - } - - impl<'a, const NBITS: usize> DispatchRule<&'a Any> for SphericalQ<'a, NBITS> { - type Error = anyhow::Error; - - fn try_match(from: &&'a Any) -> Result { - from.try_match::() - } - - fn convert(from: &'a Any) -> Result { - from.convert::() - } - - fn description( - f: &mut std::fmt::Formatter<'_>, - from: Option<&&'a Any>, - ) -> std::fmt::Result { - Any::description::(f, from, SphericalQuantBuild::tag()) - } - } - macro_rules! write_field { ($f:ident, $field:tt, $fmt:literal, $($expr:tt)*) => { writeln!($f, concat!("{:>12}: ", $fmt), $field, $($expr)*) @@ -244,6 +126,89 @@ mod imp { macro_rules! build_and_search { ($N:literal) => { + impl Benchmark for SphericalQ<'static, $N> { + type Input = SphericalQuantBuild; + type Output = SphericalBuildResult; + + fn try_match(input: &SphericalQuantBuild) -> Result { + let mut failure_score: Option = None; + if input.build.multi_insert.is_some() { + failure_score = Some(1); + } + + if let Err(FailureScore(_)) = + datatype::Type::::try_match(&input.build.data_type) + { + *failure_score.get_or_insert(0) += 1; + } + + let num_bits = input.num_bits.get(); + if num_bits != $N { + *failure_score.get_or_insert(0) += ($N as usize) + .abs_diff(num_bits) + .try_into() + .unwrap_or(u32::MAX); + } + + match failure_score { + None => Ok(MatchScore(0)), + Some(score) => Err(FailureScore(score)), + } + } + + fn description( + f: &mut std::fmt::Formatter<'_>, + input: Option<&SphericalQuantBuild>, + ) -> std::fmt::Result { + match input { + None => { + describeln!( + f, + "- Index Build and Search using {}-bit spherical quantization", + $N + )?; + describeln!(f, "- Requires `float32` data")?; + describeln!( + f, + "- Implements `squared_l2` or `inner_product` distance", + )?; + describeln!(f, "- Does not support multi-insert")?; + } + Some(input) => { + let num_bits = input.num_bits.get(); + if num_bits != $N { + describeln!(f, "- Expected {} bits, got {}", $N, num_bits)?; + } + + if input.build.multi_insert.is_some() { + describeln!( + f, + "- Spherical Quantization does not support multi-insert" + )?; + } + + if datatype::Type::::try_match(&input.build.data_type).is_err() { + describeln!( + f, + "- Only `float32` data type is supported. Instead, got {}", + input.build.data_type + )?; + } + } + } + Ok(()) + } + + fn run( + input: &SphericalQuantBuild, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + let sq = SphericalQ::<$N>::new(input); + BuildAndSearch::run(sq, checkpoint, output) + } + } + impl<'a> BuildAndSearch<'a> for SphericalQ<'a, $N> { type Data = SphericalBuildResult; fn run( diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 19230977d..c76fdb594 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -25,7 +25,7 @@ use diskann_providers::{ use serde::{Deserialize, Serialize}; use crate::{ - inputs::{self, as_input, save_and_load, Example, Input}, + inputs::{self, as_input, save_and_load, Example}, utils::SimilarityMeasure, }; @@ -42,11 +42,11 @@ as_input!(DynamicIndexRun); pub(super) fn register_inputs( registry: &mut diskann_benchmark_runner::registry::Inputs, ) -> anyhow::Result<()> { - registry.register(Input::::new())?; - registry.register(Input::::new())?; - registry.register(Input::::new())?; - registry.register(Input::::new())?; - registry.register(Input::::new())?; + registry.register::()?; + registry.register::()?; + registry.register::()?; + registry.register::()?; + registry.register::()?; Ok(()) } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..2951d1fe4 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -15,7 +15,7 @@ use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, ge use serde::{Deserialize, Serialize}; use crate::{ - inputs::{as_input, Example, Input}, + inputs::{as_input, Example}, utils::SimilarityMeasure, }; @@ -28,7 +28,7 @@ as_input!(DiskIndexOperation); pub(super) fn register_inputs( registry: &mut diskann_benchmark_runner::registry::Inputs, ) -> anyhow::Result<()> { - registry.register(Input::::new())?; + registry.register::()?; Ok(()) } diff --git a/diskann-benchmark/src/inputs/exhaustive.rs b/diskann-benchmark/src/inputs/exhaustive.rs index 321c859af..e13837de8 100644 --- a/diskann-benchmark/src/inputs/exhaustive.rs +++ b/diskann-benchmark/src/inputs/exhaustive.rs @@ -12,7 +12,7 @@ use diskann_benchmark_runner::{ use serde::{Deserialize, Serialize}; use crate::{ - inputs::{as_input, Example, Input}, + inputs::{as_input, Example}, utils::{datafiles::ConvertingLoad, SimilarityMeasure}, }; @@ -34,9 +34,9 @@ as_input!(MinMax); pub(super) fn register_inputs( registry: &mut diskann_benchmark_runner::registry::Inputs, ) -> anyhow::Result<()> { - registry.register(Input::::new())?; - registry.register(Input::::new())?; - registry.register(Input::::new())?; + registry.register::()?; + registry.register::()?; + registry.register::()?; Ok(()) } diff --git a/diskann-benchmark/src/inputs/filters.rs b/diskann-benchmark/src/inputs/filters.rs index e63b06194..981cef3ce 100644 --- a/diskann-benchmark/src/inputs/filters.rs +++ b/diskann-benchmark/src/inputs/filters.rs @@ -6,7 +6,7 @@ use diskann_benchmark_runner::{files::InputFile, CheckDeserialization, Checker}; use serde::{Deserialize, Serialize}; -use crate::inputs::{as_input, Example, Input}; +use crate::inputs::{as_input, Example}; ////////////// // Registry // @@ -17,7 +17,7 @@ as_input!(MetadataIndexBuild); pub(super) fn register_inputs( registry: &mut diskann_benchmark_runner::registry::Inputs, ) -> anyhow::Result<()> { - registry.register(Input::::new())?; + registry.register::()?; Ok(()) } diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..f5f6c015a 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -19,40 +19,31 @@ pub(crate) fn register_inputs( Ok(()) } -/// A helper type for implementing `diskann_benchmark_runner::Input` for benchmark types. -pub(crate) struct Input { - _type: std::marker::PhantomData, -} - -impl Input { - pub(crate) fn new() -> Self { - Self { - _type: std::marker::PhantomData, - } - } -} - /// Construct an example input of type `Self`. pub(crate) trait Example { fn example() -> Self; } +// NOTE: The input registration and dispatching isn't prefect. It uses a pattern (like +// the use of `'static` on the benchmark types) as a byproduct of older ways of doing +// benchmark selection. +// +// In the future, these can be migrated to reduce this legacy cruft. macro_rules! as_input { ($T:ty) => { - impl diskann_benchmark_runner::Input for $crate::inputs::Input<$T> { - fn tag(&self) -> &'static str { + impl diskann_benchmark_runner::Input for $T { + fn tag() -> &'static str { <$T>::tag() } fn try_deserialize( - &self, serialized: &serde_json::Value, checker: &mut diskann_benchmark_runner::Checker, ) -> anyhow::Result { checker.any(<$T as serde::Deserialize>::deserialize(serialized)?) } - fn example(&self) -> anyhow::Result { + fn example() -> anyhow::Result { Ok(serde_json::to_value( <$T as $crate::inputs::Example>::example(), )?) diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index 761e5b746..ebdac8116 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -102,50 +102,44 @@ macro_rules! stub_impl { mod imp { use diskann_benchmark_runner::{ describeln, - dispatcher::{DispatchRule, FailureScore, MatchScore}, + dispatcher::{FailureScore, MatchScore}, output::Output, registry::Benchmarks, - Any, Checkpoint, + Benchmark, Checkpoint, Input, }; use crate::inputs; pub(super) fn register(name: &str, registry: &mut Benchmarks) { - registry.register::(name, run) + registry.register::(name); } - pub(super) fn run( - _: Stub, - _: Checkpoint<'_>, - _: &mut dyn Output, - ) -> anyhow::Result { - panic!("this function should not be called!"); - } - - // An empty placeholder to provide a hint for the necessary feature. + /// An empty placeholder to provide a hint for the necessary feature. pub(super) struct Stub; - diskann_benchmark_runner::self_map!(Stub); - - impl<'a> DispatchRule<&'a Any> for Stub { - type Error = anyhow::Error; - fn try_match(from: &&'a Any) -> Result { - Err(match from.downcast_ref::<$input>() { - // It's the correct type, but we do not actually have an - // implementation. - Some(_) => FailureScore(0), - None => diskann_benchmark_runner::any::MATCH_FAIL, - }) - } - fn convert(_from: &'a Any) -> Result { - panic!("This should not have been reached. Please file a bug report.") + + impl Benchmark for Stub { + type Input = $input; + type Output = serde_json::Value; + + fn try_match(_input: &$input) -> Result { + Err(FailureScore(0)) } + fn description( f: &mut std::fmt::Formatter<'_>, - _from: Option<&&'a Any>, + _input: Option<&$input>, ) -> std::fmt::Result { - writeln!(f, "tag: \"{}\"", <$input>::tag())?; + writeln!(f, "tag: \"{}\"", <$input as Input>::tag())?; describeln!(f, "{}", concat!("Requires the \"", $feature, "\" feature")) } + + fn run( + _input: &$input, + _checkpoint: Checkpoint<'_>, + _output: &mut dyn Output, + ) -> anyhow::Result { + panic!("this function should not be called!"); + } } } };