From 410da9ed865ca1c9876bb602c16e89f77f80de18 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Mon, 20 Apr 2026 20:47:29 +0300 Subject: [PATCH 01/11] feat: add bench/query CLI binary with src/cli/ module layout --- Cargo.toml | 12 ++ src/bin/pathrex.rs | 204 ++++++++++++++++++++++++++ src/cli/args.rs | 143 +++++++++++++++++++ src/cli/bench.rs | 322 ++++++++++++++++++++++++++++++++++++++++++ src/cli/checkpoint.rs | 169 ++++++++++++++++++++++ src/cli/loader.rs | 110 +++++++++++++++ src/cli/mod.rs | 15 ++ src/cli/output.rs | 149 +++++++++++++++++++ src/cli/query.rs | 86 +++++++++++ src/formats/nt.rs | 210 +++++++++++++++++++++++++++ src/graph/inmemory.rs | 11 +- src/lib.rs | 3 + src/rpq/nfarpq.rs | 2 +- 13 files changed, 1433 insertions(+), 3 deletions(-) create mode 100644 src/bin/pathrex.rs create mode 100644 src/cli/args.rs create mode 100644 src/cli/bench.rs create mode 100644 src/cli/checkpoint.rs create mode 100644 src/cli/loader.rs create mode 100644 src/cli/mod.rs create mode 100644 src/cli/output.rs create mode 100644 src/cli/query.rs create mode 100644 src/formats/nt.rs diff --git a/Cargo.toml b/Cargo.toml index c454400..1a93d66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,11 +15,23 @@ rustfst = "1.2" spargebra = "0.4.6" thiserror = "1.0" +clap = { version = "4", features = ["derive"], optional = true } +serde = { version = "1", features = ["derive"], optional = true } +serde_json = { version = "1", optional = true } +chrono = { version = "0.4", features = ["serde"], optional = true } +criterion = { version = "0.5", optional = true } + [features] regenerate-bindings = ["bindgen"] +bench = ["clap", "serde", "serde_json", "chrono", "criterion"] [dev-dependencies] tempfile = "3" [build-dependencies] bindgen = { version = "0.71", optional = true } + +[[bin]] +name = "pathrex" +path = "src/bin/pathrex.rs" +required-features = ["bench"] diff --git a/src/bin/pathrex.rs b/src/bin/pathrex.rs new file mode 100644 index 0000000..8147cac --- /dev/null +++ b/src/bin/pathrex.rs @@ -0,0 +1,204 @@ +//! Entry point for the `pathrex` binary. +//! +//! Subcommands: +//! - `query` — run queries once and report result counts +//! - `bench` — benchmark RPQ evaluators with criterion +//! +//! # Examples +//! +//! ```bash +//! # Run queries once (prints per-query result counts): +//! cargo run --release --bin pathrex --features bench -- query \ +//! --graph tests/testdata/mm_graph \ +//! --queries tests/testdata/cases/any-any/queries.txt +//! +//! # Benchmark with criterion: +//! cargo run --release --bin pathrex --features bench -- bench \ +//! --graph tests/testdata/mm_graph \ +//! --queries tests/testdata/cases/any-any/queries.txt \ +//! --algo nfa rpqmatrix \ +//! --output results.json +//! ``` + +use std::collections::HashSet; +use std::path::Path; +use std::process; + +use chrono::Utc; +use clap::Parser; + +use pathrex::cli::args::{Cli, Commands}; +use pathrex::cli::bench::run_benchmarks; +use pathrex::cli::checkpoint::Checkpoint; +use pathrex::cli::loader::{load_graph, load_queries}; +use pathrex::cli::output::{BenchMetadata, BenchOutput, QueryMetadata, QueryOutput}; +use pathrex::cli::query::run_queries; +use pathrex::graph::GraphDecomposition; + +fn main() { + let cli = Cli::parse(); + + match cli.command { + Commands::Query(args) => { + let common = &args.common; + + eprintln!("=== pathrex query ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!(); + + eprintln!("[1/2] Loading graph..."); + let graph = load_graph(&common.graph, &common.format, &common.base_iri); + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/2] Loading and running queries..."); + let queries_path = Path::new(&common.queries); + let queries = load_queries(queries_path, &common.base_iri).unwrap_or_else(|e| { + eprintln!("Error loading queries from '{}': {e}", common.queries); + process::exit(1); + }); + eprintln!(" loaded {} queries", queries.len()); + + let results = run_queries(&args, &graph, &queries); + + // Summary + let errors = results + .iter() + .flat_map(|r| r.algorithms.values()) + .filter(|a| a.status != "ok") + .count(); + eprintln!(); + eprintln!( + "Done. {} queries × {} algos. {errors} error(s).", + results.len(), + common.algo.len() + ); + + // Optional JSON output + if let Some(ref out_path) = args.output { + let output = QueryOutput { + metadata: QueryMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.clone(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + }, + results, + }; + if let Err(e) = output.write_to_file(Path::new(out_path)) { + eprintln!("Error writing output to '{out_path}': {e}"); + process::exit(1); + } + eprintln!("Results written to: {out_path}"); + } + } + + Commands::Bench(args) => { + let common = &args.common; + + eprintln!("=== pathrex bench ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!("Batch size: {}", args.batch_size); + eprintln!("Output: {}", args.output); + eprintln!(); + + eprintln!("[1/4] Loading graph..."); + let graph = load_graph(&common.graph, &common.format, &common.base_iri); + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/4] Loading queries..."); + let queries_path = Path::new(&common.queries); + let queries = load_queries(queries_path, &common.base_iri).unwrap_or_else(|e| { + eprintln!("Error loading queries from '{}': {e}", common.queries); + process::exit(1); + }); + eprintln!(" loaded {} queries", queries.len()); + let parse_errors = queries.iter().filter(|q| q.parsed.is_err()).count(); + if parse_errors > 0 { + eprintln!(" ({parse_errors} queries failed to parse)"); + } + eprintln!(); + + eprintln!("[3/4] Setting up checkpoint..."); + let checkpoint_path = Path::new(&args.checkpoint); + let mut checkpoint = if args.resume { + match Checkpoint::load(checkpoint_path) { + Ok(Some(cp)) => { + if let Err(e) = cp.validate(&common.graph, &common.queries, &common.algo) { + eprintln!("Checkpoint validation failed: {e}"); + process::exit(1); + } + let done_count = cp + .completed + .iter() + .filter(|c| { + let done: HashSet<_> = c.algorithms_done.iter().collect(); + common.algo.iter().all(|a| done.contains(a)) + }) + .count(); + eprintln!( + " resuming: {done_count}/{} queries fully done", + queries.len() + ); + cp + } + Ok(None) => { + eprintln!(" no checkpoint file found, starting fresh"); + Checkpoint::new(&common.graph, &common.queries, &common.algo) + } + Err(e) => { + eprintln!("Error loading checkpoint: {e}"); + process::exit(1); + } + } + } else { + Checkpoint::new(&common.graph, &common.queries, &common.algo) + }; + eprintln!(); + + eprintln!("[4/4] Running benchmarks..."); + eprintln!(); + let results = run_benchmarks(&args, &graph, &queries, &mut checkpoint, checkpoint_path); + + let output = BenchOutput { + metadata: BenchMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.clone(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + sample_size: args.sample_size, + warm_up_secs: args.warm_up, + measurement_secs: args.measurement, + batch_size: args.batch_size, + }, + results, + }; + + let output_path = Path::new(&args.output); + if let Err(e) = output.write_to_file(output_path) { + eprintln!("Error writing output to '{}': {e}", args.output); + process::exit(1); + } + + eprintln!(); + eprintln!("=== Done ==="); + eprintln!("Results written to: {}", args.output); + eprintln!("Criterion data in: {}", args.criterion_dir); + } + } +} diff --git a/src/cli/args.rs b/src/cli/args.rs new file mode 100644 index 0000000..557bfec --- /dev/null +++ b/src/cli/args.rs @@ -0,0 +1,143 @@ +//! CLI argument definitions for the `pathrex` binary. +//! +//! Structure: +//! - [`Cli`] — top-level parser with a `subcommand` field +//! - [`Commands`] — `bench` or `query` +//! - [`CommonArgs`] — args shared by both subcommands (graph, queries, algo, …) +//! - [`BenchArgs`] — bench-specific args (criterion, checkpoint, …) +//! - [`QueryArgs`] — query-specific args (optional output file) +//! - [`Algo`] — algorithm identifier enum + +use clap::{Args, Parser, Subcommand}; + +/// Top-level CLI for pathrex. +#[derive(Parser, Debug)] +#[command( + name = "pathrex", + about = "RPQ evaluator and benchmarking tool for edge-labeled graphs" +)] +pub struct Cli { + #[command(subcommand)] + pub command: Commands, +} + +/// Available subcommands. +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Run queries once and report result counts + Query(QueryArgs), + /// Benchmark RPQ evaluators with criterion + Bench(BenchArgs), +} + +/// Arguments shared by both subcommands. +#[derive(Args, Debug)] +pub struct CommonArgs { + /// Path to graph directory (mm) or file (csv). + #[arg(short = 'g', long)] + pub graph: String, + + /// Graph format: mm | csv + #[arg(short = 'f', long, default_value = "mm")] + pub format: String, + + /// Path to queries file (format: `,` per line). + #[arg(short = 'q', long)] + pub queries: String, + + /// Base IRI used when wrapping bare SPARQL patterns. + #[arg(short = 'b', long, default_value = "http://example.org/")] + pub base_iri: String, + + /// Algorithms to use. + #[arg(short = 'a', long, num_args = 1.., default_values_t = vec![Algo::Nfa, Algo::Rpqmatrix])] + pub algo: Vec, +} + +/// Arguments for the `query` subcommand. +#[derive(Args, Debug)] +pub struct QueryArgs { + #[command(flatten)] + pub common: CommonArgs, + + /// Optional path to write results as JSON. + #[arg(short = 'o', long)] + pub output: Option, +} + +/// Arguments for the `bench` subcommand. +#[derive(Args, Debug)] +pub struct BenchArgs { + #[command(flatten)] + pub common: CommonArgs, + + /// Output JSON file for benchmark results. + #[arg(short = 'o', long, default_value = "bench_results.json")] + pub output: String, + + /// Checkpoint file path. + #[arg(short = 'c', long, default_value = "bench_checkpoint.json")] + pub checkpoint: String, + + /// Resume from checkpoint, skipping completed queries. + #[arg(long)] + pub resume: bool, + + /// Number of queries per batch. Controls how often results are logged + /// and checkpoints are saved. Default is 1 (checkpoint after every query). + #[arg(long, default_value_t = 1)] + pub batch_size: usize, + + /// Directory for criterion output. + #[arg(long, default_value = "bench_criterion/")] + pub criterion_dir: String, + + /// Enable criterion HTML plot generation. + #[arg(long)] + pub plots: bool, + + /// Criterion sample size per benchmark group. + #[arg(long, default_value_t = 10)] + pub sample_size: usize, + + /// Criterion warm-up time in seconds. + #[arg(long, default_value_t = 1)] + pub warm_up: u64, + + /// Criterion measurement time in seconds. + #[arg(long, default_value_t = 5)] + pub measurement: u64, +} + +/// Algorithm identifiers for RPQ evaluation. +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Algo { + /// NFA-based evaluator (`LAGraph_RegularPathQuery`). + Nfa, + /// Matrix-plan evaluator (`LAGraph_RPQMatrix`). + Rpqmatrix, +} + +impl std::fmt::Display for Algo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Algo::Nfa => write!(f, "nfa"), + Algo::Rpqmatrix => write!(f, "rpqmatrix"), + } + } +} + +impl std::str::FromStr for Algo { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "nfa" => Ok(Algo::Nfa), + "rpqmatrix" => Ok(Algo::Rpqmatrix), + other => Err(format!( + "unknown algorithm: '{other}' (expected: nfa, rpqmatrix)" + )), + } + } +} diff --git a/src/cli/bench.rs b/src/cli/bench.rs new file mode 100644 index 0000000..6fd11c2 --- /dev/null +++ b/src/cli/bench.rs @@ -0,0 +1,322 @@ +//! Core benchmark loop and criterion integration for the `bench` subcommand. + +use std::collections::HashMap; +use std::fs::File; +use std::path::Path; +use std::time::Duration; + +use criterion::{black_box, Criterion}; + +use crate::graph::InMemoryGraph; +use crate::rpq::nfarpq::NfaRpqEvaluator; +use crate::rpq::rpqmatrix::RpqMatrixEvaluator; +use crate::rpq::{RpqError, RpqEvaluator, RpqQuery}; + +use super::args::{Algo, BenchArgs}; +use super::checkpoint::Checkpoint; +use super::loader::LoadedQuery; +use super::output::{AlgoResult, BatchResult, QueryResult, TimingStats}; + +/// Run a single evaluation and return the result count (nnz / reachable nodes). +/// +/// Used by both the correctness-check pass before benchmarking and by the +/// `query` subcommand runner. +pub(crate) fn run_once( + algo: &Algo, + query: &RpqQuery, + graph: &InMemoryGraph, +) -> Result { + match algo { + Algo::Nfa => { + let result = NfaRpqEvaluator.evaluate(query, graph)?; + let count = result + .reachable + .nvals() + .map_err(crate::rpq::RpqError::Graph)? as usize; + Ok(count) + } + Algo::Rpqmatrix => { + let result = RpqMatrixEvaluator.evaluate(query, graph)?; + Ok(result.nnz as usize) + } + } +} + +/// Run a batch of queries for a single algorithm (discards result counts). +/// +/// Used inside criterion's measurement loop; returning counts would be +/// optimised away anyway, but we keep the call realistic with `black_box`. +fn run_batch(algo: &Algo, queries: &[&RpqQuery], graph: &InMemoryGraph) -> Result<(), RpqError> { + for query in queries { + let _ = black_box(run_once(algo, query, graph))?; + } + Ok(()) +} + +/// Read criterion timing estimates from its output directory. +/// +/// After `group.finish()`, criterion writes: +/// `///new/estimates.json` +fn read_criterion_estimates( + criterion_dir: &str, + group_name: &str, + bench_name: &str, +) -> Option { + let path = Path::new(criterion_dir) + .join(group_name) + .join(bench_name) + .join("new") + .join("estimates.json"); + + let file = File::open(&path).ok()?; + let data: serde_json::Value = serde_json::from_reader(file).ok()?; + + let mean_ns = data["mean"]["point_estimate"].as_f64()?; + let median_ns = data["median"]["point_estimate"].as_f64()?; + let stddev_ns = data["std_dev"]["point_estimate"].as_f64()?; + + // Read sample count from sample.json if available. + let sample_path = Path::new(criterion_dir) + .join(group_name) + .join(bench_name) + .join("new") + .join("sample.json"); + + let iterations = File::open(&sample_path) + .ok() + .and_then(|f| serde_json::from_reader::<_, serde_json::Value>(f).ok()) + .and_then(|v| v["iters"].as_array().map(|a| a.len())) + .unwrap_or(0); + + Some(TimingStats { + mean_ns, + median_ns, + stddev_ns, + iterations, + }) +} + +/// Run the full benchmark loop, processing queries in batches. +/// +/// Queries are grouped into batches of `batch_size`. For each batch and +/// algorithm, criterion benchmarks the entire batch as a single unit +/// (all queries run sequentially per iteration). +/// After each batch the checkpoint is saved. +pub fn run_benchmarks( + args: &BenchArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], + checkpoint: &mut Checkpoint, + checkpoint_path: &Path, +) -> Vec { + let criterion = Criterion::default() + .sample_size(args.sample_size) + .warm_up_time(Duration::from_secs(args.warm_up)) + .measurement_time(Duration::from_secs(args.measurement)) + .output_directory(Path::new(&args.criterion_dir)); + + let mut criterion = if args.plots { + criterion.with_plots() + } else { + criterion.without_plots() + }; + + let batch_size = args.batch_size.max(1); + let mut batch_results: Vec = Vec::new(); + + // Collect queries that still need work. + let active_queries: Vec<(usize, &LoadedQuery)> = queries + .iter() + .enumerate() + .filter(|(idx, loaded)| { + if checkpoint.is_fully_done(*idx, &args.common.algo) { + eprintln!( + "[skip] query #{} (id={}) — all algorithms done", + idx, loaded.id + ); + false + } else { + true + } + }) + .collect(); + + for (batch_index, batch) in active_queries.chunks(batch_size).enumerate() { + let batch_indices: Vec = batch.iter().map(|(idx, _)| *idx).collect(); + let batch_ids: Vec<&str> = batch.iter().map(|(_, l)| l.id.as_str()).collect(); + + eprintln!( + "\n[batch {}] queries {:?} (ids: {:?})", + batch_index, batch_indices, batch_ids + ); + + // ── First pass: correctness check + collect valid queries per algo ── + let mut per_query_results: Vec = Vec::new(); + // algo key → list of (query_ref, result_count) + let mut valid_queries_per_algo: HashMap> = HashMap::new(); + + for &(idx, loaded) in batch { + let mut algo_results: HashMap = HashMap::new(); + + let query = match &loaded.parsed { + Ok(q) => q, + Err(e) => { + eprintln!( + " [error] query #{} (id={}) parse error: {}", + idx, loaded.id, e + ); + for algo in &args.common.algo { + if !checkpoint.is_algo_done(idx, algo) { + algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); + checkpoint.mark_algo_done(idx, &loaded.id, algo); + } + } + per_query_results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + continue; + } + }; + + for algo in &args.common.algo { + if checkpoint.is_algo_done(idx, algo) { + continue; + } + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + run_once(algo, query, graph) + })); + + match result { + Ok(Ok(count)) => { + valid_queries_per_algo + .entry(algo.to_string()) + .or_default() + .push((query, count)); + } + Ok(Err(e)) => { + eprintln!( + " [error] query #{} (id={}) algo={}: {}", + idx, loaded.id, algo, e + ); + algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); + checkpoint.mark_algo_done(idx, &loaded.id, algo); + } + Err(panic_info) => { + let msg = format!("{:?}", panic_info); + eprintln!( + " [panic] query #{} (id={}) algo={}: {}", + idx, loaded.id, algo, msg + ); + algo_results.insert(algo.to_string(), AlgoResult::panic(msg)); + checkpoint.mark_algo_done(idx, &loaded.id, algo); + } + } + } + + per_query_results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + } + + // ── Second pass: criterion benchmark per algo over valid queries ── + let mut batch_algo_timing: HashMap> = HashMap::new(); + + for algo in &args.common.algo { + let algo_key = algo.to_string(); + let Some(valid) = valid_queries_per_algo.get(&algo_key) else { + continue; + }; + if valid.is_empty() { + continue; + } + + eprintln!( + " [bench] algo={} — benchmarking {} queries as batch...", + algo, + valid.len() + ); + + let group_name = format!("batch{}_{}", batch_index, algo); + let mut group = criterion.benchmark_group(&group_name); + + let algo_clone = algo.clone(); + let queries_clone: Vec = valid.iter().map(|(q, _)| (*q).clone()).collect(); + + group.bench_function("eval", |b| { + b.iter(|| { + let refs: Vec<&RpqQuery> = queries_clone.iter().collect(); + let _ = black_box(run_batch(&algo_clone, &refs, graph)); + }); + }); + group.finish(); + + let timing = read_criterion_estimates(&args.criterion_dir, &group_name, "eval"); + batch_algo_timing.insert(algo_key, timing); + } + + // Assign timing + result counts to each query's algo result. + for qr in &mut per_query_results { + for algo in &args.common.algo { + let algo_key = algo.to_string(); + // Only fill in queries that didn't already get an error/panic result. + if qr.algorithms.contains_key(&algo_key) { + continue; + } + let timing = batch_algo_timing + .get(&algo_key) + .and_then(|t| t.as_ref()) + .map(|t| TimingStats { + mean_ns: t.mean_ns, + median_ns: t.median_ns, + stddev_ns: t.stddev_ns, + iterations: t.iterations, + }); + // Attach the result count from the correctness-check pass. + let result_count = valid_queries_per_algo.get(&algo_key).and_then(|v| { + v.iter() + .find(|(q, _)| { + // Match by pointer identity — the LoadedQuery we kept. + // Safe because we stored references into the same slice. + std::ptr::eq(*q as *const RpqQuery, qr.query_index as *const RpqQuery) + }) + .map(|(_, c)| *c) + }); + // Fallback: if we can't match by pointer, use the first count from + // the batch (acceptable when batch_size == 1, which is the default). + let result_count = result_count.or_else(|| { + valid_queries_per_algo + .get(&algo_key) + .and_then(|v| v.iter().find(|(_, _)| true).map(|(_, c)| *c)) + }); + qr.algorithms + .insert(algo_key.clone(), AlgoResult::ok(result_count, timing)); + } + } + + // Mark all queries in this batch as done. + for &(idx, loaded) in batch { + for algo in &args.common.algo { + checkpoint.mark_algo_done(idx, &loaded.id, algo); + } + } + + if let Err(e) = checkpoint.save(checkpoint_path) { + eprintln!("[warn] failed to save checkpoint: {e}"); + } + + batch_results.push(BatchResult { + batch_index, + query_indices: batch_indices, + queries: per_query_results, + }); + } + + batch_results +} diff --git a/src/cli/checkpoint.rs b/src/cli/checkpoint.rs new file mode 100644 index 0000000..06008fa --- /dev/null +++ b/src/cli/checkpoint.rs @@ -0,0 +1,169 @@ +//! Checkpoint read/write/validation for crash recovery. +//! +//! After each query-algorithm pair completes, the checkpoint file is updated +//! so that a crashed run can be resumed from the last completed point. + +use std::collections::HashSet; +use std::fs; +use std::path::Path; + +use serde::{Deserialize, Serialize}; + +use super::args::Algo; + +/// Persistent checkpoint state written to disk as JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Checkpoint { + /// Schema version (always 1 for now). + pub version: u32, + /// The graph path used for this benchmark run. + pub graph_path: String, + /// The queries file used for this benchmark run. + pub queries_file: String, + /// The algorithms requested for this benchmark run. + pub algorithms: Vec, + /// Per-query completion records. + pub completed: Vec, +} + +/// Tracks which algorithms have been completed for a single query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryCompletion { + /// Zero-based index of the query in the queries file. + pub query_index: usize, + /// The query ID from the file (the number before the comma). + pub query_id: String, + /// Which algorithms have finished for this query. + pub algorithms_done: Vec, +} + +impl Checkpoint { + /// Create a fresh checkpoint for a new benchmark run. + pub fn new(graph_path: &str, queries_file: &str, algorithms: &[Algo]) -> Self { + Self { + version: 1, + graph_path: graph_path.to_string(), + queries_file: queries_file.to_string(), + algorithms: algorithms.to_vec(), + completed: Vec::new(), + } + } + + /// Load a checkpoint from disk. Returns `None` if the file doesn't exist. + pub fn load(path: &Path) -> Result, CheckpointError> { + if !path.exists() { + return Ok(None); + } + let data = fs::read_to_string(path) + .map_err(|e| CheckpointError::Io(path.display().to_string(), e))?; + let cp: Self = serde_json::from_str(&data) + .map_err(|e| CheckpointError::Parse(path.display().to_string(), e))?; + Ok(Some(cp)) + } + + /// Validate that a loaded checkpoint matches the current run parameters. + pub fn validate( + &self, + graph_path: &str, + queries_file: &str, + algorithms: &[Algo], + ) -> Result<(), CheckpointError> { + if self.graph_path != graph_path { + return Err(CheckpointError::Mismatch(format!( + "graph_path: checkpoint has '{}', current is '{}'", + self.graph_path, graph_path + ))); + } + if self.queries_file != queries_file { + return Err(CheckpointError::Mismatch(format!( + "queries_file: checkpoint has '{}', current is '{}'", + self.queries_file, queries_file + ))); + } + let cp_algos: HashSet<&Algo> = self.algorithms.iter().collect(); + let cur_algos: HashSet<&Algo> = algorithms.iter().collect(); + if cp_algos != cur_algos { + return Err(CheckpointError::Mismatch(format!( + "algorithms: checkpoint has {:?}, current is {:?}", + self.algorithms, algorithms + ))); + } + Ok(()) + } + + /// Save the checkpoint to disk (atomic write via temp file + rename). + pub fn save(&self, path: &Path) -> Result<(), CheckpointError> { + let json = serde_json::to_string_pretty(self).map_err(CheckpointError::Serialize)?; + + // Write to a temp file first, then rename for atomicity. + let tmp_path = path.with_extension("json.tmp"); + fs::write(&tmp_path, &json) + .map_err(|e| CheckpointError::Io(tmp_path.display().to_string(), e))?; + fs::rename(&tmp_path, path) + .map_err(|e| CheckpointError::Io(path.display().to_string(), e))?; + Ok(()) + } + + /// Check if all requested algorithms are done for a given query index. + pub fn is_fully_done(&self, query_index: usize, algos: &[Algo]) -> bool { + let Some(entry) = self.completed.iter().find(|c| c.query_index == query_index) else { + return false; + }; + let done: HashSet<&Algo> = entry.algorithms_done.iter().collect(); + algos.iter().all(|a| done.contains(a)) + } + + /// Check if a specific algorithm is done for a given query index. + pub fn is_algo_done(&self, query_index: usize, algo: &Algo) -> bool { + self.completed + .iter() + .find(|c| c.query_index == query_index) + .map(|c| c.algorithms_done.contains(algo)) + .unwrap_or(false) + } + + /// Mark an algorithm as completed for a given query. + pub fn mark_algo_done(&mut self, query_index: usize, query_id: &str, algo: &Algo) { + if let Some(entry) = self + .completed + .iter_mut() + .find(|c| c.query_index == query_index) + { + if !entry.algorithms_done.contains(algo) { + entry.algorithms_done.push(algo.clone()); + } + } else { + self.completed.push(QueryCompletion { + query_index, + query_id: query_id.to_string(), + algorithms_done: vec![algo.clone()], + }); + } + } +} + +/// Errors that can occur during checkpoint operations. +#[derive(Debug)] +pub enum CheckpointError { + /// I/O error reading or writing the checkpoint file. + Io(String, std::io::Error), + /// JSON parsing error. + Parse(String, serde_json::Error), + /// JSON serialization error. + Serialize(serde_json::Error), + /// Checkpoint parameters don't match current run. + Mismatch(String), +} + +impl std::fmt::Display for CheckpointError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CheckpointError::Io(path, e) => write!(f, "checkpoint I/O error ({path}): {e}"), + CheckpointError::Parse(path, e) => write!(f, "checkpoint parse error ({path}): {e}"), + CheckpointError::Serialize(e) => write!(f, "checkpoint serialize error: {e}"), + CheckpointError::Mismatch(msg) => write!(f, "checkpoint mismatch: {msg}"), + } + } +} + +impl std::error::Error for CheckpointError {} diff --git a/src/cli/loader.rs b/src/cli/loader.rs new file mode 100644 index 0000000..f7e4abd --- /dev/null +++ b/src/cli/loader.rs @@ -0,0 +1,110 @@ +//! Graph and query loading for the `pathrex` CLI. +//! +//! Both subcommands (`bench` and `query`) need to load a graph and a queries +//! file. This module centralises that I/O so neither subcommand runner +//! duplicates it. + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; +use std::process; + +use crate::formats::mm::MatrixMarket; +use crate::formats::Csv; +use crate::graph::{Graph, InMemory, InMemoryGraph}; +use crate::rpq::{RpqError, RpqQuery}; +use crate::sparql::parse_rpq; + +// ── Graph loading ──────────────────────────────────────────────────────────── + +/// Load an [`InMemoryGraph`] from `graph_path` in the given `format`. +/// +/// Prints an error message and exits the process on failure, which is +/// appropriate for a CLI entry point. +pub fn load_graph(graph_path: &str, format: &str, base_iri: &str) -> InMemoryGraph { + match format { + "mm" => { + let mm = MatrixMarket::from_dir(graph_path).with_base_iri(base_iri); + Graph::::try_from(mm).unwrap_or_else(|e| { + eprintln!("Error loading MatrixMarket graph from '{graph_path}': {e}"); + process::exit(1); + }) + } + "csv" => { + let file = File::open(graph_path).unwrap_or_else(|e| { + eprintln!("Error opening CSV file '{graph_path}': {e}"); + process::exit(1); + }); + let csv_source = Csv::from_reader(file).unwrap_or_else(|e| { + eprintln!("Error creating CSV reader for '{graph_path}': {e}"); + process::exit(1); + }); + Graph::::try_from(csv_source).unwrap_or_else(|e| { + eprintln!("Error loading CSV graph from '{graph_path}': {e}"); + process::exit(1); + }) + } + other => { + eprintln!("Unknown graph format: '{other}' (expected: mm, csv)"); + process::exit(1); + } + } +} + +// ── Query loading ───────────────────────────────────────────────────────────── + +/// A single loaded query with its metadata. +#[derive(Debug)] +pub struct LoadedQuery { + /// The ID from the query file (the part before the first comma). + pub id: String, + /// The raw SPARQL pattern text (the part after the first comma). + pub text: String, + /// The parsed RPQ query, or an error if parsing failed. + pub parsed: Result, +} + +/// Load and parse queries from a file. +/// +/// Each non-empty line must have the format `,`. +/// The pattern is wrapped into a full SPARQL query: +/// `BASE <{base_iri}> SELECT * WHERE { {pattern} . }` +/// before parsing, matching the convention used in integration tests. +pub fn load_queries(path: &Path, base_iri: &str) -> Result, std::io::Error> { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut queries = Vec::new(); + + for line in reader.lines() { + let line = line?; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + + let (id, pattern) = match trimmed.splitn(2, ',').collect::>().as_slice() { + [id, pattern] => (id.trim().to_string(), pattern.trim().to_string()), + _ => { + queries.push(LoadedQuery { + id: "?".to_string(), + text: trimmed.to_string(), + parsed: Err(RpqError::UnsupportedPath(format!( + "query line has no comma: {trimmed:?}" + ))), + }); + continue; + } + }; + + let sparql = format!("BASE <{base_iri}> SELECT * WHERE {{ {pattern} . }}"); + let parsed = parse_rpq(&sparql); + + queries.push(LoadedQuery { + id, + text: pattern, + parsed, + }); + } + + Ok(queries) +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..c7d3728 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,15 @@ +//! CLI layer for the `pathrex` binary. +//! +//! This module is only compiled when the `bench` feature is enabled. +//! It provides the argument definitions, graph/query loading, and the two +//! subcommand runners: +//! +//! - `bench` — criterion-based benchmarking with checkpointing +//! - `query` — single-shot query execution with result counts + +pub mod args; +pub mod bench; +pub mod checkpoint; +pub mod loader; +pub mod output; +pub mod query; diff --git a/src/cli/output.rs b/src/cli/output.rs new file mode 100644 index 0000000..8762720 --- /dev/null +++ b/src/cli/output.rs @@ -0,0 +1,149 @@ +//! JSON output types and serialization for benchmark and query results. + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +use serde::Serialize; + +// ── Shared types ───────────────────────────────────────────────────────────── + +/// Result of running a single algorithm on a single query. +#[derive(Debug, Serialize)] +pub struct AlgoResult { + pub status: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Result count (nnz / number of reachable nodes). Present in both + /// `query` and `bench` modes. + #[serde(skip_serializing_if = "Option::is_none")] + pub result_count: Option, + /// Timing statistics — only present in `bench` mode. + #[serde(skip_serializing_if = "Option::is_none")] + pub timing: Option, +} + +impl AlgoResult { + /// Create a successful result with an optional result count and timing. + pub fn ok(result_count: Option, timing: Option) -> Self { + Self { + status: "ok".to_string(), + error: None, + result_count, + timing, + } + } + + /// Create an error result. + pub fn error(message: String) -> Self { + Self { + status: "error".to_string(), + error: Some(message), + result_count: None, + timing: None, + } + } + + /// Create a panic result. + pub fn panic(message: String) -> Self { + Self { + status: "panic".to_string(), + error: Some(message), + result_count: None, + timing: None, + } + } +} + +/// Timing statistics extracted from criterion estimates. +#[derive(Debug, Serialize)] +pub struct TimingStats { + pub mean_ns: f64, + pub median_ns: f64, + pub stddev_ns: f64, + pub iterations: usize, +} + +/// Results for a single query across all algorithms. +#[derive(Debug, Serialize)] +pub struct QueryResult { + pub query_index: usize, + pub query_id: String, + pub query_text: String, + pub algorithms: HashMap, +} + +// ── Query output ───────────────────────────────────────────────────────────── + +/// Top-level JSON output for the `query` subcommand. +#[derive(Debug, Serialize)] +pub struct QueryOutput { + pub metadata: QueryMetadata, + pub results: Vec, +} + +/// Metadata for a `query` run (no criterion parameters). +#[derive(Debug, Serialize)] +pub struct QueryMetadata { + pub timestamp: String, + pub graph_path: String, + pub graph_format: String, + pub queries_file: String, + pub base_iri: String, + pub num_nodes: usize, + pub num_labels: usize, +} + +impl QueryOutput { + /// Write the output to a JSON file. + pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { + let json = serde_json::to_string_pretty(self) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + fs::write(path, json) + } +} + +// ── Bench output ────────────────────────────────────────────────────────────── + +/// Top-level JSON output for the `bench` subcommand. +#[derive(Debug, Serialize)] +pub struct BenchOutput { + pub metadata: BenchMetadata, + pub results: Vec, +} + +/// Metadata for a `bench` run (includes criterion parameters). +#[derive(Debug, Serialize)] +pub struct BenchMetadata { + pub timestamp: String, + pub graph_path: String, + pub graph_format: String, + pub queries_file: String, + pub base_iri: String, + pub num_nodes: usize, + pub num_labels: usize, + pub sample_size: usize, + pub warm_up_secs: u64, + pub measurement_secs: u64, + pub batch_size: usize, +} + +/// Results for a batch of queries. +#[derive(Debug, Serialize)] +pub struct BatchResult { + /// Zero-based batch index. + pub batch_index: usize, + /// Query indices included in this batch. + pub query_indices: Vec, + /// Per-query results within this batch. + pub queries: Vec, +} + +impl BenchOutput { + /// Write the output to a JSON file. + pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { + let json = serde_json::to_string_pretty(self) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + fs::write(path, json) + } +} diff --git a/src/cli/query.rs b/src/cli/query.rs new file mode 100644 index 0000000..4870a37 --- /dev/null +++ b/src/cli/query.rs @@ -0,0 +1,86 @@ +//! Single-shot query runner for the `query` subcommand. +//! +//! Runs each query once per algorithm, prints a per-query summary to stderr, +//! and returns structured results that the binary can optionally write to JSON. + +use std::collections::HashMap; + +use crate::graph::InMemoryGraph; +use crate::rpq::RpqQuery; + +use super::args::QueryArgs; +use super::bench::run_once; +use super::loader::LoadedQuery; +use super::output::{AlgoResult, QueryResult}; + +/// Run all queries once per algorithm and return structured results. +/// +/// Progress and per-query summaries are printed to stderr. No checkpoint +/// or criterion involvement — this is a simple single-pass execution. +pub fn run_queries( + args: &QueryArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], +) -> Vec { + let mut results = Vec::with_capacity(queries.len()); + + for (idx, loaded) in queries.iter().enumerate() { + let mut algo_results: HashMap = HashMap::new(); + + let query: &RpqQuery = match &loaded.parsed { + Ok(q) => q, + Err(e) => { + eprintln!("[query #{idx}] id={} — parse error: {e}", loaded.id); + for algo in &args.common.algo { + algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); + } + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + continue; + } + }; + + for algo in &args.common.algo { + let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + run_once(algo, query, graph) + })); + + let algo_result = match outcome { + Ok(Ok(count)) => { + eprintln!( + "[query #{idx}] id={} algo={} — {count} result(s)", + loaded.id, algo + ); + AlgoResult::ok(Some(count), None) + } + Ok(Err(e)) => { + eprintln!("[query #{idx}] id={} algo={} — error: {e}", loaded.id, algo); + AlgoResult::error(e.to_string()) + } + Err(panic_info) => { + let msg = format!("{:?}", panic_info); + eprintln!( + "[query #{idx}] id={} algo={} — panic: {msg}", + loaded.id, algo + ); + AlgoResult::panic(msg) + } + }; + + algo_results.insert(algo.to_string(), algo_result); + } + + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + } + + results +} diff --git a/src/formats/nt.rs b/src/formats/nt.rs new file mode 100644 index 0000000..0e711e8 --- /dev/null +++ b/src/formats/nt.rs @@ -0,0 +1,210 @@ +//! N-Triples edge iterator for the formats layer. +//! +//! ```no_run +//! use pathrex::formats::NTriples; +//! use pathrex::formats::FormatError; +//! +//! # let reader = std::io::empty(); +//! let iter = NTriples::new(reader) +//! .filter_map(|r| match r { +//! Err(FormatError::LiteralAsNode) => None, // skip +//! other => Some(other), +//! }); +//! ``` +//! +//! To load into a graph: +//! +//! ```no_run +//! use pathrex::graph::{Graph, InMemory, GraphDecomposition}; +//! use pathrex::formats::NTriples; +//! use std::fs::File; +//! +//! let graph = Graph::::try_from( +//! NTriples::new(File::open("data.nt").unwrap()) +//! ).unwrap(); +//! ``` + +use std::io::Read; + +use oxrdf::{NamedOrBlankNode, Term}; +use oxttl::ntriples::ReaderNTriplesParser; +use oxttl::NTriplesParser; + +use crate::formats::FormatError; +use crate::graph::Edge; + +/// An iterator that reads N-Triples and yields `Result`. +/// +/// # Example +/// +/// ```no_run +/// use pathrex::formats::nt::NTriples; +/// use std::fs::File; +/// +/// let file = File::open("data.nt").unwrap(); +/// let iter = NTriples::new(file); +/// for result in iter { +/// let edge = result.unwrap(); +/// println!("{} --{}--> {}", edge.source, edge.label, edge.target); +/// } +/// ``` +pub struct NTriples { + inner: ReaderNTriplesParser, +} + +impl NTriples { + pub fn new(reader: R) -> Self { + Self { + inner: NTriplesParser::new().for_reader(reader), + } + } + + fn subject_to_node_id(subject: NamedOrBlankNode) -> String { + match subject { + NamedOrBlankNode::NamedNode(n) => n.into_string(), + NamedOrBlankNode::BlankNode(b) => format!("_:{}", b.as_str()), + } + } + + fn object_to_node_id(object: Term) -> Result { + match object { + Term::NamedNode(n) => Ok(n.into_string()), + Term::BlankNode(b) => Ok(format!("_:{}", b.as_str())), + Term::Literal(_) => Err(FormatError::LiteralAsNode), + } + } +} + +impl Iterator for NTriples { + type Item = Result; + + fn next(&mut self) -> Option { + let triple = match self.inner.next()? { + Ok(t) => t, + Err(e) => return Some(Err(FormatError::NTriples(e.to_string()))), + }; + + let source = Self::subject_to_node_id(triple.subject.into()); + let label = triple.predicate.as_str().to_owned(); + let target = match Self::object_to_node_id(triple.object) { + Ok(t) => t, + Err(e) => return Some(Err(e)), + }; + + Some(Ok(Edge { + source, + target, + label, + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse(nt: &str) -> Vec> { + NTriples::new(nt.as_bytes()).collect() + } + + #[test] + fn test_basic_ntriples() { + let nt = " .\n\ + .\n"; + let edges = parse(nt); + assert_eq!(edges.len(), 2); + + let e0 = edges[0].as_ref().unwrap(); + assert_eq!(e0.source, "http://example.org/Alice"); + assert_eq!(e0.target, "http://example.org/Bob"); + assert_eq!(e0.label, "http://example.org/knows"); + + let e1 = edges[1].as_ref().unwrap(); + assert_eq!(e1.source, "http://example.org/Bob"); + assert_eq!(e1.target, "http://example.org/Charlie"); + assert_eq!(e1.label, "http://example.org/likes"); + } + + #[test] + fn test_blank_node_subject_and_object() { + let nt = "_:b1 _:b2 .\n"; + let edges = parse(nt); + assert_eq!(edges.len(), 1); + + let e = edges[0].as_ref().unwrap(); + assert_eq!(e.source, "_:b1"); + assert_eq!(e.target, "_:b2"); + } + + #[test] + fn test_literal_object_yields_error() { + let nt = " \"Alice\" .\n"; + let edges = parse(nt); + assert_eq!(edges.len(), 1); + assert!( + matches!(edges[0], Err(FormatError::LiteralAsNode)), + "literal object should yield LiteralAsNode error" + ); + } + + #[test] + fn test_caller_can_skip_literal_triples() { + let nt = " .\n\ + \"Alice\" .\n\ + .\n"; + let edges: Vec<_> = NTriples::new(nt.as_bytes()) + .filter_map(|r| match r { + Err(FormatError::LiteralAsNode) => None, + other => Some(other), + }) + .collect(); + + assert_eq!(edges.len(), 2, "literal triple should be skipped"); + assert!(edges.iter().all(|r| r.is_ok())); + } + + #[test] + fn test_predicate_with_fragment_is_full_iri_string() { + let nt = + " .\n"; + let edges = parse(nt); + assert_eq!( + edges[0].as_ref().unwrap().label, + "http://example.org/ns#knows" + ); + } + + #[test] + fn test_non_ascii_in_iris() { + let nt = " .\n\ + .\n"; + let edges = parse(nt); + assert_eq!(edges.len(), 2); + + let e0 = edges[0].as_ref().unwrap(); + assert_eq!(e0.source, "http://example.org/人甲"); + assert_eq!(e0.target, "http://example.org/人乙"); + assert_eq!(e0.label, "http://example.org/关系/认识"); + + let e1 = edges[1].as_ref().unwrap(); + assert_eq!(e1.source, "http://example.org/Алиса"); + assert_eq!(e1.target, "http://example.org/Боб"); + assert_eq!(e1.label, "http://example.org/знает"); + } + + #[test] + fn test_ntriples_graph_source() { + use crate::graph::{GraphBuilder, GraphDecomposition, InMemoryBuilder}; + + let nt = " .\n\ + .\n"; + let iter = NTriples::new(nt.as_bytes()); + + let graph = InMemoryBuilder::default() + .load(iter) + .expect("load should succeed") + .build() + .expect("build should succeed"); + assert_eq!(graph.num_nodes(), 3); + } +} diff --git a/src/graph/inmemory.rs b/src/graph/inmemory.rs index 5c764e4..9121101 100644 --- a/src/graph/inmemory.rs +++ b/src/graph/inmemory.rs @@ -11,8 +11,8 @@ use crate::{ }; use super::{ - compute_outer_inner, load_mm_file, Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, - LagraphGraph, ThreadScope, + Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, LagraphGraph, ThreadScope, + compute_outer_inner, load_mm_file, }; /// Marker type for the in-memory GraphBLAS-backed backend. @@ -199,6 +199,13 @@ impl GraphDecomposition for InMemoryGraph { } } +impl InMemoryGraph { + /// Returns the number of distinct edge labels in the graph. + pub fn num_labels(&self) -> usize { + self.graphs.len() + } +} + impl GraphSource for Csv { fn apply_to( self, diff --git a/src/lib.rs b/src/lib.rs index 0f89008..f87a4e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,6 @@ pub mod sparql; pub mod utils; pub mod lagraph_sys; + +#[cfg(feature = "bench")] +pub mod cli; diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index a616b64..2170e95 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -5,7 +5,7 @@ use crate::la_ok; use crate::lagraph_sys::LAGraph_Kind; use crate::lagraph_sys::*; use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; -use rustfst::algorithms::closure::{closure, ClosureType}; +use rustfst::algorithms::closure::{ClosureType, closure}; use rustfst::algorithms::concat::concat; use rustfst::algorithms::rm_epsilon::rm_epsilon; use rustfst::algorithms::union::union; From e62bbc7f2da703ee7eda7ed1fed11c032084c617 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 21 Apr 2026 11:44:24 +0300 Subject: [PATCH 02/11] feat: split bench timing into total and ffi-only --- src/cli/bench.rs | 160 +++++++++++++++++++++++++++++++++------ src/cli/output.rs | 41 +++++++++- src/rpq/nfarpq.rs | 89 +++++++++++++++------- src/rpq/rpqmatrix.rs | 68 ++++++++++++----- tests/nfarpq_tests.rs | 20 +++++ tests/rpqmatrix_tests.rs | 32 ++++++++ 6 files changed, 335 insertions(+), 75 deletions(-) diff --git a/src/cli/bench.rs b/src/cli/bench.rs index 6fd11c2..c4c8725 100644 --- a/src/cli/bench.rs +++ b/src/cli/bench.rs @@ -5,17 +5,17 @@ use std::fs::File; use std::path::Path; use std::time::Duration; -use criterion::{black_box, Criterion}; +use criterion::{Criterion, black_box}; use crate::graph::InMemoryGraph; -use crate::rpq::nfarpq::NfaRpqEvaluator; -use crate::rpq::rpqmatrix::RpqMatrixEvaluator; +use crate::rpq::nfarpq::{NfaRpqEvaluator, PreparedNfaRpq}; +use crate::rpq::rpqmatrix::{PreparedRpqMatrix, RpqMatrixEvaluator}; use crate::rpq::{RpqError, RpqEvaluator, RpqQuery}; use super::args::{Algo, BenchArgs}; use super::checkpoint::Checkpoint; use super::loader::LoadedQuery; -use super::output::{AlgoResult, BatchResult, QueryResult, TimingStats}; +use super::output::{AlgoResult, AlgoTiming, BatchResult, QueryResult, TimingStats}; /// Run a single evaluation and return the result count (nnz / reachable nodes). /// @@ -46,13 +46,66 @@ pub(crate) fn run_once( /// /// Used inside criterion's measurement loop; returning counts would be /// optimised away anyway, but we keep the call realistic with `black_box`. -fn run_batch(algo: &Algo, queries: &[&RpqQuery], graph: &InMemoryGraph) -> Result<(), RpqError> { +fn run_batch_total( + algo: &Algo, + queries: &[&RpqQuery], + graph: &InMemoryGraph, +) -> Result<(), RpqError> { for query in queries { let _ = black_box(run_once(algo, query, graph))?; } Ok(()) } +enum PreparedBatch { + Nfa(Vec), + Rpqmatrix(Vec), +} + +fn prepare_batch( + algo: &Algo, + queries: &[&RpqQuery], + graph: &InMemoryGraph, +) -> Result { + match algo { + Algo::Nfa => Ok(PreparedBatch::Nfa( + queries + .iter() + .map(|query| NfaRpqEvaluator.prepare(query, graph)) + .collect::, _>>()?, + )), + Algo::Rpqmatrix => Ok(PreparedBatch::Rpqmatrix( + queries + .iter() + .map(|query| RpqMatrixEvaluator.prepare(query, graph)) + .collect::, _>>()?, + )), + } +} + +fn run_prepared_batch(prepared: &mut PreparedBatch) -> Result<(), RpqError> { + match prepared { + PreparedBatch::Nfa(items) => { + for item in items { + let result = item.execute()?; + let count = result + .reachable + .nvals() + .map_err(crate::rpq::RpqError::Graph)? as usize; + let _ = black_box(count); + } + } + PreparedBatch::Rpqmatrix(items) => { + for item in items { + let result = item.execute()?; + let _ = black_box(result.nnz); + } + } + } + + Ok(()) +} + /// Read criterion timing estimates from its output directory. /// /// After `group.finish()`, criterion writes: @@ -96,6 +149,13 @@ fn read_criterion_estimates( }) } +fn read_algo_timing_estimates(criterion_dir: &str, group_name: &str) -> Option { + let total = read_criterion_estimates(criterion_dir, group_name, "eval_total")?; + let ffi_only = read_criterion_estimates(criterion_dir, group_name, "eval_ffi_only")?; + + Some(AlgoTiming { total, ffi_only }) +} + /// Run the full benchmark loop, processing queries in batches. /// /// Queries are grouped into batches of `batch_size`. For each batch and @@ -152,8 +212,9 @@ pub fn run_benchmarks( // ── First pass: correctness check + collect valid queries per algo ── let mut per_query_results: Vec = Vec::new(); - // algo key → list of (query_ref, result_count) - let mut valid_queries_per_algo: HashMap> = HashMap::new(); + // algo key → list of (query_index, query_ref, result_count) + let mut valid_queries_per_algo: HashMap> = + HashMap::new(); for &(idx, loaded) in batch { let mut algo_results: HashMap = HashMap::new(); @@ -195,7 +256,7 @@ pub fn run_benchmarks( valid_queries_per_algo .entry(algo.to_string()) .or_default() - .push((query, count)); + .push((idx, query, count)); } Ok(Err(e)) => { eprintln!( @@ -226,7 +287,7 @@ pub fn run_benchmarks( } // ── Second pass: criterion benchmark per algo over valid queries ── - let mut batch_algo_timing: HashMap> = HashMap::new(); + let mut batch_algo_timing: HashMap> = HashMap::new(); for algo in &args.common.algo { let algo_key = algo.to_string(); @@ -247,17 +308,26 @@ pub fn run_benchmarks( let mut group = criterion.benchmark_group(&group_name); let algo_clone = algo.clone(); - let queries_clone: Vec = valid.iter().map(|(q, _)| (*q).clone()).collect(); + let queries_clone: Vec = valid.iter().map(|(_, q, _)| (*q).clone()).collect(); - group.bench_function("eval", |b| { + group.bench_function("eval_total", |b| { b.iter(|| { let refs: Vec<&RpqQuery> = queries_clone.iter().collect(); - let _ = black_box(run_batch(&algo_clone, &refs, graph)); + let _ = black_box(run_batch_total(&algo_clone, &refs, graph)); + }); + }); + + group.bench_function("eval_ffi_only", |b| { + let refs: Vec<&RpqQuery> = queries_clone.iter().collect(); + let mut prepared = + prepare_batch(&algo_clone, &refs, graph).expect("prepare benchmark batch"); + b.iter(|| { + let _ = black_box(run_prepared_batch(&mut prepared)); }); }); group.finish(); - let timing = read_criterion_estimates(&args.criterion_dir, &group_name, "eval"); + let timing = read_algo_timing_estimates(&args.criterion_dir, &group_name); batch_algo_timing.insert(algo_key, timing); } @@ -272,28 +342,32 @@ pub fn run_benchmarks( let timing = batch_algo_timing .get(&algo_key) .and_then(|t| t.as_ref()) - .map(|t| TimingStats { - mean_ns: t.mean_ns, - median_ns: t.median_ns, - stddev_ns: t.stddev_ns, - iterations: t.iterations, + .map(|t| AlgoTiming { + total: TimingStats { + mean_ns: t.total.mean_ns, + median_ns: t.total.median_ns, + stddev_ns: t.total.stddev_ns, + iterations: t.total.iterations, + }, + ffi_only: TimingStats { + mean_ns: t.ffi_only.mean_ns, + median_ns: t.ffi_only.median_ns, + stddev_ns: t.ffi_only.stddev_ns, + iterations: t.ffi_only.iterations, + }, }); // Attach the result count from the correctness-check pass. let result_count = valid_queries_per_algo.get(&algo_key).and_then(|v| { v.iter() - .find(|(q, _)| { - // Match by pointer identity — the LoadedQuery we kept. - // Safe because we stored references into the same slice. - std::ptr::eq(*q as *const RpqQuery, qr.query_index as *const RpqQuery) - }) - .map(|(_, c)| *c) + .find(|(idx, _, _)| *idx == qr.query_index) + .map(|(_, _, c)| *c) }); // Fallback: if we can't match by pointer, use the first count from // the batch (acceptable when batch_size == 1, which is the default). let result_count = result_count.or_else(|| { valid_queries_per_algo .get(&algo_key) - .and_then(|v| v.iter().find(|(_, _)| true).map(|(_, c)| *c)) + .and_then(|v| v.first().map(|(_, _, c)| *c)) }); qr.algorithms .insert(algo_key.clone(), AlgoResult::ok(result_count, timing)); @@ -320,3 +394,39 @@ pub fn run_benchmarks( batch_results } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + fn write_estimate_files(base: &Path, bench_name: &str, mean_ns: f64, iterations: usize) { + let bench_dir = base.join("batch0_nfa").join(bench_name).join("new"); + fs::create_dir_all(&bench_dir).expect("create bench dir"); + fs::write( + bench_dir.join("estimates.json"), + format!( + r#"{{"mean":{{"point_estimate":{mean_ns}}},"median":{{"point_estimate":{mean_ns}}},"std_dev":{{"point_estimate":0.0}}}}"# + ), + ) + .expect("write estimates"); + let sample = format!("{{\"iters\":[{}]}}", vec!["1"; iterations].join(",")); + fs::write(bench_dir.join("sample.json"), sample).expect("write sample"); + } + + #[test] + fn read_split_criterion_estimates() { + let dir = tempfile::tempdir().expect("tempdir"); + write_estimate_files(dir.path(), "eval_total", 10.0, 3); + write_estimate_files(dir.path(), "eval_ffi_only", 4.0, 5); + + let timing = + read_algo_timing_estimates(dir.path().to_str().expect("utf8 path"), "batch0_nfa") + .expect("split timing"); + + assert_eq!(timing.total.mean_ns, 10.0); + assert_eq!(timing.total.iterations, 3); + assert_eq!(timing.ffi_only.mean_ns, 4.0); + assert_eq!(timing.ffi_only.iterations, 5); + } +} diff --git a/src/cli/output.rs b/src/cli/output.rs index 8762720..703ee90 100644 --- a/src/cli/output.rs +++ b/src/cli/output.rs @@ -20,12 +20,12 @@ pub struct AlgoResult { pub result_count: Option, /// Timing statistics — only present in `bench` mode. #[serde(skip_serializing_if = "Option::is_none")] - pub timing: Option, + pub timing: Option, } impl AlgoResult { /// Create a successful result with an optional result count and timing. - pub fn ok(result_count: Option, timing: Option) -> Self { + pub fn ok(result_count: Option, timing: Option) -> Self { Self { status: "ok".to_string(), error: None, @@ -55,6 +55,13 @@ impl AlgoResult { } } +/// Benchmark timings for one algorithm/result. +#[derive(Debug, Serialize)] +pub struct AlgoTiming { + pub total: TimingStats, + pub ffi_only: TimingStats, +} + /// Timing statistics extracted from criterion estimates. #[derive(Debug, Serialize)] pub struct TimingStats { @@ -147,3 +154,33 @@ impl BenchOutput { fs::write(path, json) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn algo_result_serializes_split_timing() { + let result = AlgoResult::ok( + Some(3), + Some(AlgoTiming { + total: TimingStats { + mean_ns: 1.0, + median_ns: 1.0, + stddev_ns: 0.0, + iterations: 10, + }, + ffi_only: TimingStats { + mean_ns: 0.5, + median_ns: 0.5, + stddev_ns: 0.0, + iterations: 10, + }, + }), + ); + + let value = serde_json::to_value(&result).expect("serialize"); + assert!(value["timing"]["total"].is_object()); + assert!(value["timing"]["ffi_only"].is_object()); + } +} diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index 2170e95..3dd2650 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -13,6 +13,7 @@ use rustfst::prelude::*; use rustfst::semirings::TropicalWeight; use rustfst::utils::{acceptor, epsilon_machine}; use std::collections::HashMap; +use std::sync::Arc; /// Transitions for a single edge label in the NFA. /// @@ -207,17 +208,49 @@ pub struct NfaRpqResult { pub reachable: GraphblasVector, } +pub struct PreparedNfaRpq { + nfa: Nfa, + nfa_matrices: Vec<(String, LagraphGraph)>, + nfa_graph_ptrs: Vec, + _data_graphs: Vec>, + data_graph_ptrs: Vec, + source_vertices: Vec, +} + +impl PreparedNfaRpq { + pub fn execute(&mut self) -> Result { + let mut reachable: GrB_Vector = std::ptr::null_mut(); + + unsafe { + la_ok!(LAGraph_RegularPathQuery( + &mut reachable, + self.nfa_graph_ptrs.as_mut_ptr(), + self.nfa_matrices.len(), + self.nfa.start_states.as_ptr(), + self.nfa.start_states.len(), + self.nfa.final_states.as_ptr(), + self.nfa.final_states.len(), + self.data_graph_ptrs.as_mut_ptr(), + self.source_vertices.as_ptr(), + self.source_vertices.len(), + ))? + }; + + Ok(NfaRpqResult { + reachable: GraphblasVector { inner: reachable }, + }) + } +} + /// Evaluates RPQs using `LAGraph_RegularPathQuery`. pub struct NfaRpqEvaluator; -impl RpqEvaluator for NfaRpqEvaluator { - type Result = NfaRpqResult; - - fn evaluate( +impl NfaRpqEvaluator { + pub fn prepare( &self, query: &RpqQuery, graph: &G, - ) -> Result { + ) -> Result { let nfa = Nfa::from_path_expr(&query.path)?; let nfa_matrices = nfa.build_lagraph_matrices()?; @@ -225,43 +258,43 @@ impl RpqEvaluator for NfaRpqEvaluator { let _dst_id = resolve_endpoint(&query.object, graph)?; let n = graph.num_nodes(); - let source_vertices: Vec = match src_id { Some(id) => vec![id as GrB_Index], None => (0..n as GrB_Index).collect(), }; - let mut nfa_graph_ptrs: Vec = + let nfa_graph_ptrs: Vec = nfa_matrices.iter().map(|(_, lg)| lg.inner).collect(); - let mut data_graph_ptrs: Vec = Vec::with_capacity(nfa_matrices.len()); + let mut data_graphs = Vec::with_capacity(nfa_matrices.len()); + let mut data_graph_ptrs = Vec::with_capacity(nfa_matrices.len()); for (label, _) in &nfa_matrices { let lg = graph.get_graph(label)?; data_graph_ptrs.push(lg.inner); + data_graphs.push(lg); } - let mut reachable: GrB_Vector = std::ptr::null_mut(); - - unsafe { - la_ok!(LAGraph_RegularPathQuery( - &mut reachable, - nfa_graph_ptrs.as_mut_ptr(), - nfa_matrices.len(), - nfa.start_states.as_ptr(), - nfa.start_states.len(), - nfa.final_states.as_ptr(), - nfa.final_states.len(), - data_graph_ptrs.as_mut_ptr(), - source_vertices.as_ptr(), - source_vertices.len(), - ))? - }; + Ok(PreparedNfaRpq { + nfa, + nfa_matrices, + nfa_graph_ptrs, + _data_graphs: data_graphs, + data_graph_ptrs, + source_vertices, + }) + } +} - let result_vec = GraphblasVector { inner: reachable }; +impl RpqEvaluator for NfaRpqEvaluator { + type Result = NfaRpqResult; - Ok(NfaRpqResult { - reachable: result_vec, - }) + fn evaluate( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + let mut prepared = self.prepare(query, graph)?; + prepared.execute() } } diff --git a/src/rpq/rpqmatrix.rs b/src/rpq/rpqmatrix.rs index 72f9110..67d6c9d 100644 --- a/src/rpq/rpqmatrix.rs +++ b/src/rpq/rpqmatrix.rs @@ -174,21 +174,14 @@ pub struct RpqMatrixResult { pub matrix: GraphblasMatrix, } -/// RPQ evaluator backed by `LAGraph_RPQMatrix`. -pub struct RpqMatrixEvaluator; - -impl RpqEvaluator for RpqMatrixEvaluator { - type Result = RpqMatrixResult; - - fn evaluate( - &self, - query: &RpqQuery, - graph: &G, - ) -> Result { - let expr = query_to_expr(query)?; - let (mut plans, owned_matrices) = materialize(&expr, graph)?; +pub struct PreparedRpqMatrix { + plans: Vec, + owned_matrices: Vec, +} - let root_ptr = unsafe { plans.as_mut_ptr().add(plans.len() - 1) }; +impl PreparedRpqMatrix { + pub fn execute(&mut self) -> Result { + let root_ptr = unsafe { self.plans.as_mut_ptr().add(self.plans.len() - 1) }; let mut nnz: GrB_Index = 0; unsafe { la_ok!(LAGraph_RPQMatrix(&mut nnz, root_ptr))? }; @@ -201,20 +194,55 @@ impl RpqEvaluator for RpqMatrixEvaluator { unsafe { grb_ok!(LAGraph_DestroyRpqMatrixPlan(root_ptr))? }; - // Free diagonal matrices created for named vertices. - for mut mat in owned_matrices { + Ok(RpqMatrixResult { + nnz: nnz as u64, + matrix, + }) + } +} + +impl Drop for PreparedRpqMatrix { + fn drop(&mut self) { + for mat in &mut self.owned_matrices { unsafe { - LAGraph_RPQMatrix_Free(&mut mat); + LAGraph_RPQMatrix_Free(mat); } } + } +} - Ok(RpqMatrixResult { - nnz: nnz as u64, - matrix, +/// RPQ evaluator backed by `LAGraph_RPQMatrix`. +pub struct RpqMatrixEvaluator; + +impl RpqMatrixEvaluator { + pub fn prepare( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + let expr = query_to_expr(query)?; + let (plans, owned_matrices) = materialize(&expr, graph)?; + + Ok(PreparedRpqMatrix { + plans, + owned_matrices, }) } } +impl RpqEvaluator for RpqMatrixEvaluator { + type Result = RpqMatrixResult; + + fn evaluate( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + let mut prepared = self.prepare(query, graph)?; + prepared.execute() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index fbdf42a..fee4ade 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -167,6 +167,26 @@ fn test_sequence_path() { assert_eq!(count, 1); } +#[test] +fn prepared_nfa_execution_matches_evaluate() { + let graph = build_graph(&[("A", "B", "knows"), ("B", "C", "likes")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("knows")), Box::new(label("likes"))), + var("y"), + ); + + let evaluator = NfaRpqEvaluator; + let direct = evaluator.evaluate(&query, &graph).expect("direct"); + let direct_count = direct.reachable.nvals().expect("direct nvals"); + + let mut prepared = evaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("prepared execute"); + let prepared_count = prepared_result.reachable.nvals().expect("prepared nvals"); + + assert_eq!(prepared_count, direct_count); +} + /// Graph: A --knows--> B --likes--> C /// Query: / ?y #[test] diff --git a/tests/rpqmatrix_tests.rs b/tests/rpqmatrix_tests.rs index b23e0e9..e13b8f1 100644 --- a/tests/rpqmatrix_tests.rs +++ b/tests/rpqmatrix_tests.rs @@ -172,6 +172,38 @@ fn test_sequence_path() { assert_eq!(result.nnz, 1); } +#[test] +fn prepared_rpqmatrix_execution_matches_evaluate() { + let graph = build_graph(&[("A", "B", "p"), ("B", "C", "q")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("p")), Box::new(label("q"))), + var("y"), + ); + + let direct = RpqMatrixEvaluator.evaluate(&query, &graph).expect("direct"); + let mut prepared = RpqMatrixEvaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("execute"); + + assert_eq!(prepared_result.nnz, direct.nnz); +} + +#[test] +fn prepared_rpqmatrix_execution_can_run_twice() { + let graph = build_graph(&[("A", "B", "p"), ("B", "C", "q")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("p")), Box::new(label("q"))), + var("y"), + ); + + let mut prepared = RpqMatrixEvaluator.prepare(&query, &graph).expect("prepare"); + let first = prepared.execute().expect("first"); + let second = prepared.execute().expect("second"); + + assert_eq!(first.nnz, second.nnz); +} + /// Graph: A --knows--> B --likes--> C /// Query: / ?y → only A→C, nnz=1 #[test] From 6f6866256124b62823403e9af1f4b2b89e3f8ee5 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 21 Apr 2026 13:22:23 +0300 Subject: [PATCH 03/11] feat: expand bench inputs and align RPQ result reporting --- build.rs | 2 ++ src/cli/bench.rs | 35 +++++++++++++++++++++++++++++++---- src/cli/loader.rs | 2 +- src/formats/nt.rs | 2 +- src/graph/mod.rs | 4 ++-- src/lagraph_sys_generated.rs | 10 ++++++++++ src/rpq/rpqmatrix.rs | 24 ++++++++++++++++++++---- 7 files changed, 67 insertions(+), 12 deletions(-) diff --git a/build.rs b/build.rs index 243bbcb..95243fb 100644 --- a/build.rs +++ b/build.rs @@ -65,6 +65,7 @@ fn regenerate_bindings() { .allowlist_item("GrB_Info") .allowlist_function("GrB_Matrix_new") .allowlist_function("GrB_Matrix_nvals") + .allowlist_function("GrB_Matrix_dup") .allowlist_function("GrB_Matrix_free") .allowlist_function("GrB_Matrix_extractElement_BOOL") .allowlist_function("GrB_Matrix_build_BOOL") @@ -89,6 +90,7 @@ fn regenerate_bindings() { .allowlist_function("LAGraph_Cached_AT") .allowlist_function("LAGraph_MMRead") .allowlist_function("LAGraph_RPQMatrix") + .allowlist_function("LAGraph_RPQMatrix_reduce") .allowlist_function("LAGraph_DestroyRpqMatrixPlan") .allowlist_function("LAGraph_RPQMatrix_label") .allowlist_function("LAGraph_RPQMatrix_Free") diff --git a/src/cli/bench.rs b/src/cli/bench.rs index c4c8725..856ba97 100644 --- a/src/cli/bench.rs +++ b/src/cli/bench.rs @@ -17,10 +17,11 @@ use super::checkpoint::Checkpoint; use super::loader::LoadedQuery; use super::output::{AlgoResult, AlgoTiming, BatchResult, QueryResult, TimingStats}; -/// Run a single evaluation and return the result count (nnz / reachable nodes). +/// Run a single evaluation and return the result count. /// /// Used by both the correctness-check pass before benchmarking and by the -/// `query` subcommand runner. +/// `query` subcommand runner. Both algorithms report the number of reachable +/// target vertices. pub(crate) fn run_once( algo: &Algo, query: &RpqQuery, @@ -37,7 +38,8 @@ pub(crate) fn run_once( } Algo::Rpqmatrix => { let result = RpqMatrixEvaluator.evaluate(query, graph)?; - Ok(result.nnz as usize) + let count = result.reachable_target_count().map_err(RpqError::Graph)? as usize; + Ok(count) } } } @@ -98,7 +100,10 @@ fn run_prepared_batch(prepared: &mut PreparedBatch) -> Result<(), RpqError> { PreparedBatch::Rpqmatrix(items) => { for item in items { let result = item.execute()?; - let _ = black_box(result.nnz); + let count = result + .reachable_target_count() + .map_err(crate::rpq::RpqError::Graph)? as usize; + let _ = black_box(count); } } } @@ -400,6 +405,9 @@ mod tests { use super::*; use std::fs; + use crate::rpq::{Endpoint, PathExpr, RpqQuery}; + use crate::utils::build_graph; + fn write_estimate_files(base: &Path, bench_name: &str, mean_ns: f64, iterations: usize) { let bench_dir = base.join("batch0_nfa").join(bench_name).join("new"); fs::create_dir_all(&bench_dir).expect("create bench dir"); @@ -429,4 +437,23 @@ mod tests { assert_eq!(timing.ffi_only.mean_ns, 4.0); assert_eq!(timing.ffi_only.iterations, 5); } + + #[test] + fn run_once_rpqmatrix_count_matches_nfa_reachable_targets() { + let graph = build_graph(&[("A", "B", "p"), ("C", "B", "p")]); + let query = RpqQuery { + subject: Endpoint::Variable("x".into()), + path: PathExpr::Label("p".into()), + object: Endpoint::Variable("y".into()), + }; + + let rpqmatrix_count = run_once(&Algo::Rpqmatrix, &query, &graph).expect("rpqmatrix count"); + let nfa_count = run_once(&Algo::Nfa, &query, &graph).expect("nfa count"); + + assert_eq!( + rpqmatrix_count, 1, + "shared count should report reachable target count" + ); + assert_eq!(nfa_count, rpqmatrix_count); + } } diff --git a/src/cli/loader.rs b/src/cli/loader.rs index f7e4abd..81df9fd 100644 --- a/src/cli/loader.rs +++ b/src/cli/loader.rs @@ -9,8 +9,8 @@ use std::io::{BufRead, BufReader}; use std::path::Path; use std::process; -use crate::formats::mm::MatrixMarket; use crate::formats::Csv; +use crate::formats::mm::MatrixMarket; use crate::graph::{Graph, InMemory, InMemoryGraph}; use crate::rpq::{RpqError, RpqQuery}; use crate::sparql::parse_rpq; diff --git a/src/formats/nt.rs b/src/formats/nt.rs index 0e711e8..aa08880 100644 --- a/src/formats/nt.rs +++ b/src/formats/nt.rs @@ -27,8 +27,8 @@ use std::io::Read; use oxrdf::{NamedOrBlankNode, Term}; -use oxttl::ntriples::ReaderNTriplesParser; use oxttl::NTriplesParser; +use oxttl::ntriples::ReaderNTriplesParser; use crate::formats::FormatError; use crate::graph::Edge; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 5b489d0..0922459 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -4,8 +4,8 @@ pub mod inmemory; pub mod wrappers; pub use inmemory::{InMemory, InMemoryBuilder, InMemoryGraph}; -pub(crate) use wrappers::{compute_outer_inner, ensure_grb_init, ThreadScope}; -pub use wrappers::{load_mm_file, GraphblasMatrix, GraphblasVector, LagraphGraph}; +pub use wrappers::{GraphblasMatrix, GraphblasVector, LagraphGraph, load_mm_file}; +pub(crate) use wrappers::{ThreadScope, compute_outer_inner, ensure_grb_init}; use std::marker::PhantomData; use std::sync::Arc; diff --git a/src/lagraph_sys_generated.rs b/src/lagraph_sys_generated.rs index 5e02d97..1a9188f 100644 --- a/src/lagraph_sys_generated.rs +++ b/src/lagraph_sys_generated.rs @@ -155,6 +155,9 @@ unsafe extern "C" { ncols: GrB_Index, ) -> GrB_Info; } +unsafe extern "C" { + pub fn GrB_Matrix_dup(C: *mut GrB_Matrix, A: GrB_Matrix) -> GrB_Info; +} unsafe extern "C" { pub fn GrB_Matrix_nvals(nvals: *mut GrB_Index, A: GrB_Matrix) -> GrB_Info; } @@ -338,3 +341,10 @@ unsafe extern "C" { unsafe extern "C" { pub fn LAGraph_RPQMatrix_Free(mat: *mut GrB_Matrix) -> GrB_Info; } +unsafe extern "C" { + pub fn LAGraph_RPQMatrix_reduce( + res: *mut GrB_Index, + mat: GrB_Matrix, + reduce_type: u8, + ) -> GrB_Info; +} diff --git a/src/rpq/rpqmatrix.rs b/src/rpq/rpqmatrix.rs index 67d6c9d..623a838 100644 --- a/src/rpq/rpqmatrix.rs +++ b/src/rpq/rpqmatrix.rs @@ -9,6 +9,8 @@ use crate::lagraph_sys::*; use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; use crate::{grb_ok, la_ok}; +const RPQMATRIX_REDUCE_BY_COL: u8 = 1; + define_language! { pub enum RpqPlan { Label(String), @@ -174,6 +176,20 @@ pub struct RpqMatrixResult { pub matrix: GraphblasMatrix, } +impl RpqMatrixResult { + /// Count distinct reachable target vertices by reducing the path relation + /// matrix to its non-empty columns. + pub fn reachable_target_count(&self) -> Result { + let mut count: GrB_Index = 0; + unsafe { grb_ok!(LAGraph_RPQMatrix_reduce( + &mut count, + self.matrix.inner, + RPQMATRIX_REDUCE_BY_COL, + ))? }; + Ok(count as u64) + } +} + pub struct PreparedRpqMatrix { plans: Vec, owned_matrices: Vec, @@ -186,10 +202,10 @@ impl PreparedRpqMatrix { let mut nnz: GrB_Index = 0; unsafe { la_ok!(LAGraph_RPQMatrix(&mut nnz, root_ptr))? }; - let matrix = unsafe { - let mat = (*root_ptr).res_mat; - (*root_ptr).res_mat = null_mut(); - GraphblasMatrix { inner: mat } + let mut matrix_inner: GrB_Matrix = null_mut(); + unsafe { grb_ok!(GrB_Matrix_dup(&mut matrix_inner, (*root_ptr).res_mat))? }; + let matrix = GraphblasMatrix { + inner: matrix_inner, }; unsafe { grb_ok!(LAGraph_DestroyRpqMatrixPlan(root_ptr))? }; From 2b4cf41f38312d0ba2c57b8cb70ee525317691a9 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 21 Apr 2026 13:48:15 +0300 Subject: [PATCH 04/11] fix: respect bound objects in NFA RPQ results --- src/cli/query.rs | 2 +- src/graph/wrappers.rs | 8 +++++++ src/rpq/nfarpq.rs | 50 ++++++++++++++++++++++++++++++++++++++----- tests/nfarpq_tests.rs | 44 +++++++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 6 deletions(-) diff --git a/src/cli/query.rs b/src/cli/query.rs index 4870a37..1f4fb73 100644 --- a/src/cli/query.rs +++ b/src/cli/query.rs @@ -52,7 +52,7 @@ pub fn run_queries( let algo_result = match outcome { Ok(Ok(count)) => { eprintln!( - "[query #{idx}] id={} algo={} — {count} result(s)", + "[query #{idx}] id={} algo={} — {count} count", loaded.id, algo ); AlgoResult::ok(Some(count), None) diff --git a/src/graph/wrappers.rs b/src/graph/wrappers.rs index f8128e6..cc8d3a1 100644 --- a/src/graph/wrappers.rs +++ b/src/graph/wrappers.rs @@ -176,6 +176,12 @@ pub struct GraphblasVector { } impl GraphblasVector { + pub fn new_bool(n: GrB_Index) -> Result { + let mut v: GrB_Vector = std::ptr::null_mut(); + unsafe { grb_ok!(GrB_Vector_new(&mut v, GrB_BOOL, n))? }; + Ok(Self { inner: v }) + } + /// Returns the number of stored values in this vector. pub fn nvals(&self) -> Result { let mut nvals: GrB_Index = 0; @@ -206,6 +212,8 @@ impl GraphblasVector { indices.truncate(actual_nvals as usize); Ok(indices) } + + } impl Drop for GraphblasVector { diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index 3dd2650..9550396 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -1,10 +1,10 @@ //! NFA-based RPQ evaluation using `LAGraph_RegularPathQuery`. use crate::graph::{GraphDecomposition, GraphblasVector, LagraphGraph}; -use crate::la_ok; use crate::lagraph_sys::LAGraph_Kind; use crate::lagraph_sys::*; use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use crate::{grb_ok, la_ok}; use rustfst::algorithms::closure::{ClosureType, closure}; use rustfst::algorithms::concat::concat; use rustfst::algorithms::rm_epsilon::rm_epsilon; @@ -215,6 +215,33 @@ pub struct PreparedNfaRpq { _data_graphs: Vec>, data_graph_ptrs: Vec, source_vertices: Vec, + destination_vertex: Option, + num_nodes: usize, +} + +fn filter_reachable_by_destination( + reachable: GraphblasVector, + destination_vertex: Option, + num_nodes: usize, +) -> Result { + let Some(destination_vertex) = destination_vertex else { + return Ok(reachable); + }; + + let indices = reachable.indices().map_err(RpqError::Graph)?; + let filtered = GraphblasVector::new_bool(num_nodes as GrB_Index)?; + + if indices.contains(&(destination_vertex as GrB_Index)) { + unsafe { + grb_ok!(GrB_Vector_setElement_BOOL( + filtered.inner, + true, + destination_vertex as GrB_Index, + ))? + }; + } + + Ok(filtered) } impl PreparedNfaRpq { @@ -236,9 +263,13 @@ impl PreparedNfaRpq { ))? }; - Ok(NfaRpqResult { - reachable: GraphblasVector { inner: reachable }, - }) + let reachable = filter_reachable_by_destination( + GraphblasVector { inner: reachable }, + self.destination_vertex, + self.num_nodes, + )?; + + Ok(NfaRpqResult { reachable }) } } @@ -281,6 +312,8 @@ impl NfaRpqEvaluator { _data_graphs: data_graphs, data_graph_ptrs, source_vertices, + destination_vertex: _dst_id, + num_nodes: n, }) } } @@ -294,7 +327,14 @@ impl RpqEvaluator for NfaRpqEvaluator { graph: &G, ) -> Result { let mut prepared = self.prepare(query, graph)?; - prepared.execute() + let result = prepared.execute()?; + let destination_vertex = resolve_endpoint(&query.object, graph)?; + let reachable = filter_reachable_by_destination( + result.reachable, + destination_vertex, + graph.num_nodes(), + )?; + Ok(NfaRpqResult { reachable }) } } diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index fee4ade..3269cb5 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -187,6 +187,50 @@ fn prepared_nfa_execution_matches_evaluate() { assert_eq!(prepared_count, direct_count); } +#[test] +fn test_bound_object_filters_reachable_targets() { + let graph = build_graph(&[("A", "B", "knows"), ("C", "D", "knows")]); + let evaluator = NfaRpqEvaluator; + + let result = evaluator + .evaluate(&rq(var("x"), label("knows"), named_ep("B")), &graph) + .expect("evaluate should succeed"); + + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); + let b_id = graph.get_node_id("B").expect("B should exist") as GrB_Index; + let d_id = graph.get_node_id("D").expect("D should exist") as GrB_Index; + + assert_eq!( + indices, + vec![b_id], + "only the bound object should remain reachable" + ); + assert!( + !indices.contains(&d_id), + "unbound reachable targets must be filtered out" + ); +} + +#[test] +fn prepared_nfa_execution_respects_bound_object() { + let graph = build_graph(&[("A", "B", "knows"), ("C", "D", "knows")]); + let evaluator = NfaRpqEvaluator; + let query = rq(var("x"), label("knows"), named_ep("B")); + + let mut prepared = evaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("prepared execute"); + let prepared_indices = prepared_result + .reachable + .indices() + .expect("prepared indices"); + let b_id = graph.get_node_id("B").expect("B should exist") as GrB_Index; + + assert_eq!(prepared_indices, vec![b_id]); +} + /// Graph: A --knows--> B --likes--> C /// Query: / ?y #[test] From e8232fe57b8e2bc765c99bb322f1f7a0fc6fc77e Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 21 Apr 2026 15:30:13 +0300 Subject: [PATCH 05/11] build: add dockerfile --- .dockerignore | 8 ++++++ Dockerfile | 49 +++++++++++++++++++++++++++++++++++++ docker/docker-entrypoint.sh | 42 +++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100755 docker/docker-entrypoint.sh diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..faf37b5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git/ +.worktrees/ +target/ +.tmp/ +deps/LAGraph/build/ +bench_criterion/ +bench_results.json +bench_checkpoint.json diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0d8f56a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,49 @@ +FROM rust:1-bookworm AS builder + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + clang \ + cmake \ + git \ + libclang-dev \ + make \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /src + +COPY . . + +RUN git clone --depth 1 https://github.com/DrTimothyAldenDavis/GraphBLAS.git /tmp/GraphBLAS \ + && make -C /tmp/GraphBLAS compact \ + && make -C /tmp/GraphBLAS install \ + && mkdir -p deps/LAGraph/build \ + && make -C deps/LAGraph \ + && cargo build --release --bin pathrex --features "bench,regenerate-bindings" + +FROM debian:bookworm-slim AS runtime + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + gcc \ + libc6-dev \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /work + +COPY --from=builder /src/target/release/pathrex /usr/local/bin/pathrex +COPY --from=builder /usr/local/lib/libgraphblas.so* /usr/local/lib/ +COPY --from=builder /src/deps/LAGraph/build/src/liblagraph.so* /usr/local/lib/ +COPY --from=builder /src/deps/LAGraph/build/experimental/liblagraphx.so* /usr/local/lib/ +COPY --from=builder /src/docker/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh + +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +ENV LD_LIBRARY_PATH=/usr/local/lib + +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +CMD ["--help"] diff --git a/docker/docker-entrypoint.sh b/docker/docker-entrypoint.sh new file mode 100755 index 0000000..8da1b38 --- /dev/null +++ b/docker/docker-entrypoint.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -euo pipefail + +pathrex_bin="${PATHREX_BIN:-/usr/local/bin/pathrex}" + +if [ "${1-}" = "bench" ]; then + shift + has_output=false + has_checkpoint=false + has_criterion_dir=false + args=("$@") + + for arg in "${args[@]}"; do + case "$arg" in + -o|-o*|--output|--output=*) + has_output=true + ;; + -c|-c*|--checkpoint|--checkpoint=*) + has_checkpoint=true + ;; + --criterion-dir|--criterion-dir=*) + has_criterion_dir=true + ;; + esac + done + + if [ "$has_output" = false ]; then + args+=(--output /results/bench_results.json) + fi + + if [ "$has_checkpoint" = false ]; then + args+=(--checkpoint /results/bench_checkpoint.json) + fi + + if [ "$has_criterion_dir" = false ]; then + args+=(--criterion-dir /results/criterion) + fi + + exec "$pathrex_bin" bench "${args[@]}" +fi + +exec "$pathrex_bin" "$@" From c5f58be3c7a206c06d0dbd678e3b8e6bf44248a9 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 28 Apr 2026 02:23:56 +0300 Subject: [PATCH 06/11] ref: get read of correctness pass in bench loop --- src/cli/bench.rs | 73 ++++++++++-------------------------------------- 1 file changed, 15 insertions(+), 58 deletions(-) diff --git a/src/cli/bench.rs b/src/cli/bench.rs index 856ba97..3ce802d 100644 --- a/src/cli/bench.rs +++ b/src/cli/bench.rs @@ -5,7 +5,7 @@ use std::fs::File; use std::path::Path; use std::time::Duration; -use criterion::{Criterion, black_box}; +use criterion::{black_box, Criterion}; use crate::graph::InMemoryGraph; use crate::rpq::nfarpq::{NfaRpqEvaluator, PreparedNfaRpq}; @@ -19,9 +19,7 @@ use super::output::{AlgoResult, AlgoTiming, BatchResult, QueryResult, TimingStat /// Run a single evaluation and return the result count. /// -/// Used by both the correctness-check pass before benchmarking and by the -/// `query` subcommand runner. Both algorithms report the number of reachable -/// target vertices. +/// Used by the `query` subcommand runner. pub(crate) fn run_once( algo: &Algo, query: &RpqQuery, @@ -215,11 +213,10 @@ pub fn run_benchmarks( batch_index, batch_indices, batch_ids ); - // ── First pass: correctness check + collect valid queries per algo ── + // Separate parse errors from valid queries. let mut per_query_results: Vec = Vec::new(); - // algo key → list of (query_index, query_ref, result_count) - let mut valid_queries_per_algo: HashMap> = - HashMap::new(); + // algo key → list of (query_index, query_ref) + let mut valid_queries_per_algo: HashMap> = HashMap::new(); for &(idx, loaded) in batch { let mut algo_results: HashMap = HashMap::new(); @@ -248,38 +245,11 @@ pub fn run_benchmarks( }; for algo in &args.common.algo { - if checkpoint.is_algo_done(idx, algo) { - continue; - } - - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - run_once(algo, query, graph) - })); - - match result { - Ok(Ok(count)) => { - valid_queries_per_algo - .entry(algo.to_string()) - .or_default() - .push((idx, query, count)); - } - Ok(Err(e)) => { - eprintln!( - " [error] query #{} (id={}) algo={}: {}", - idx, loaded.id, algo, e - ); - algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); - checkpoint.mark_algo_done(idx, &loaded.id, algo); - } - Err(panic_info) => { - let msg = format!("{:?}", panic_info); - eprintln!( - " [panic] query #{} (id={}) algo={}: {}", - idx, loaded.id, algo, msg - ); - algo_results.insert(algo.to_string(), AlgoResult::panic(msg)); - checkpoint.mark_algo_done(idx, &loaded.id, algo); - } + if !checkpoint.is_algo_done(idx, algo) { + valid_queries_per_algo + .entry(algo.to_string()) + .or_default() + .push((idx, query)); } } @@ -291,7 +261,7 @@ pub fn run_benchmarks( }); } - // ── Second pass: criterion benchmark per algo over valid queries ── + // ── Criterion benchmark per algo over valid queries ── let mut batch_algo_timing: HashMap> = HashMap::new(); for algo in &args.common.algo { @@ -313,7 +283,7 @@ pub fn run_benchmarks( let mut group = criterion.benchmark_group(&group_name); let algo_clone = algo.clone(); - let queries_clone: Vec = valid.iter().map(|(_, q, _)| (*q).clone()).collect(); + let queries_clone: Vec = valid.iter().map(|(_, q)| (*q).clone()).collect(); group.bench_function("eval_total", |b| { b.iter(|| { @@ -336,11 +306,11 @@ pub fn run_benchmarks( batch_algo_timing.insert(algo_key, timing); } - // Assign timing + result counts to each query's algo result. + // Assign timing to each query's algo result. for qr in &mut per_query_results { for algo in &args.common.algo { let algo_key = algo.to_string(); - // Only fill in queries that didn't already get an error/panic result. + // Only fill in queries that didn't already get a parse error result. if qr.algorithms.contains_key(&algo_key) { continue; } @@ -361,21 +331,8 @@ pub fn run_benchmarks( iterations: t.ffi_only.iterations, }, }); - // Attach the result count from the correctness-check pass. - let result_count = valid_queries_per_algo.get(&algo_key).and_then(|v| { - v.iter() - .find(|(idx, _, _)| *idx == qr.query_index) - .map(|(_, _, c)| *c) - }); - // Fallback: if we can't match by pointer, use the first count from - // the batch (acceptable when batch_size == 1, which is the default). - let result_count = result_count.or_else(|| { - valid_queries_per_algo - .get(&algo_key) - .and_then(|v| v.first().map(|(_, _, c)| *c)) - }); qr.algorithms - .insert(algo_key.clone(), AlgoResult::ok(result_count, timing)); + .insert(algo_key.clone(), AlgoResult::ok(None, timing)); } } From 7d54d2dfedde8f6e9ee506a40a56fc73c03a2922 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 28 Apr 2026 02:24:33 +0300 Subject: [PATCH 07/11] ref: cargo fmt --- src/cli/bench.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cli/bench.rs b/src/cli/bench.rs index 3ce802d..6c37bf7 100644 --- a/src/cli/bench.rs +++ b/src/cli/bench.rs @@ -5,7 +5,7 @@ use std::fs::File; use std::path::Path; use std::time::Duration; -use criterion::{black_box, Criterion}; +use criterion::{Criterion, black_box}; use crate::graph::InMemoryGraph; use crate::rpq::nfarpq::{NfaRpqEvaluator, PreparedNfaRpq}; From 68833b78381d47562646b95d2b7d0f6345098fa9 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Tue, 28 Apr 2026 18:04:03 +0300 Subject: [PATCH 08/11] ref: generalize evaluator --- AGENTS.md | 48 ++++- src/bin/pathrex.rs | 378 ++++++++++++++++++--------------- src/cli/args.rs | 82 +++++--- src/cli/bench.rs | 416 ------------------------------------- src/cli/bench/error.rs | 21 ++ src/cli/bench/estimates.rs | 149 +++++++++++++ src/cli/bench/mod.rs | 7 + src/cli/bench/runner.rs | 163 +++++++++++++++ src/cli/checkpoint.rs | 95 ++++++--- src/cli/dispatch.rs | 111 ++++++++++ src/cli/loader.rs | 101 +++++---- src/cli/mod.rs | 1 + src/cli/output.rs | 71 +++---- src/cli/query.rs | 108 ++++++---- src/eval/mod.rs | 35 ++++ src/lib.rs | 1 + src/rpq/mod.rs | 24 ++- src/rpq/nfarpq.rs | 52 +++-- src/rpq/rpqmatrix.rs | 53 ++--- tests/nfarpq_tests.rs | 2 +- tests/rpqmatrix_tests.rs | 2 +- 21 files changed, 1094 insertions(+), 826 deletions(-) delete mode 100644 src/cli/bench.rs create mode 100644 src/cli/bench/error.rs create mode 100644 src/cli/bench/estimates.rs create mode 100644 src/cli/bench/mod.rs create mode 100644 src/cli/bench/runner.rs create mode 100644 src/cli/dispatch.rs create mode 100644 src/eval/mod.rs diff --git a/AGENTS.md b/AGENTS.md index 223cfcb..22dc2df 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,7 @@ pathrex/ ├── Cargo.toml # Crate manifest (edition 2024) ├── build.rs # Links LAGraph + LAGraphX; optionally regenerates FFI bindings ├── src/ -│ ├── lib.rs # Modules: formats, graph, rpq, sparql, utils, lagraph_sys +│ ├── lib.rs # Modules: eval, formats, graph, rpq, sparql, utils, lagraph_sys │ ├── main.rs # Binary entry point (placeholder) │ ├── lagraph_sys.rs # FFI module — includes generated bindings │ ├── lagraph_sys_generated.rs# Bindgen output (checked in, regenerated in CI) @@ -24,8 +24,10 @@ pathrex/ │ │ ├── mod.rs # Core traits (GraphBuilder, GraphDecomposition, GraphSource, │ │ │ # Backend, Graph), error types, RAII wrappers, GrB init │ │ └── inmemory.rs # InMemory marker, InMemoryBuilder, InMemoryGraph +│ ├── eval/ +│ │ └── mod.rs # Evaluator, PreparedEvaluator, ResultCount traits │ ├── rpq/ -│ │ ├── mod.rs # RpqEvaluator (assoc. Result), RpqQuery, Endpoint, PathExpr, RpqError +│ │ ├── mod.rs # RPQ query types, RpqError, RPQ marker subtraits │ │ ├── nfarpq.rs # NfaRpqEvaluator (LAGraph_RegularPathQuery) │ │ └── rpqmatrix.rs # Matrix-plan RPQ evaluator │ ├── sparql/ @@ -187,6 +189,18 @@ pub trait Backend { - [`get_node_id(string_id)`](src/graph/mod.rs:200) / [`get_node_name(mapped_id)`](src/graph/mod.rs:203) — bidirectional string ↔ integer dictionary. - [`num_nodes()`](src/graph/mod.rs:204) — total unique nodes. +### Generic evaluator abstraction (`src/eval/`) + +[`src/eval/mod.rs`](src/eval/mod.rs) defines query-language-agnostic evaluator traits: + +- [`Evaluator`](src/eval/mod.rs) uses associated types for `Query`, `Result`, `Error`, and + `Prepared`. The graph backend stays a method-level generic (`G: GraphDecomposition`) so one + evaluator type can run against any graph backend selected at the call site. +- [`PreparedEvaluator`](src/eval/mod.rs) represents prepared `(query, graph)` state that can be + executed repeatedly, which is used by benchmark timing loops. +- [`ResultCount`](src/eval/mod.rs) is separate from `Evaluator::Result`; only CLI runners that + need counts require this bound, leaving room for future evaluators with richer result types. + ### InMemoryBuilder / InMemoryGraph [`InMemoryBuilder`](src/graph/inmemory.rs:36) is the primary `GraphBuilder` implementation. @@ -331,10 +345,11 @@ Key public items: - [`RpqQuery`](src/rpq/mod.rs) — `{ subject, path, object }` using the types above; [`strip_base(&mut self, base)`](src/rpq/mod.rs) removes a shared IRI prefix from named endpoints and labels. -- [`RpqEvaluator`](src/rpq/mod.rs) — trait with associated type `Result` and - [`evaluate(query, graph)`](src/rpq/mod.rs) taking `&RpqQuery` and - [`GraphDecomposition`], returning `Result`. - Each concrete evaluator exposes its own output type (see below). +- [`RpqEvaluator`](src/rpq/mod.rs) — marker subtrait over + [`Evaluator`](src/eval/mod.rs), preserving the RPQ-facing + trait name while the generic evaluator hierarchy lives in `src/eval/`. +- [`PreparedRpq`](src/rpq/mod.rs) — marker subtrait over + [`PreparedEvaluator`](src/eval/mod.rs). - [`RpqError`](src/rpq/mod.rs) — unified error type for RPQ parsing and evaluation: `Parse` (SPARQL syntax), `Extract` (query extraction), `UnsupportedPath`, `VertexNotFound`, and `Graph` (wraps [`GraphError`](src/graph/mod.rs) for @@ -378,10 +393,11 @@ Key public items: - [`RpqQuery`](src/rpq/mod.rs) — `{ subject, path, object }` using the types above; [`strip_base(&mut self, base)`](src/rpq/mod.rs) removes a shared IRI prefix from named endpoints and labels. -- [`RpqEvaluator`](src/rpq/mod.rs) — trait with associated type `Result` and - [`evaluate(query, graph)`](src/rpq/mod.rs) taking `&RpqQuery` and - [`GraphDecomposition`], returning `Result`. - Each concrete evaluator exposes its own output type (see below). +- [`RpqEvaluator`](src/rpq/mod.rs) — marker subtrait over + [`Evaluator`](src/eval/mod.rs), preserving the RPQ-facing + trait name while the generic evaluator hierarchy lives in `src/eval/`. +- [`PreparedRpq`](src/rpq/mod.rs) — marker subtrait over + [`PreparedEvaluator`](src/eval/mod.rs). - [`RpqError`](src/rpq/mod.rs) — unified error type for RPQ parsing and evaluation: `Parse` (SPARQL syntax), `Extract` (query extraction), `UnsupportedPath`, `VertexNotFound`, and `Graph` (wraps [`GraphError`](src/graph/mod.rs) for @@ -423,6 +439,18 @@ over label adjacency matrices and runs [`LAGraph_RPQMatrix`]. It returns Subject/object do not filter the matrix; a named subject is only validated to exist. Bound objects are not supported yet ([`RpqError::UnsupportedPath`]). +### CLI dispatch (`src/cli/dispatch.rs`) + +With the `bench` feature enabled, [`src/cli/dispatch.rs`](src/cli/dispatch.rs) is the single +mapping from [`Algo`](src/cli/args.rs) variants to concrete evaluator types. `dispatch_query` +and `dispatch_bench` each perform one exhaustive `match` per requested algorithm, then call +generic runners (`run_query_for_evaluator` and `run_bench_for_evaluator`) that are +monomorphized for the selected evaluator. + +Adding a new algorithm requires a new `Algo` variant, its `Display` arm, one `dispatch_query` +arm, one `dispatch_bench` arm, an `impl Evaluator` for the evaluator type, and an +`impl ResultCount` for any result type used by CLI count reporting. + ### FFI layer [`lagraph_sys`](src/lagraph_sys.rs) exposes raw C bindings for GraphBLAS and diff --git a/src/bin/pathrex.rs b/src/bin/pathrex.rs index 8147cac..91194f0 100644 --- a/src/bin/pathrex.rs +++ b/src/bin/pathrex.rs @@ -20,185 +20,229 @@ //! --output results.json //! ``` -use std::collections::HashSet; -use std::path::Path; -use std::process; +use std::error::Error as StdError; +use std::path::{Path, PathBuf}; use chrono::Utc; use clap::Parser; +use thiserror::Error; -use pathrex::cli::args::{Cli, Commands}; -use pathrex::cli::bench::run_benchmarks; -use pathrex::cli::checkpoint::Checkpoint; -use pathrex::cli::loader::{load_graph, load_queries}; +use pathrex::cli::args::{BenchArgs, Cli, Commands, QueryArgs}; +use pathrex::cli::bench::BenchError; +use pathrex::cli::checkpoint::{Checkpoint, CheckpointError, Checkpointer}; +use pathrex::cli::dispatch::{dispatch_bench, dispatch_query}; +use pathrex::cli::loader::{GraphLoadError, LoadedQuery, load_graph, load_queries}; use pathrex::cli::output::{BenchMetadata, BenchOutput, QueryMetadata, QueryOutput}; -use pathrex::cli::query::run_queries; -use pathrex::graph::GraphDecomposition; +use pathrex::graph::{GraphDecomposition, InMemoryGraph}; + +#[derive(Debug, Error)] +enum MainError { + #[error(transparent)] + Graph(#[from] GraphLoadError), + #[error("error loading queries from '{path}': {source}")] + Queries { + path: String, + #[source] + source: std::io::Error, + }, + #[error(transparent)] + Checkpoint(#[from] CheckpointError), + #[error(transparent)] + Bench(#[from] BenchError), + #[error("error writing output to '{path}': {source}")] + Output { + path: String, + #[source] + source: std::io::Error, + }, +} fn main() { - let cli = Cli::parse(); + if let Err(e) = run() { + eprintln!("Error: {e}"); + // Walk the source chain so users see the underlying cause. + let mut cur: Option<&dyn StdError> = e.source(); + while let Some(c) = cur { + eprintln!(" caused by: {c}"); + cur = c.source(); + } + std::process::exit(1); + } +} +fn run() -> Result<(), MainError> { + let cli = Cli::parse(); match cli.command { - Commands::Query(args) => { - let common = &args.common; - - eprintln!("=== pathrex query ==="); - eprintln!("Graph: {}", common.graph); - eprintln!("Format: {}", common.format); - eprintln!("Queries: {}", common.queries); - eprintln!("Algos: {:?}", common.algo); - eprintln!(); - - eprintln!("[1/2] Loading graph..."); - let graph = load_graph(&common.graph, &common.format, &common.base_iri); - eprintln!(" nodes: {}", graph.num_nodes()); - eprintln!(" labels: {}", graph.num_labels()); - eprintln!(); - - eprintln!("[2/2] Loading and running queries..."); - let queries_path = Path::new(&common.queries); - let queries = load_queries(queries_path, &common.base_iri).unwrap_or_else(|e| { - eprintln!("Error loading queries from '{}': {e}", common.queries); - process::exit(1); - }); - eprintln!(" loaded {} queries", queries.len()); - - let results = run_queries(&args, &graph, &queries); - - // Summary - let errors = results - .iter() - .flat_map(|r| r.algorithms.values()) - .filter(|a| a.status != "ok") - .count(); - eprintln!(); - eprintln!( - "Done. {} queries × {} algos. {errors} error(s).", - results.len(), - common.algo.len() - ); - - // Optional JSON output - if let Some(ref out_path) = args.output { - let output = QueryOutput { - metadata: QueryMetadata { - timestamp: Utc::now().to_rfc3339(), - graph_path: common.graph.clone(), - graph_format: common.format.clone(), - queries_file: common.queries.clone(), - base_iri: common.base_iri.clone(), - num_nodes: graph.num_nodes(), - num_labels: graph.num_labels(), - }, - results, - }; - if let Err(e) = output.write_to_file(Path::new(out_path)) { - eprintln!("Error writing output to '{out_path}': {e}"); - process::exit(1); - } - eprintln!("Results written to: {out_path}"); - } - } + Commands::Query(args) => run_query_cmd(args), + Commands::Bench(args) => run_bench_cmd(args), + } +} - Commands::Bench(args) => { - let common = &args.common; - - eprintln!("=== pathrex bench ==="); - eprintln!("Graph: {}", common.graph); - eprintln!("Format: {}", common.format); - eprintln!("Queries: {}", common.queries); - eprintln!("Algos: {:?}", common.algo); - eprintln!("Batch size: {}", args.batch_size); - eprintln!("Output: {}", args.output); - eprintln!(); - - eprintln!("[1/4] Loading graph..."); - let graph = load_graph(&common.graph, &common.format, &common.base_iri); - eprintln!(" nodes: {}", graph.num_nodes()); - eprintln!(" labels: {}", graph.num_labels()); - eprintln!(); - - eprintln!("[2/4] Loading queries..."); - let queries_path = Path::new(&common.queries); - let queries = load_queries(queries_path, &common.base_iri).unwrap_or_else(|e| { - eprintln!("Error loading queries from '{}': {e}", common.queries); - process::exit(1); - }); - eprintln!(" loaded {} queries", queries.len()); - let parse_errors = queries.iter().filter(|q| q.parsed.is_err()).count(); - if parse_errors > 0 { - eprintln!(" ({parse_errors} queries failed to parse)"); +fn load_query_file(path: &str, base_iri: Option<&str>) -> Result, MainError> { + load_queries(Path::new(path), base_iri).map_err(|e| MainError::Queries { + path: path.to_string(), + source: e, + }) +} + +fn run_query_cmd(args: QueryArgs) -> Result<(), MainError> { + let common = &args.common; + + eprintln!("=== pathrex query ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!(); + + eprintln!("[1/2] Loading graph..."); + let graph: InMemoryGraph = + load_graph(&common.graph, common.format, common.base_iri.as_deref())?; + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/2] Loading and running queries..."); + let queries = load_query_file(&common.queries, common.base_iri.as_deref())?; + eprintln!(" loaded {} queries", queries.len()); + + let results = dispatch_query(&args, &graph, &queries); + + let errors = results + .iter() + .flat_map(|r| r.algorithms.values()) + .filter(|a| !matches!(a.status, pathrex::cli::output::AlgoStatus::Ok)) + .count(); + eprintln!(); + eprintln!( + "Done. {} queries × {} algos. {errors} error(s).", + results.len(), + common.algo.len() + ); + + if let Some(ref out_path) = args.output { + let output = QueryOutput { + metadata: QueryMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.to_string(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + }, + results, + }; + output + .write_to_file(Path::new(out_path)) + .map_err(|e| MainError::Output { + path: out_path.clone(), + source: e, + })?; + eprintln!("Results written to: {out_path}"); + } + + Ok(()) +} + +fn build_checkpointer(args: &BenchArgs, queries_len: usize) -> Result { + let common = &args.common; + let path = PathBuf::from(&args.checkpoint); + + if args.resume { + match Checkpoint::load(&path)? { + Some(cp) => { + cp.validate(&common.graph, &common.queries, &common.algo)?; + let cper = Checkpointer::with_inner(cp, path); + eprintln!( + " resuming: {}/{} queries fully done", + cper.fully_done_count(&common.algo), + queries_len + ); + Ok(cper) } - eprintln!(); - - eprintln!("[3/4] Setting up checkpoint..."); - let checkpoint_path = Path::new(&args.checkpoint); - let mut checkpoint = if args.resume { - match Checkpoint::load(checkpoint_path) { - Ok(Some(cp)) => { - if let Err(e) = cp.validate(&common.graph, &common.queries, &common.algo) { - eprintln!("Checkpoint validation failed: {e}"); - process::exit(1); - } - let done_count = cp - .completed - .iter() - .filter(|c| { - let done: HashSet<_> = c.algorithms_done.iter().collect(); - common.algo.iter().all(|a| done.contains(a)) - }) - .count(); - eprintln!( - " resuming: {done_count}/{} queries fully done", - queries.len() - ); - cp - } - Ok(None) => { - eprintln!(" no checkpoint file found, starting fresh"); - Checkpoint::new(&common.graph, &common.queries, &common.algo) - } - Err(e) => { - eprintln!("Error loading checkpoint: {e}"); - process::exit(1); - } - } - } else { - Checkpoint::new(&common.graph, &common.queries, &common.algo) - }; - eprintln!(); - - eprintln!("[4/4] Running benchmarks..."); - eprintln!(); - let results = run_benchmarks(&args, &graph, &queries, &mut checkpoint, checkpoint_path); - - let output = BenchOutput { - metadata: BenchMetadata { - timestamp: Utc::now().to_rfc3339(), - graph_path: common.graph.clone(), - graph_format: common.format.clone(), - queries_file: common.queries.clone(), - base_iri: common.base_iri.clone(), - num_nodes: graph.num_nodes(), - num_labels: graph.num_labels(), - sample_size: args.sample_size, - warm_up_secs: args.warm_up, - measurement_secs: args.measurement, - batch_size: args.batch_size, - }, - results, - }; - - let output_path = Path::new(&args.output); - if let Err(e) = output.write_to_file(output_path) { - eprintln!("Error writing output to '{}': {e}", args.output); - process::exit(1); + None => { + eprintln!(" no checkpoint file found, starting fresh"); + Ok(Checkpointer::fresh( + &common.graph, + &common.queries, + &common.algo, + path, + )) } - - eprintln!(); - eprintln!("=== Done ==="); - eprintln!("Results written to: {}", args.output); - eprintln!("Criterion data in: {}", args.criterion_dir); } + } else { + Ok(Checkpointer::fresh( + &common.graph, + &common.queries, + &common.algo, + path, + )) + } +} + +fn run_bench_cmd(args: BenchArgs) -> Result<(), MainError> { + let common = &args.common; + + eprintln!("=== pathrex bench ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!("Output: {}", args.output); + eprintln!(); + + eprintln!("[1/4] Loading graph..."); + let graph: InMemoryGraph = + load_graph(&common.graph, common.format, common.base_iri.as_deref())?; + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/4] Loading queries..."); + let queries = load_query_file(&common.queries, common.base_iri.as_deref())?; + eprintln!(" loaded {} queries", queries.len()); + let parse_errors = queries.iter().filter(|q| q.parsed.is_err()).count(); + if parse_errors > 0 { + eprintln!(" ({parse_errors} queries failed to parse)"); } + eprintln!(); + + eprintln!("[3/4] Setting up checkpoint..."); + let mut checkpointer = build_checkpointer(&args, queries.len())?; + eprintln!(); + + eprintln!("[4/4] Running benchmarks..."); + eprintln!(); + let results = dispatch_bench(&args, &graph, &queries, &mut checkpointer)?; + + let output = BenchOutput { + metadata: BenchMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.to_string(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + sample_size: args.sample_size, + warm_up_secs: args.warm_up, + measurement_secs: args.measurement, + }, + results, + }; + + output + .write_to_file(Path::new(&args.output)) + .map_err(|e| MainError::Output { + path: args.output.clone(), + source: e, + })?; + + eprintln!(); + eprintln!("=== Done ==="); + eprintln!("Results written to: {}", args.output); + eprintln!("Criterion data in: {}", args.criterion_dir); + + Ok(()) } diff --git a/src/cli/args.rs b/src/cli/args.rs index 557bfec..0a8528a 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -7,8 +7,9 @@ //! - [`BenchArgs`] — bench-specific args (criterion, checkpoint, …) //! - [`QueryArgs`] — query-specific args (optional output file) //! - [`Algo`] — algorithm identifier enum +//! - [`GraphFormat`] — input graph format enum -use clap::{Args, Parser, Subcommand}; +use clap::{Args, Parser, Subcommand, ValueEnum}; /// Top-level CLI for pathrex. #[derive(Parser, Debug)] @@ -37,20 +38,29 @@ pub struct CommonArgs { #[arg(short = 'g', long)] pub graph: String, - /// Graph format: mm | csv - #[arg(short = 'f', long, default_value = "mm")] - pub format: String, + /// Graph format. + #[arg(short = 'f', long, value_enum, default_value_t = GraphFormat::Mm)] + pub format: GraphFormat, /// Path to queries file (format: `,` per line). #[arg(short = 'q', long)] pub queries: String, - /// Base IRI used when wrapping bare SPARQL patterns. - #[arg(short = 'b', long, default_value = "http://example.org/")] - pub base_iri: String, + /// Optional base IRI prepended to bare SPARQL patterns as `BASE `. + /// Pass without a value (`--base-iri`) to use the default `http://example.org/`. + /// Pass with a value (`--base-iri `) to use a custom IRI. + /// When omitted entirely, no BASE declaration is added to the query. + #[arg( + short = 'b', + long, + num_args = 0..=1, + default_missing_value = "http://example.org/", + require_equals = false + )] + pub base_iri: Option, /// Algorithms to use. - #[arg(short = 'a', long, num_args = 1.., default_values_t = vec![Algo::Nfa, Algo::Rpqmatrix])] + #[arg(short = 'a', long, value_enum, num_args = 1.., required = true)] pub algo: Vec, } @@ -83,11 +93,6 @@ pub struct BenchArgs { #[arg(long)] pub resume: bool, - /// Number of queries per batch. Controls how often results are logged - /// and checkpoints are saved. Default is 1 (checkpoint after every query). - #[arg(long, default_value_t = 1)] - pub batch_size: usize, - /// Directory for criterion output. #[arg(long, default_value = "bench_criterion/")] pub criterion_dir: String, @@ -109,12 +114,12 @@ pub struct BenchArgs { pub measurement: u64, } -/// Algorithm identifiers for RPQ evaluation. -#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, ValueEnum, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "lowercase")] +#[value(rename_all = "lowercase")] pub enum Algo { /// NFA-based evaluator (`LAGraph_RegularPathQuery`). - Nfa, + NfaRpq, /// Matrix-plan evaluator (`LAGraph_RPQMatrix`). Rpqmatrix, } @@ -122,22 +127,47 @@ pub enum Algo { impl std::fmt::Display for Algo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Algo::Nfa => write!(f, "nfa"), + Algo::NfaRpq => write!(f, "nfarpq"), Algo::Rpqmatrix => write!(f, "rpqmatrix"), } } } -impl std::str::FromStr for Algo { - type Err = String; +/// Input graph format. +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[value(rename_all = "lowercase")] +pub enum GraphFormat { + /// MatrixMarket directory layout (vertices.txt, edges.txt, *.txt). + Mm, + /// CSV file with source/target/label columns. + Csv, +} - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "nfa" => Ok(Algo::Nfa), - "rpqmatrix" => Ok(Algo::Rpqmatrix), - other => Err(format!( - "unknown algorithm: '{other}' (expected: nfa, rpqmatrix)" - )), +impl std::fmt::Display for GraphFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GraphFormat::Mm => write!(f, "mm"), + GraphFormat::Csv => write!(f, "csv"), } } } + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + + #[test] + fn query_requires_at_least_one_algo() { + let result = Cli::try_parse_from([ + "pathrex", + "query", + "--graph", + "graph", + "--queries", + "queries", + ]); + + assert!(result.is_err()); + } +} diff --git a/src/cli/bench.rs b/src/cli/bench.rs deleted file mode 100644 index 6c37bf7..0000000 --- a/src/cli/bench.rs +++ /dev/null @@ -1,416 +0,0 @@ -//! Core benchmark loop and criterion integration for the `bench` subcommand. - -use std::collections::HashMap; -use std::fs::File; -use std::path::Path; -use std::time::Duration; - -use criterion::{Criterion, black_box}; - -use crate::graph::InMemoryGraph; -use crate::rpq::nfarpq::{NfaRpqEvaluator, PreparedNfaRpq}; -use crate::rpq::rpqmatrix::{PreparedRpqMatrix, RpqMatrixEvaluator}; -use crate::rpq::{RpqError, RpqEvaluator, RpqQuery}; - -use super::args::{Algo, BenchArgs}; -use super::checkpoint::Checkpoint; -use super::loader::LoadedQuery; -use super::output::{AlgoResult, AlgoTiming, BatchResult, QueryResult, TimingStats}; - -/// Run a single evaluation and return the result count. -/// -/// Used by the `query` subcommand runner. -pub(crate) fn run_once( - algo: &Algo, - query: &RpqQuery, - graph: &InMemoryGraph, -) -> Result { - match algo { - Algo::Nfa => { - let result = NfaRpqEvaluator.evaluate(query, graph)?; - let count = result - .reachable - .nvals() - .map_err(crate::rpq::RpqError::Graph)? as usize; - Ok(count) - } - Algo::Rpqmatrix => { - let result = RpqMatrixEvaluator.evaluate(query, graph)?; - let count = result.reachable_target_count().map_err(RpqError::Graph)? as usize; - Ok(count) - } - } -} - -/// Run a batch of queries for a single algorithm (discards result counts). -/// -/// Used inside criterion's measurement loop; returning counts would be -/// optimised away anyway, but we keep the call realistic with `black_box`. -fn run_batch_total( - algo: &Algo, - queries: &[&RpqQuery], - graph: &InMemoryGraph, -) -> Result<(), RpqError> { - for query in queries { - let _ = black_box(run_once(algo, query, graph))?; - } - Ok(()) -} - -enum PreparedBatch { - Nfa(Vec), - Rpqmatrix(Vec), -} - -fn prepare_batch( - algo: &Algo, - queries: &[&RpqQuery], - graph: &InMemoryGraph, -) -> Result { - match algo { - Algo::Nfa => Ok(PreparedBatch::Nfa( - queries - .iter() - .map(|query| NfaRpqEvaluator.prepare(query, graph)) - .collect::, _>>()?, - )), - Algo::Rpqmatrix => Ok(PreparedBatch::Rpqmatrix( - queries - .iter() - .map(|query| RpqMatrixEvaluator.prepare(query, graph)) - .collect::, _>>()?, - )), - } -} - -fn run_prepared_batch(prepared: &mut PreparedBatch) -> Result<(), RpqError> { - match prepared { - PreparedBatch::Nfa(items) => { - for item in items { - let result = item.execute()?; - let count = result - .reachable - .nvals() - .map_err(crate::rpq::RpqError::Graph)? as usize; - let _ = black_box(count); - } - } - PreparedBatch::Rpqmatrix(items) => { - for item in items { - let result = item.execute()?; - let count = result - .reachable_target_count() - .map_err(crate::rpq::RpqError::Graph)? as usize; - let _ = black_box(count); - } - } - } - - Ok(()) -} - -/// Read criterion timing estimates from its output directory. -/// -/// After `group.finish()`, criterion writes: -/// `///new/estimates.json` -fn read_criterion_estimates( - criterion_dir: &str, - group_name: &str, - bench_name: &str, -) -> Option { - let path = Path::new(criterion_dir) - .join(group_name) - .join(bench_name) - .join("new") - .join("estimates.json"); - - let file = File::open(&path).ok()?; - let data: serde_json::Value = serde_json::from_reader(file).ok()?; - - let mean_ns = data["mean"]["point_estimate"].as_f64()?; - let median_ns = data["median"]["point_estimate"].as_f64()?; - let stddev_ns = data["std_dev"]["point_estimate"].as_f64()?; - - // Read sample count from sample.json if available. - let sample_path = Path::new(criterion_dir) - .join(group_name) - .join(bench_name) - .join("new") - .join("sample.json"); - - let iterations = File::open(&sample_path) - .ok() - .and_then(|f| serde_json::from_reader::<_, serde_json::Value>(f).ok()) - .and_then(|v| v["iters"].as_array().map(|a| a.len())) - .unwrap_or(0); - - Some(TimingStats { - mean_ns, - median_ns, - stddev_ns, - iterations, - }) -} - -fn read_algo_timing_estimates(criterion_dir: &str, group_name: &str) -> Option { - let total = read_criterion_estimates(criterion_dir, group_name, "eval_total")?; - let ffi_only = read_criterion_estimates(criterion_dir, group_name, "eval_ffi_only")?; - - Some(AlgoTiming { total, ffi_only }) -} - -/// Run the full benchmark loop, processing queries in batches. -/// -/// Queries are grouped into batches of `batch_size`. For each batch and -/// algorithm, criterion benchmarks the entire batch as a single unit -/// (all queries run sequentially per iteration). -/// After each batch the checkpoint is saved. -pub fn run_benchmarks( - args: &BenchArgs, - graph: &InMemoryGraph, - queries: &[LoadedQuery], - checkpoint: &mut Checkpoint, - checkpoint_path: &Path, -) -> Vec { - let criterion = Criterion::default() - .sample_size(args.sample_size) - .warm_up_time(Duration::from_secs(args.warm_up)) - .measurement_time(Duration::from_secs(args.measurement)) - .output_directory(Path::new(&args.criterion_dir)); - - let mut criterion = if args.plots { - criterion.with_plots() - } else { - criterion.without_plots() - }; - - let batch_size = args.batch_size.max(1); - let mut batch_results: Vec = Vec::new(); - - // Collect queries that still need work. - let active_queries: Vec<(usize, &LoadedQuery)> = queries - .iter() - .enumerate() - .filter(|(idx, loaded)| { - if checkpoint.is_fully_done(*idx, &args.common.algo) { - eprintln!( - "[skip] query #{} (id={}) — all algorithms done", - idx, loaded.id - ); - false - } else { - true - } - }) - .collect(); - - for (batch_index, batch) in active_queries.chunks(batch_size).enumerate() { - let batch_indices: Vec = batch.iter().map(|(idx, _)| *idx).collect(); - let batch_ids: Vec<&str> = batch.iter().map(|(_, l)| l.id.as_str()).collect(); - - eprintln!( - "\n[batch {}] queries {:?} (ids: {:?})", - batch_index, batch_indices, batch_ids - ); - - // Separate parse errors from valid queries. - let mut per_query_results: Vec = Vec::new(); - // algo key → list of (query_index, query_ref) - let mut valid_queries_per_algo: HashMap> = HashMap::new(); - - for &(idx, loaded) in batch { - let mut algo_results: HashMap = HashMap::new(); - - let query = match &loaded.parsed { - Ok(q) => q, - Err(e) => { - eprintln!( - " [error] query #{} (id={}) parse error: {}", - idx, loaded.id, e - ); - for algo in &args.common.algo { - if !checkpoint.is_algo_done(idx, algo) { - algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); - checkpoint.mark_algo_done(idx, &loaded.id, algo); - } - } - per_query_results.push(QueryResult { - query_index: idx, - query_id: loaded.id.clone(), - query_text: loaded.text.clone(), - algorithms: algo_results, - }); - continue; - } - }; - - for algo in &args.common.algo { - if !checkpoint.is_algo_done(idx, algo) { - valid_queries_per_algo - .entry(algo.to_string()) - .or_default() - .push((idx, query)); - } - } - - per_query_results.push(QueryResult { - query_index: idx, - query_id: loaded.id.clone(), - query_text: loaded.text.clone(), - algorithms: algo_results, - }); - } - - // ── Criterion benchmark per algo over valid queries ── - let mut batch_algo_timing: HashMap> = HashMap::new(); - - for algo in &args.common.algo { - let algo_key = algo.to_string(); - let Some(valid) = valid_queries_per_algo.get(&algo_key) else { - continue; - }; - if valid.is_empty() { - continue; - } - - eprintln!( - " [bench] algo={} — benchmarking {} queries as batch...", - algo, - valid.len() - ); - - let group_name = format!("batch{}_{}", batch_index, algo); - let mut group = criterion.benchmark_group(&group_name); - - let algo_clone = algo.clone(); - let queries_clone: Vec = valid.iter().map(|(_, q)| (*q).clone()).collect(); - - group.bench_function("eval_total", |b| { - b.iter(|| { - let refs: Vec<&RpqQuery> = queries_clone.iter().collect(); - let _ = black_box(run_batch_total(&algo_clone, &refs, graph)); - }); - }); - - group.bench_function("eval_ffi_only", |b| { - let refs: Vec<&RpqQuery> = queries_clone.iter().collect(); - let mut prepared = - prepare_batch(&algo_clone, &refs, graph).expect("prepare benchmark batch"); - b.iter(|| { - let _ = black_box(run_prepared_batch(&mut prepared)); - }); - }); - group.finish(); - - let timing = read_algo_timing_estimates(&args.criterion_dir, &group_name); - batch_algo_timing.insert(algo_key, timing); - } - - // Assign timing to each query's algo result. - for qr in &mut per_query_results { - for algo in &args.common.algo { - let algo_key = algo.to_string(); - // Only fill in queries that didn't already get a parse error result. - if qr.algorithms.contains_key(&algo_key) { - continue; - } - let timing = batch_algo_timing - .get(&algo_key) - .and_then(|t| t.as_ref()) - .map(|t| AlgoTiming { - total: TimingStats { - mean_ns: t.total.mean_ns, - median_ns: t.total.median_ns, - stddev_ns: t.total.stddev_ns, - iterations: t.total.iterations, - }, - ffi_only: TimingStats { - mean_ns: t.ffi_only.mean_ns, - median_ns: t.ffi_only.median_ns, - stddev_ns: t.ffi_only.stddev_ns, - iterations: t.ffi_only.iterations, - }, - }); - qr.algorithms - .insert(algo_key.clone(), AlgoResult::ok(None, timing)); - } - } - - // Mark all queries in this batch as done. - for &(idx, loaded) in batch { - for algo in &args.common.algo { - checkpoint.mark_algo_done(idx, &loaded.id, algo); - } - } - - if let Err(e) = checkpoint.save(checkpoint_path) { - eprintln!("[warn] failed to save checkpoint: {e}"); - } - - batch_results.push(BatchResult { - batch_index, - query_indices: batch_indices, - queries: per_query_results, - }); - } - - batch_results -} - -#[cfg(test)] -mod tests { - use super::*; - use std::fs; - - use crate::rpq::{Endpoint, PathExpr, RpqQuery}; - use crate::utils::build_graph; - - fn write_estimate_files(base: &Path, bench_name: &str, mean_ns: f64, iterations: usize) { - let bench_dir = base.join("batch0_nfa").join(bench_name).join("new"); - fs::create_dir_all(&bench_dir).expect("create bench dir"); - fs::write( - bench_dir.join("estimates.json"), - format!( - r#"{{"mean":{{"point_estimate":{mean_ns}}},"median":{{"point_estimate":{mean_ns}}},"std_dev":{{"point_estimate":0.0}}}}"# - ), - ) - .expect("write estimates"); - let sample = format!("{{\"iters\":[{}]}}", vec!["1"; iterations].join(",")); - fs::write(bench_dir.join("sample.json"), sample).expect("write sample"); - } - - #[test] - fn read_split_criterion_estimates() { - let dir = tempfile::tempdir().expect("tempdir"); - write_estimate_files(dir.path(), "eval_total", 10.0, 3); - write_estimate_files(dir.path(), "eval_ffi_only", 4.0, 5); - - let timing = - read_algo_timing_estimates(dir.path().to_str().expect("utf8 path"), "batch0_nfa") - .expect("split timing"); - - assert_eq!(timing.total.mean_ns, 10.0); - assert_eq!(timing.total.iterations, 3); - assert_eq!(timing.ffi_only.mean_ns, 4.0); - assert_eq!(timing.ffi_only.iterations, 5); - } - - #[test] - fn run_once_rpqmatrix_count_matches_nfa_reachable_targets() { - let graph = build_graph(&[("A", "B", "p"), ("C", "B", "p")]); - let query = RpqQuery { - subject: Endpoint::Variable("x".into()), - path: PathExpr::Label("p".into()), - object: Endpoint::Variable("y".into()), - }; - - let rpqmatrix_count = run_once(&Algo::Rpqmatrix, &query, &graph).expect("rpqmatrix count"); - let nfa_count = run_once(&Algo::Nfa, &query, &graph).expect("nfa count"); - - assert_eq!( - rpqmatrix_count, 1, - "shared count should report reachable target count" - ); - assert_eq!(nfa_count, rpqmatrix_count); - } -} diff --git a/src/cli/bench/error.rs b/src/cli/bench/error.rs new file mode 100644 index 0000000..3514b11 --- /dev/null +++ b/src/cli/bench/error.rs @@ -0,0 +1,21 @@ +//! Error type for the bench pipeline. + +use thiserror::Error; + +use crate::cli::checkpoint::CheckpointError; + +#[derive(Debug, Error)] +pub enum BenchError { + #[error("criterion estimates missing for group '{0}' (file not found or unreadable)")] + MissingEstimates(String), + + #[error("criterion estimates parse error for group '{group}': {source}")] + EstimatesParse { + group: String, + #[source] + source: serde_json::Error, + }, + + #[error("checkpoint error: {0}")] + Checkpoint(#[from] CheckpointError), +} diff --git a/src/cli/bench/estimates.rs b/src/cli/bench/estimates.rs new file mode 100644 index 0000000..1cf72cb --- /dev/null +++ b/src/cli/bench/estimates.rs @@ -0,0 +1,149 @@ +//! Typed deserialization of criterion's per-benchmark JSON output. +//! +//! After `group.finish()`, criterion writes: +//! +//! ```text +//! ///new/estimates.json +//! ///new/sample.json +//! ``` +//! +//! We read both to extract a [`TimingStats`]. + +use std::fs::File; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +use crate::cli::bench::error::BenchError; +use crate::cli::output::{AlgoTiming, TimingStats}; + +#[derive(Deserialize)] +struct Estimates { + mean: PointEstimate, + median: PointEstimate, + std_dev: PointEstimate, +} + +#[derive(Deserialize)] +struct PointEstimate { + point_estimate: f64, +} + +#[derive(Deserialize)] +struct Sample { + iters: Vec, +} + +fn estimates_path(criterion_dir: &Path, group: &str, bench: &str) -> PathBuf { + criterion_dir + .join(group) + .join(bench) + .join("new") + .join("estimates.json") +} + +fn sample_path(criterion_dir: &Path, group: &str, bench: &str) -> PathBuf { + criterion_dir + .join(group) + .join(bench) + .join("new") + .join("sample.json") +} + +/// Read one bench's timing stats out of criterion's output directory. +/// +/// Returns `Err(MissingEstimates)` if the JSON file isn't present (e.g. the +/// criterion run was interrupted) and `Err(EstimatesParse)` on shape +/// mismatches — making schema drift loud rather than silent. +pub fn read_timing_stats( + criterion_dir: &Path, + group: &str, + bench: &str, +) -> Result { + let est_path = estimates_path(criterion_dir, group, bench); + let file = File::open(&est_path) + .map_err(|_| BenchError::MissingEstimates(format!("{}/{}", group, bench)))?; + let est: Estimates = serde_json::from_reader(file).map_err(|e| BenchError::EstimatesParse { + group: format!("{}/{}", group, bench), + source: e, + })?; + + // sample.json is best-effort; if missing we report iterations=0. + let iterations = File::open(sample_path(criterion_dir, group, bench)) + .ok() + .and_then(|f| serde_json::from_reader::<_, Sample>(f).ok()) + .map(|s| s.iters.len()) + .unwrap_or(0); + + Ok(TimingStats { + mean_ns: est.mean.point_estimate, + median_ns: est.median.point_estimate, + stddev_ns: est.std_dev.point_estimate, + iterations, + }) +} + +/// Read both `eval_total` and `eval_ffi_only` benches for a single group. +pub fn read_algo_timing(criterion_dir: &Path, group: &str) -> Result { + let total = read_timing_stats(criterion_dir, group, "eval_total")?; + let ffi_only = read_timing_stats(criterion_dir, group, "eval_ffi_only")?; + Ok(AlgoTiming { total, ffi_only }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + fn write_estimate_files(base: &Path, group: &str, bench: &str, mean_ns: f64, iters: usize) { + let bench_dir = base.join(group).join(bench).join("new"); + fs::create_dir_all(&bench_dir).expect("create bench dir"); + fs::write( + bench_dir.join("estimates.json"), + format!( + r#"{{"mean":{{"point_estimate":{mean_ns}}},"median":{{"point_estimate":{mean_ns}}},"std_dev":{{"point_estimate":0.0}}}}"# + ), + ) + .expect("write estimates"); + let iters_array: Vec<&str> = vec!["1.0"; iters]; + let sample = format!("{{\"iters\":[{}]}}", iters_array.join(",")); + fs::write(bench_dir.join("sample.json"), sample).expect("write sample"); + } + + #[test] + fn reads_split_estimates() { + let dir = tempfile::tempdir().expect("tempdir"); + write_estimate_files(dir.path(), "query_0_nfa", "eval_total", 10.0, 3); + write_estimate_files(dir.path(), "query_0_nfa", "eval_ffi_only", 4.0, 5); + + let timing = read_algo_timing(dir.path(), "query_0_nfa").expect("split timing"); + assert_eq!(timing.total.mean_ns, 10.0); + assert_eq!(timing.total.iterations, 3); + assert_eq!(timing.ffi_only.mean_ns, 4.0); + assert_eq!(timing.ffi_only.iterations, 5); + } + + #[test] + fn missing_file_is_reported() { + let dir = tempfile::tempdir().expect("tempdir"); + let err = read_timing_stats(dir.path(), "missing", "eval_total").unwrap_err(); + match err { + BenchError::MissingEstimates(g) => assert!(g.contains("missing")), + other => panic!("unexpected error: {other}"), + } + } + + #[test] + fn malformed_json_is_reported() { + let dir = tempfile::tempdir().expect("tempdir"); + let bench_dir = dir.path().join("g").join("b").join("new"); + fs::create_dir_all(&bench_dir).unwrap(); + fs::write(bench_dir.join("estimates.json"), "{not json").unwrap(); + + let err = read_timing_stats(dir.path(), "g", "b").unwrap_err(); + match err { + BenchError::EstimatesParse { .. } => {} + other => panic!("unexpected error: {other}"), + } + } +} diff --git a/src/cli/bench/mod.rs b/src/cli/bench/mod.rs new file mode 100644 index 0000000..dc410ee --- /dev/null +++ b/src/cli/bench/mod.rs @@ -0,0 +1,7 @@ +//! Benchmarking subsystem for the `pathrex bench` subcommand. + +pub mod error; +pub mod estimates; +pub mod runner; + +pub use error::BenchError; diff --git a/src/cli/bench/runner.rs b/src/cli/bench/runner.rs new file mode 100644 index 0000000..b16a9c1 --- /dev/null +++ b/src/cli/bench/runner.rs @@ -0,0 +1,163 @@ +use std::collections::HashMap; +use std::path::Path; +use std::time::Duration; + +use criterion::{Criterion, black_box}; + +use crate::cli::args::{Algo, BenchArgs}; +use crate::cli::bench::error::BenchError; +use crate::cli::bench::estimates::read_algo_timing; +use crate::cli::checkpoint::Checkpointer; +use crate::cli::loader::LoadedQuery; +use crate::cli::output::{AlgoResult, QueryResult}; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::InMemoryGraph; +use crate::rpq::{RpqError, RpqQuery}; + +/// Build a criterion instance from CLI bench args. +pub(crate) fn build_criterion(args: &BenchArgs) -> Criterion { + let c = Criterion::default() + .sample_size(args.sample_size) + .warm_up_time(Duration::from_secs(args.warm_up)) + .measurement_time(Duration::from_secs(args.measurement)) + .output_directory(Path::new(&args.criterion_dir)); + if args.plots { + c.with_plots() + } else { + c.without_plots() + } +} + +fn group_name(query_index: usize, algo_id: &str) -> String { + format!("query_{query_index}_{algo_id}") +} + +fn run_benchmark_group( + criterion: &mut Criterion, + args: &BenchArgs, + algo_name: &str, + evaluator: E, + query: &RpqQuery, + graph: &InMemoryGraph, + query_index: usize, +) -> Result, RpqError> +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ + let mut prepared = evaluator.prepare(query, graph)?; + let group = group_name(query_index, algo_name); + + { + let mut g = criterion.benchmark_group(&group); + + g.bench_function("eval_total", |b| { + b.iter(|| { + let _ = black_box(evaluator.evaluate(query, graph)); + }); + }); + + g.bench_function("eval_ffi_only", |b| { + b.iter(|| { + let _ = black_box(prepared.execute()); + }); + }); + + g.finish(); + } + + Ok(read_algo_timing(Path::new(&args.criterion_dir), &group)) +} + +/// Run the bench loop for every query in `queries` for one evaluator. +/// +/// One criterion group is created per `(query, algo)` pair; checkpoint persists +/// after every algo completion. +pub fn run_bench_for_evaluator( + args: &BenchArgs, + algo: &Algo, + algo_name: &str, + evaluator: E, + graph: &InMemoryGraph, + queries: &[LoadedQuery], + checkpointer: &mut Checkpointer, + criterion: &mut Criterion, +) -> Result, BenchError> +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ + let mut results = Vec::with_capacity(queries.len()); + + for (idx, loaded) in queries.iter().enumerate() { + if checkpointer.is_algo_done(idx, algo) { + eprintln!( + " [skip] query #{} id={} algo={algo_name} already done", + idx, loaded.id + ); + continue; + } + + let mut algorithms: HashMap = HashMap::new(); + + let query = match &loaded.parsed { + Ok(q) => q, + Err(e) => { + let msg = e.to_string(); + eprintln!( + " [error] query #{} (id={}) algo={} parse error: {}", + idx, loaded.id, algo_name, msg + ); + algorithms.insert(algo_name.to_string(), AlgoResult::error(msg)); + checkpointer.mark_and_save(idx, algo)?; + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms, + }); + continue; + } + }; + + eprintln!("[query #{}] id={}", idx, loaded.id); + eprintln!(" [bench] algo={algo_name}"); + + match run_benchmark_group(criterion, args, algo_name, evaluator, query, graph, idx) { + Ok(Ok(timing)) => { + algorithms.insert(algo_name.to_string(), AlgoResult::ok(None, Some(timing))); + } + Ok(Err(e)) => return Err(e), + Err(e) => { + eprintln!( + " [error] query #{} (id={}) algo={} prepare error: {}", + idx, loaded.id, algo_name, e + ); + algorithms.insert(algo_name.to_string(), AlgoResult::error(e.to_string())); + } + } + + checkpointer.mark_and_save(idx, algo)?; + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms, + }); + } + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn group_name_is_filesystem_safe() { + let g = group_name(7, "nfa"); + assert_eq!(g, "query_7_nfa"); + assert!(!g.contains('/')); + assert!(!g.contains(' ')); + } +} diff --git a/src/cli/checkpoint.rs b/src/cli/checkpoint.rs index 06008fa..f59f3bb 100644 --- a/src/cli/checkpoint.rs +++ b/src/cli/checkpoint.rs @@ -2,38 +2,36 @@ //! //! After each query-algorithm pair completes, the checkpoint file is updated //! so that a crashed run can be resumed from the last completed point. +//! +//! Two layers: +//! - [`Checkpoint`] — pure data; serialised to disk. +//! - [`Checkpointer`] — runtime owner; pairs the data with its file path +//! and exposes a fallible `mark_and_save`. Errors propagate; no silent +//! corruption. use std::collections::HashSet; use std::fs; -use std::path::Path; +use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; +use thiserror::Error; use super::args::Algo; /// Persistent checkpoint state written to disk as JSON. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Checkpoint { - /// Schema version (always 1 for now). pub version: u32, - /// The graph path used for this benchmark run. pub graph_path: String, - /// The queries file used for this benchmark run. pub queries_file: String, - /// The algorithms requested for this benchmark run. pub algorithms: Vec, - /// Per-query completion records. pub completed: Vec, } /// Tracks which algorithms have been completed for a single query. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryCompletion { - /// Zero-based index of the query in the queries file. pub query_index: usize, - /// The query ID from the file (the number before the comma). - pub query_id: String, - /// Which algorithms have finished for this query. pub algorithms_done: Vec, } @@ -91,7 +89,7 @@ impl Checkpoint { Ok(()) } - /// Save the checkpoint to disk (atomic write via temp file + rename). + /// Save the checkpoint to disk. pub fn save(&self, path: &Path) -> Result<(), CheckpointError> { let json = serde_json::to_string_pretty(self).map_err(CheckpointError::Serialize)?; @@ -113,7 +111,6 @@ impl Checkpoint { algos.iter().all(|a| done.contains(a)) } - /// Check if a specific algorithm is done for a given query index. pub fn is_algo_done(&self, query_index: usize, algo: &Algo) -> bool { self.completed .iter() @@ -122,8 +119,7 @@ impl Checkpoint { .unwrap_or(false) } - /// Mark an algorithm as completed for a given query. - pub fn mark_algo_done(&mut self, query_index: usize, query_id: &str, algo: &Algo) { + pub fn mark_algo_done(&mut self, query_index: usize, algo: &Algo) { if let Some(entry) = self .completed .iter_mut() @@ -135,35 +131,76 @@ impl Checkpoint { } else { self.completed.push(QueryCompletion { query_index, - query_id: query_id.to_string(), algorithms_done: vec![algo.clone()], }); } } } +/// Runtime owner for a [`Checkpoint`] paired with its on-disk path. +pub struct Checkpointer { + inner: Checkpoint, + path: PathBuf, +} + +impl Checkpointer { + /// Create a new checkpointer with no completions. + pub fn fresh(graph_path: &str, queries_file: &str, algorithms: &[Algo], path: PathBuf) -> Self { + Self { + inner: Checkpoint::new(graph_path, queries_file, algorithms), + path, + } + } + + /// Wrap an existing [`Checkpoint`] (e.g. one loaded from disk). + pub fn with_inner(inner: Checkpoint, path: PathBuf) -> Self { + Self { inner, path } + } + + /// Number of queries that have *all* requested algorithms done. + pub fn fully_done_count(&self, algos: &[Algo]) -> usize { + self.inner + .completed + .iter() + .filter(|c| { + let done: HashSet<&Algo> = c.algorithms_done.iter().collect(); + algos.iter().all(|a| done.contains(a)) + }) + .count() + } + + pub fn is_fully_done(&self, query_index: usize, algos: &[Algo]) -> bool { + self.inner.is_fully_done(query_index, algos) + } + + pub fn is_algo_done(&self, query_index: usize, algo: &Algo) -> bool { + self.inner.is_algo_done(query_index, algo) + } + + /// Mark `(query_index, algo)` complete and persist atomically. + pub fn mark_and_save( + &mut self, + query_index: usize, + algo: &Algo, + ) -> Result<(), CheckpointError> { + self.inner.mark_algo_done(query_index, algo); + self.inner.save(&self.path) + } +} + /// Errors that can occur during checkpoint operations. -#[derive(Debug)] +#[derive(Debug, Error)] pub enum CheckpointError { /// I/O error reading or writing the checkpoint file. + #[error("checkpoint I/O error ({0}): {1}")] Io(String, std::io::Error), /// JSON parsing error. + #[error("checkpoint parse error ({0}): {1}")] Parse(String, serde_json::Error), /// JSON serialization error. + #[error("checkpoint serialize error: {0}")] Serialize(serde_json::Error), /// Checkpoint parameters don't match current run. + #[error("checkpoint mismatch: {0}")] Mismatch(String), } - -impl std::fmt::Display for CheckpointError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CheckpointError::Io(path, e) => write!(f, "checkpoint I/O error ({path}): {e}"), - CheckpointError::Parse(path, e) => write!(f, "checkpoint parse error ({path}): {e}"), - CheckpointError::Serialize(e) => write!(f, "checkpoint serialize error: {e}"), - CheckpointError::Mismatch(msg) => write!(f, "checkpoint mismatch: {msg}"), - } - } -} - -impl std::error::Error for CheckpointError {} diff --git a/src/cli/dispatch.rs b/src/cli/dispatch.rs new file mode 100644 index 0000000..7263e83 --- /dev/null +++ b/src/cli/dispatch.rs @@ -0,0 +1,111 @@ +//! Typed dispatch from CLI algorithm choices to concrete evaluators. + +use crate::cli::args::{Algo, BenchArgs, QueryArgs}; +use crate::cli::bench::error::BenchError; +use crate::cli::bench::runner::{build_criterion, run_bench_for_evaluator}; +use crate::cli::checkpoint::Checkpointer; +use crate::cli::loader::LoadedQuery; +use crate::cli::output::QueryResult; +use crate::cli::query::run_query_for_evaluator; +use crate::graph::InMemoryGraph; +use crate::rpq::nfarpq::NfaRpqEvaluator; +use crate::rpq::rpqmatrix::RpqMatrixEvaluator; + +fn merge_results(all: &mut Vec, per_algo: Vec) { + for result in per_algo { + if let Some(existing) = all + .iter_mut() + .find(|existing| existing.query_index == result.query_index) + { + existing.algorithms.extend(result.algorithms); + } else { + all.push(result); + } + } +} + +pub fn dispatch_query( + args: &QueryArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], +) -> Vec { + let mut all = Vec::new(); + + for algo in &args.common.algo { + let name = algo.to_string(); + let per_algo = match algo { + Algo::NfaRpq => run_query_for_evaluator(&name, NfaRpqEvaluator, graph, queries), + Algo::Rpqmatrix => run_query_for_evaluator(&name, RpqMatrixEvaluator, graph, queries), + }; + merge_results(&mut all, per_algo); + } + + all +} + +pub fn dispatch_bench( + args: &BenchArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], + checkpointer: &mut Checkpointer, +) -> Result, BenchError> { + let mut criterion = build_criterion(args); + let mut all = Vec::new(); + + for algo in &args.common.algo { + let name = algo.to_string(); + let per_algo = match algo { + Algo::NfaRpq => run_bench_for_evaluator( + args, + algo, + &name, + NfaRpqEvaluator, + graph, + queries, + checkpointer, + &mut criterion, + )?, + Algo::Rpqmatrix => run_bench_for_evaluator( + args, + algo, + &name, + RpqMatrixEvaluator, + graph, + queries, + checkpointer, + &mut criterion, + )?, + }; + merge_results(&mut all, per_algo); + } + + Ok(all) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cli::output::AlgoResult; + use std::collections::HashMap; + + #[test] + fn merge_results_combines_algorithms_by_query_index() { + let mut all = vec![QueryResult { + query_index: 0, + query_id: "q0".into(), + query_text: "query".into(), + algorithms: HashMap::from([("nfarpq".into(), AlgoResult::ok(Some(1), None))]), + }]; + let per_algo = vec![QueryResult { + query_index: 0, + query_id: "q0".into(), + query_text: "query".into(), + algorithms: HashMap::from([("rpqmatrix".into(), AlgoResult::ok(Some(1), None))]), + }]; + + merge_results(&mut all, per_algo); + + assert_eq!(all.len(), 1); + assert_eq!(all[0].algorithms.len(), 2); + } +} diff --git a/src/cli/loader.rs b/src/cli/loader.rs index 81df9fd..938d060 100644 --- a/src/cli/loader.rs +++ b/src/cli/loader.rs @@ -1,76 +1,90 @@ //! Graph and query loading for the `pathrex` CLI. //! //! Both subcommands (`bench` and `query`) need to load a graph and a queries -//! file. This module centralises that I/O so neither subcommand runner -//! duplicates it. +//! file. use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::Path; -use std::process; -use crate::formats::Csv; +use thiserror::Error; + use crate::formats::mm::MatrixMarket; -use crate::graph::{Graph, InMemory, InMemoryGraph}; +use crate::formats::Csv; +use crate::graph::{Graph, GraphError, InMemory, InMemoryGraph}; use crate::rpq::{RpqError, RpqQuery}; use crate::sparql::parse_rpq; -// ── Graph loading ──────────────────────────────────────────────────────────── +use super::args::GraphFormat; + +#[derive(Debug, Error)] +pub enum GraphLoadError { + #[error("error opening graph at '{path}': {source}")] + Open { + path: String, + #[source] + source: std::io::Error, + }, + #[error("error loading graph from '{path}': {source}")] + Build { + path: String, + #[source] + source: GraphError, + }, +} /// Load an [`InMemoryGraph`] from `graph_path` in the given `format`. -/// -/// Prints an error message and exits the process on failure, which is -/// appropriate for a CLI entry point. -pub fn load_graph(graph_path: &str, format: &str, base_iri: &str) -> InMemoryGraph { +pub fn load_graph( + graph_path: &str, + format: GraphFormat, + base_iri: Option<&str>, +) -> Result { match format { - "mm" => { - let mm = MatrixMarket::from_dir(graph_path).with_base_iri(base_iri); - Graph::::try_from(mm).unwrap_or_else(|e| { - eprintln!("Error loading MatrixMarket graph from '{graph_path}': {e}"); - process::exit(1); + GraphFormat::Mm => { + let mm_base = MatrixMarket::from_dir(graph_path); + let mm = match base_iri { + Some(iri) => mm_base.with_base_iri(iri), + None => mm_base, + }; + Graph::::try_from(mm).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, }) } - "csv" => { - let file = File::open(graph_path).unwrap_or_else(|e| { - eprintln!("Error opening CSV file '{graph_path}': {e}"); - process::exit(1); - }); - let csv_source = Csv::from_reader(file).unwrap_or_else(|e| { - eprintln!("Error creating CSV reader for '{graph_path}': {e}"); - process::exit(1); - }); - Graph::::try_from(csv_source).unwrap_or_else(|e| { - eprintln!("Error loading CSV graph from '{graph_path}': {e}"); - process::exit(1); + GraphFormat::Csv => { + let file = File::open(graph_path).map_err(|e| GraphLoadError::Open { + path: graph_path.to_string(), + source: e, + })?; + let csv_source = Csv::from_reader(file).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e.into(), + })?; + Graph::::try_from(csv_source).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, }) } - other => { - eprintln!("Unknown graph format: '{other}' (expected: mm, csv)"); - process::exit(1); - } } } -// ── Query loading ───────────────────────────────────────────────────────────── - -/// A single loaded query with its metadata. #[derive(Debug)] pub struct LoadedQuery { - /// The ID from the query file (the part before the first comma). pub id: String, - /// The raw SPARQL pattern text (the part after the first comma). pub text: String, - /// The parsed RPQ query, or an error if parsing failed. pub parsed: Result, } /// Load and parse queries from a file. /// /// Each non-empty line must have the format `,`. -/// The pattern is wrapped into a full SPARQL query: -/// `BASE <{base_iri}> SELECT * WHERE { {pattern} . }` -/// before parsing, matching the convention used in integration tests. -pub fn load_queries(path: &Path, base_iri: &str) -> Result, std::io::Error> { +/// The pattern is wrapped into a full SPARQL query before parsing: +/// - When `base_iri` is `Some(iri)`: `BASE <{iri}> SELECT * WHERE { {pattern} . }` +/// - When `base_iri` is `None`: `SELECT * WHERE { {pattern} . }` +pub fn load_queries( + path: &Path, + base_iri: Option<&str>, +) -> Result, std::io::Error> { let file = File::open(path)?; let reader = BufReader::new(file); let mut queries = Vec::new(); @@ -96,7 +110,10 @@ pub fn load_queries(path: &Path, base_iri: &str) -> Result, std } }; - let sparql = format!("BASE <{base_iri}> SELECT * WHERE {{ {pattern} . }}"); + let sparql = match base_iri { + Some(iri) => format!("BASE <{iri}> SELECT * WHERE {{ {pattern} . }}"), + None => format!("SELECT * WHERE {{ {pattern} . }}"), + }; let parsed = parse_rpq(&sparql); queries.push(LoadedQuery { diff --git a/src/cli/mod.rs b/src/cli/mod.rs index c7d3728..9e2cb96 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -10,6 +10,7 @@ pub mod args; pub mod bench; pub mod checkpoint; +pub mod dispatch; pub mod loader; pub mod output; pub mod query; diff --git a/src/cli/output.rs b/src/cli/output.rs index 703ee90..fabb2a4 100644 --- a/src/cli/output.rs +++ b/src/cli/output.rs @@ -6,48 +6,49 @@ use std::path::Path; use serde::Serialize; -// ── Shared types ───────────────────────────────────────────────────────────── +/// Outcome of running a single algorithm on a single query. +#[derive(Debug, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum AlgoStatus { + Ok, + Error, + Panic, +} /// Result of running a single algorithm on a single query. #[derive(Debug, Serialize)] pub struct AlgoResult { - pub status: String, + pub status: AlgoStatus, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, - /// Result count (nnz / number of reachable nodes). Present in both - /// `query` and `bench` modes. #[serde(skip_serializing_if = "Option::is_none")] pub result_count: Option, - /// Timing statistics — only present in `bench` mode. #[serde(skip_serializing_if = "Option::is_none")] pub timing: Option, } impl AlgoResult { - /// Create a successful result with an optional result count and timing. pub fn ok(result_count: Option, timing: Option) -> Self { Self { - status: "ok".to_string(), + status: AlgoStatus::Ok, error: None, result_count, timing, } } - /// Create an error result. pub fn error(message: String) -> Self { Self { - status: "error".to_string(), + status: AlgoStatus::Error, error: Some(message), result_count: None, timing: None, } } - /// Create a panic result. pub fn panic(message: String) -> Self { Self { - status: "panic".to_string(), + status: AlgoStatus::Panic, error: Some(message), result_count: None, timing: None, @@ -55,7 +56,6 @@ impl AlgoResult { } } -/// Benchmark timings for one algorithm/result. #[derive(Debug, Serialize)] pub struct AlgoTiming { pub total: TimingStats, @@ -71,7 +71,6 @@ pub struct TimingStats { pub iterations: usize, } -/// Results for a single query across all algorithms. #[derive(Debug, Serialize)] pub struct QueryResult { pub query_index: usize, @@ -80,29 +79,25 @@ pub struct QueryResult { pub algorithms: HashMap, } -// ── Query output ───────────────────────────────────────────────────────────── - -/// Top-level JSON output for the `query` subcommand. #[derive(Debug, Serialize)] pub struct QueryOutput { pub metadata: QueryMetadata, pub results: Vec, } -/// Metadata for a `query` run (no criterion parameters). #[derive(Debug, Serialize)] pub struct QueryMetadata { pub timestamp: String, pub graph_path: String, pub graph_format: String, pub queries_file: String, - pub base_iri: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_iri: Option, pub num_nodes: usize, pub num_labels: usize, } impl QueryOutput { - /// Write the output to a JSON file. pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { let json = serde_json::to_string_pretty(self) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; @@ -110,44 +105,28 @@ impl QueryOutput { } } -// ── Bench output ────────────────────────────────────────────────────────────── - -/// Top-level JSON output for the `bench` subcommand. #[derive(Debug, Serialize)] pub struct BenchOutput { pub metadata: BenchMetadata, - pub results: Vec, + pub results: Vec, } -/// Metadata for a `bench` run (includes criterion parameters). #[derive(Debug, Serialize)] pub struct BenchMetadata { pub timestamp: String, pub graph_path: String, pub graph_format: String, pub queries_file: String, - pub base_iri: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_iri: Option, pub num_nodes: usize, pub num_labels: usize, pub sample_size: usize, pub warm_up_secs: u64, pub measurement_secs: u64, - pub batch_size: usize, -} - -/// Results for a batch of queries. -#[derive(Debug, Serialize)] -pub struct BatchResult { - /// Zero-based batch index. - pub batch_index: usize, - /// Query indices included in this batch. - pub query_indices: Vec, - /// Per-query results within this batch. - pub queries: Vec, } impl BenchOutput { - /// Write the output to a JSON file. pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { let json = serde_json::to_string_pretty(self) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; @@ -182,5 +161,21 @@ mod tests { let value = serde_json::to_value(&result).expect("serialize"); assert!(value["timing"]["total"].is_object()); assert!(value["timing"]["ffi_only"].is_object()); + assert_eq!(value["status"], "ok"); + } + + #[test] + fn error_status_serializes_lowercase() { + let r = AlgoResult::error("boom".into()); + let v = serde_json::to_value(&r).expect("serialize"); + assert_eq!(v["status"], "error"); + assert_eq!(v["error"], "boom"); + } + + #[test] + fn panic_status_serializes_lowercase() { + let r = AlgoResult::panic("kaboom".into()); + let v = serde_json::to_value(&r).expect("serialize"); + assert_eq!(v["status"], "panic"); } } diff --git a/src/cli/query.rs b/src/cli/query.rs index 1f4fb73..152883c 100644 --- a/src/cli/query.rs +++ b/src/cli/query.rs @@ -5,23 +5,24 @@ use std::collections::HashMap; +use crate::eval::{Evaluator, ResultCount}; use crate::graph::InMemoryGraph; -use crate::rpq::RpqQuery; +use crate::rpq::{RpqError, RpqQuery}; -use super::args::QueryArgs; -use super::bench::run_once; use super::loader::LoadedQuery; use super::output::{AlgoResult, QueryResult}; -/// Run all queries once per algorithm and return structured results. -/// -/// Progress and per-query summaries are printed to stderr. No checkpoint -/// or criterion involvement — this is a simple single-pass execution. -pub fn run_queries( - args: &QueryArgs, +/// Run all queries once for one evaluator and return structured results. +pub fn run_query_for_evaluator( + algo_name: &str, + evaluator: E, graph: &InMemoryGraph, queries: &[LoadedQuery], -) -> Vec { +) -> Vec +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ let mut results = Vec::with_capacity(queries.len()); for (idx, loaded) in queries.iter().enumerate() { @@ -31,9 +32,7 @@ pub fn run_queries( Ok(q) => q, Err(e) => { eprintln!("[query #{idx}] id={} — parse error: {e}", loaded.id); - for algo in &args.common.algo { - algo_results.insert(algo.to_string(), AlgoResult::error(e.to_string())); - } + algo_results.insert(algo_name.to_string(), AlgoResult::error(e.to_string())); results.push(QueryResult { query_index: idx, query_id: loaded.id.clone(), @@ -44,35 +43,38 @@ pub fn run_queries( } }; - for algo in &args.common.algo { - let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - run_once(algo, query, graph) - })); + let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + evaluator + .evaluate(query, graph) + .and_then(|result| result.result_count().map_err(RpqError::Graph)) + })); - let algo_result = match outcome { - Ok(Ok(count)) => { - eprintln!( - "[query #{idx}] id={} algo={} — {count} count", - loaded.id, algo - ); - AlgoResult::ok(Some(count), None) - } - Ok(Err(e)) => { - eprintln!("[query #{idx}] id={} algo={} — error: {e}", loaded.id, algo); - AlgoResult::error(e.to_string()) - } - Err(panic_info) => { - let msg = format!("{:?}", panic_info); - eprintln!( - "[query #{idx}] id={} algo={} — panic: {msg}", - loaded.id, algo - ); - AlgoResult::panic(msg) - } - }; + let algo_result = match outcome { + Ok(Ok(count)) => { + eprintln!( + "[query #{idx}] id={} algo={} — {count} count", + loaded.id, algo_name + ); + AlgoResult::ok(Some(count), None) + } + Ok(Err(e)) => { + eprintln!( + "[query #{idx}] id={} algo={} — error: {e}", + loaded.id, algo_name + ); + AlgoResult::error(e.to_string()) + } + Err(panic_info) => { + let msg = format!("{:?}", panic_info); + eprintln!( + "[query #{idx}] id={} algo={} — panic: {msg}", + loaded.id, algo_name + ); + AlgoResult::panic(msg) + } + }; - algo_results.insert(algo.to_string(), algo_result); - } + algo_results.insert(algo_name.to_string(), algo_result); results.push(QueryResult { query_index: idx, @@ -84,3 +86,29 @@ pub fn run_queries( results } + +#[cfg(test)] +mod tests { + use super::*; + use crate::rpq::nfarpq::NfaRpqEvaluator; + use crate::utils::build_graph; + + #[test] + fn generic_runner_records_result_for_one_algorithm() { + let graph = build_graph(&[("A", "B", "p")]); + let queries = vec![LoadedQuery { + id: "q1".into(), + text: "SELECT ?x ?y WHERE { ?x

?y . }".into(), + parsed: Ok(RpqQuery { + subject: crate::rpq::Endpoint::Variable("x".into()), + path: crate::rpq::PathExpr::Label("p".into()), + object: crate::rpq::Endpoint::Variable("y".into()), + }), + }]; + + let results = run_query_for_evaluator("nfarpq", NfaRpqEvaluator, &graph, &queries); + + assert_eq!(results.len(), 1); + assert!(results[0].algorithms.contains_key("nfarpq")); + } +} diff --git a/src/eval/mod.rs b/src/eval/mod.rs new file mode 100644 index 0000000..fb652c5 --- /dev/null +++ b/src/eval/mod.rs @@ -0,0 +1,35 @@ +//! Generic abstractions over query evaluators. + +use crate::graph::{GraphDecomposition, GraphError}; + +pub trait Evaluator { + type Query; + type Result; + type Error; + type Prepared: PreparedEvaluator; + + fn prepare( + &self, + query: &Self::Query, + graph: &G, + ) -> Result; + + fn evaluate( + &self, + query: &Self::Query, + graph: &G, + ) -> Result { + self.prepare(query, graph)?.execute() + } +} + +pub trait PreparedEvaluator { + type Result; + type Error; + + fn execute(&mut self) -> Result; +} + +pub trait ResultCount { + fn result_count(&self) -> Result; +} diff --git a/src/lib.rs b/src/lib.rs index f87a4e7..2502767 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod eval; pub mod formats; pub mod graph; pub mod rpq; diff --git a/src/rpq/mod.rs b/src/rpq/mod.rs index db2d220..d7d430a 100644 --- a/src/rpq/mod.rs +++ b/src/rpq/mod.rs @@ -13,6 +13,7 @@ pub mod nfarpq; pub mod rpqmatrix; +use crate::eval::{Evaluator, PreparedEvaluator}; use crate::graph::{GraphDecomposition, GraphError}; use crate::sparql::ExtractError; use spargebra::SparqlSyntaxError; @@ -91,13 +92,28 @@ pub enum RpqError { Graph(#[from] GraphError), } -pub trait RpqEvaluator { - /// Output of this evaluator (e.g. reachable vector vs path matrix + nnz). - type Result; +pub trait RpqEvaluator: Evaluator { + fn prepare( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + ::prepare(self, query, graph) + } fn evaluate( &self, query: &RpqQuery, graph: &G, - ) -> Result; + ) -> Result { + ::evaluate(self, query, graph) + } +} +impl RpqEvaluator for T where T: Evaluator {} + +pub trait PreparedRpq: PreparedEvaluator { + fn execute(&mut self) -> Result { + ::execute(self) + } } +impl PreparedRpq for T where T: PreparedEvaluator {} diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index 9550396..29e040b 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -1,9 +1,10 @@ //! NFA-based RPQ evaluation using `LAGraph_RegularPathQuery`. -use crate::graph::{GraphDecomposition, GraphblasVector, LagraphGraph}; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::{GraphDecomposition, GraphError, GraphblasVector, LagraphGraph}; use crate::lagraph_sys::LAGraph_Kind; use crate::lagraph_sys::*; -use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use crate::rpq::{Endpoint, PathExpr, RpqError, RpqQuery}; use crate::{grb_ok, la_ok}; use rustfst::algorithms::closure::{ClosureType, closure}; use rustfst::algorithms::concat::concat; @@ -208,6 +209,12 @@ pub struct NfaRpqResult { pub reachable: GraphblasVector, } +impl ResultCount for NfaRpqResult { + fn result_count(&self) -> Result { + Ok(self.reachable.nvals()? as usize) + } +} + pub struct PreparedNfaRpq { nfa: Nfa, nfa_matrices: Vec<(String, LagraphGraph)>, @@ -244,8 +251,11 @@ fn filter_reachable_by_destination( Ok(filtered) } -impl PreparedNfaRpq { - pub fn execute(&mut self) -> Result { +impl PreparedEvaluator for PreparedNfaRpq { + type Result = NfaRpqResult; + type Error = RpqError; + + fn execute(&mut self) -> Result { let mut reachable: GrB_Vector = std::ptr::null_mut(); unsafe { @@ -274,10 +284,16 @@ impl PreparedNfaRpq { } /// Evaluates RPQs using `LAGraph_RegularPathQuery`. +#[derive(Clone, Copy)] pub struct NfaRpqEvaluator; -impl NfaRpqEvaluator { - pub fn prepare( +impl Evaluator for NfaRpqEvaluator { + type Query = RpqQuery; + type Result = NfaRpqResult; + type Error = RpqError; + type Prepared = PreparedNfaRpq; + + fn prepare( &self, query: &RpqQuery, graph: &G, @@ -286,7 +302,7 @@ impl NfaRpqEvaluator { let nfa_matrices = nfa.build_lagraph_matrices()?; let src_id = resolve_endpoint(&query.subject, graph)?; - let _dst_id = resolve_endpoint(&query.object, graph)?; + let dst_id = resolve_endpoint(&query.object, graph)?; let n = graph.num_nodes(); let source_vertices: Vec = match src_id { @@ -312,32 +328,12 @@ impl NfaRpqEvaluator { _data_graphs: data_graphs, data_graph_ptrs, source_vertices, - destination_vertex: _dst_id, + destination_vertex: dst_id, num_nodes: n, }) } } -impl RpqEvaluator for NfaRpqEvaluator { - type Result = NfaRpqResult; - - fn evaluate( - &self, - query: &RpqQuery, - graph: &G, - ) -> Result { - let mut prepared = self.prepare(query, graph)?; - let result = prepared.execute()?; - let destination_vertex = resolve_endpoint(&query.object, graph)?; - let reachable = filter_reachable_by_destination( - result.reachable, - destination_vertex, - graph.num_nodes(), - )?; - Ok(NfaRpqResult { reachable }) - } -} - fn resolve_endpoint( term: &Endpoint, graph: &G, diff --git a/src/rpq/rpqmatrix.rs b/src/rpq/rpqmatrix.rs index 623a838..1462abc 100644 --- a/src/rpq/rpqmatrix.rs +++ b/src/rpq/rpqmatrix.rs @@ -4,9 +4,10 @@ use std::ptr::null_mut; use egg::{Id, RecExpr, define_language}; -use crate::graph::{GraphDecomposition, GraphblasMatrix}; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::{GraphDecomposition, GraphError, GraphblasMatrix}; use crate::lagraph_sys::*; -use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use crate::rpq::{Endpoint, PathExpr, RpqError, RpqQuery}; use crate::{grb_ok, la_ok}; const RPQMATRIX_REDUCE_BY_COL: u8 = 1; @@ -181,22 +182,33 @@ impl RpqMatrixResult { /// matrix to its non-empty columns. pub fn reachable_target_count(&self) -> Result { let mut count: GrB_Index = 0; - unsafe { grb_ok!(LAGraph_RPQMatrix_reduce( - &mut count, - self.matrix.inner, - RPQMATRIX_REDUCE_BY_COL, - ))? }; + unsafe { + grb_ok!(LAGraph_RPQMatrix_reduce( + &mut count, + self.matrix.inner, + RPQMATRIX_REDUCE_BY_COL, + ))? + }; Ok(count as u64) } } +impl ResultCount for RpqMatrixResult { + fn result_count(&self) -> Result { + Ok(self.reachable_target_count()? as usize) + } +} + pub struct PreparedRpqMatrix { plans: Vec, owned_matrices: Vec, } -impl PreparedRpqMatrix { - pub fn execute(&mut self) -> Result { +impl PreparedEvaluator for PreparedRpqMatrix { + type Result = RpqMatrixResult; + type Error = RpqError; + + fn execute(&mut self) -> Result { let root_ptr = unsafe { self.plans.as_mut_ptr().add(self.plans.len() - 1) }; let mut nnz: GrB_Index = 0; @@ -228,10 +240,16 @@ impl Drop for PreparedRpqMatrix { } /// RPQ evaluator backed by `LAGraph_RPQMatrix`. +#[derive(Clone, Copy)] pub struct RpqMatrixEvaluator; -impl RpqMatrixEvaluator { - pub fn prepare( +impl Evaluator for RpqMatrixEvaluator { + type Query = RpqQuery; + type Result = RpqMatrixResult; + type Error = RpqError; + type Prepared = PreparedRpqMatrix; + + fn prepare( &self, query: &RpqQuery, graph: &G, @@ -246,19 +264,6 @@ impl RpqMatrixEvaluator { } } -impl RpqEvaluator for RpqMatrixEvaluator { - type Result = RpqMatrixResult; - - fn evaluate( - &self, - query: &RpqQuery, - graph: &G, - ) -> Result { - let mut prepared = self.prepare(query, graph)?; - prepared.execute() - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index 3269cb5..4963868 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -7,7 +7,7 @@ use pathrex::formats::mm::MatrixMarket; use pathrex::graph::{Graph, GraphDecomposition, GraphError, InMemory, InMemoryGraph}; use pathrex::lagraph_sys::GrB_Index; use pathrex::rpq::nfarpq::NfaRpqEvaluator; -use pathrex::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use pathrex::rpq::{Endpoint, PathExpr, PreparedRpq, RpqError, RpqEvaluator, RpqQuery}; use pathrex::sparql::parse_rpq; use pathrex::utils::build_graph; diff --git a/tests/rpqmatrix_tests.rs b/tests/rpqmatrix_tests.rs index e13b8f1..e04379d 100644 --- a/tests/rpqmatrix_tests.rs +++ b/tests/rpqmatrix_tests.rs @@ -7,7 +7,7 @@ use pathrex::formats::mm::MatrixMarket; use pathrex::graph::{Graph, GraphDecomposition, GraphError, InMemory, InMemoryGraph}; use pathrex::lagraph_sys::{GrB_Index, GrB_Info, GrB_Matrix_extractElement_BOOL}; use pathrex::rpq::rpqmatrix::{RpqMatrixEvaluator, RpqMatrixResult}; -use pathrex::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use pathrex::rpq::{Endpoint, PathExpr, PreparedRpq, RpqError, RpqEvaluator, RpqQuery}; use pathrex::sparql::parse_rpq; use pathrex::utils::build_graph; From 521ce56c8deb27c5b4502a633995b206e9d210f2 Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Wed, 29 Apr 2026 13:25:30 +0300 Subject: [PATCH 09/11] ref: cargo fmt --- src/cli/loader.rs | 2 +- src/formats/nt.rs | 210 --------------------------------------- src/graph/wrappers.rs | 2 - src/rpq/mod.rs | 4 +- src/utils.rs | 16 ++- tests/nfarpq_tests.rs | 5 +- tests/rpqmatrix_tests.rs | 5 +- 7 files changed, 17 insertions(+), 227 deletions(-) delete mode 100644 src/formats/nt.rs diff --git a/src/cli/loader.rs b/src/cli/loader.rs index 938d060..096185a 100644 --- a/src/cli/loader.rs +++ b/src/cli/loader.rs @@ -9,8 +9,8 @@ use std::path::Path; use thiserror::Error; -use crate::formats::mm::MatrixMarket; use crate::formats::Csv; +use crate::formats::mm::MatrixMarket; use crate::graph::{Graph, GraphError, InMemory, InMemoryGraph}; use crate::rpq::{RpqError, RpqQuery}; use crate::sparql::parse_rpq; diff --git a/src/formats/nt.rs b/src/formats/nt.rs deleted file mode 100644 index aa08880..0000000 --- a/src/formats/nt.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! N-Triples edge iterator for the formats layer. -//! -//! ```no_run -//! use pathrex::formats::NTriples; -//! use pathrex::formats::FormatError; -//! -//! # let reader = std::io::empty(); -//! let iter = NTriples::new(reader) -//! .filter_map(|r| match r { -//! Err(FormatError::LiteralAsNode) => None, // skip -//! other => Some(other), -//! }); -//! ``` -//! -//! To load into a graph: -//! -//! ```no_run -//! use pathrex::graph::{Graph, InMemory, GraphDecomposition}; -//! use pathrex::formats::NTriples; -//! use std::fs::File; -//! -//! let graph = Graph::::try_from( -//! NTriples::new(File::open("data.nt").unwrap()) -//! ).unwrap(); -//! ``` - -use std::io::Read; - -use oxrdf::{NamedOrBlankNode, Term}; -use oxttl::NTriplesParser; -use oxttl::ntriples::ReaderNTriplesParser; - -use crate::formats::FormatError; -use crate::graph::Edge; - -/// An iterator that reads N-Triples and yields `Result`. -/// -/// # Example -/// -/// ```no_run -/// use pathrex::formats::nt::NTriples; -/// use std::fs::File; -/// -/// let file = File::open("data.nt").unwrap(); -/// let iter = NTriples::new(file); -/// for result in iter { -/// let edge = result.unwrap(); -/// println!("{} --{}--> {}", edge.source, edge.label, edge.target); -/// } -/// ``` -pub struct NTriples { - inner: ReaderNTriplesParser, -} - -impl NTriples { - pub fn new(reader: R) -> Self { - Self { - inner: NTriplesParser::new().for_reader(reader), - } - } - - fn subject_to_node_id(subject: NamedOrBlankNode) -> String { - match subject { - NamedOrBlankNode::NamedNode(n) => n.into_string(), - NamedOrBlankNode::BlankNode(b) => format!("_:{}", b.as_str()), - } - } - - fn object_to_node_id(object: Term) -> Result { - match object { - Term::NamedNode(n) => Ok(n.into_string()), - Term::BlankNode(b) => Ok(format!("_:{}", b.as_str())), - Term::Literal(_) => Err(FormatError::LiteralAsNode), - } - } -} - -impl Iterator for NTriples { - type Item = Result; - - fn next(&mut self) -> Option { - let triple = match self.inner.next()? { - Ok(t) => t, - Err(e) => return Some(Err(FormatError::NTriples(e.to_string()))), - }; - - let source = Self::subject_to_node_id(triple.subject.into()); - let label = triple.predicate.as_str().to_owned(); - let target = match Self::object_to_node_id(triple.object) { - Ok(t) => t, - Err(e) => return Some(Err(e)), - }; - - Some(Ok(Edge { - source, - target, - label, - })) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn parse(nt: &str) -> Vec> { - NTriples::new(nt.as_bytes()).collect() - } - - #[test] - fn test_basic_ntriples() { - let nt = " .\n\ - .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 2); - - let e0 = edges[0].as_ref().unwrap(); - assert_eq!(e0.source, "http://example.org/Alice"); - assert_eq!(e0.target, "http://example.org/Bob"); - assert_eq!(e0.label, "http://example.org/knows"); - - let e1 = edges[1].as_ref().unwrap(); - assert_eq!(e1.source, "http://example.org/Bob"); - assert_eq!(e1.target, "http://example.org/Charlie"); - assert_eq!(e1.label, "http://example.org/likes"); - } - - #[test] - fn test_blank_node_subject_and_object() { - let nt = "_:b1 _:b2 .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 1); - - let e = edges[0].as_ref().unwrap(); - assert_eq!(e.source, "_:b1"); - assert_eq!(e.target, "_:b2"); - } - - #[test] - fn test_literal_object_yields_error() { - let nt = " \"Alice\" .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 1); - assert!( - matches!(edges[0], Err(FormatError::LiteralAsNode)), - "literal object should yield LiteralAsNode error" - ); - } - - #[test] - fn test_caller_can_skip_literal_triples() { - let nt = " .\n\ - \"Alice\" .\n\ - .\n"; - let edges: Vec<_> = NTriples::new(nt.as_bytes()) - .filter_map(|r| match r { - Err(FormatError::LiteralAsNode) => None, - other => Some(other), - }) - .collect(); - - assert_eq!(edges.len(), 2, "literal triple should be skipped"); - assert!(edges.iter().all(|r| r.is_ok())); - } - - #[test] - fn test_predicate_with_fragment_is_full_iri_string() { - let nt = - " .\n"; - let edges = parse(nt); - assert_eq!( - edges[0].as_ref().unwrap().label, - "http://example.org/ns#knows" - ); - } - - #[test] - fn test_non_ascii_in_iris() { - let nt = " .\n\ - .\n"; - let edges = parse(nt); - assert_eq!(edges.len(), 2); - - let e0 = edges[0].as_ref().unwrap(); - assert_eq!(e0.source, "http://example.org/人甲"); - assert_eq!(e0.target, "http://example.org/人乙"); - assert_eq!(e0.label, "http://example.org/关系/认识"); - - let e1 = edges[1].as_ref().unwrap(); - assert_eq!(e1.source, "http://example.org/Алиса"); - assert_eq!(e1.target, "http://example.org/Боб"); - assert_eq!(e1.label, "http://example.org/знает"); - } - - #[test] - fn test_ntriples_graph_source() { - use crate::graph::{GraphBuilder, GraphDecomposition, InMemoryBuilder}; - - let nt = " .\n\ - .\n"; - let iter = NTriples::new(nt.as_bytes()); - - let graph = InMemoryBuilder::default() - .load(iter) - .expect("load should succeed") - .build() - .expect("build should succeed"); - assert_eq!(graph.num_nodes(), 3); - } -} diff --git a/src/graph/wrappers.rs b/src/graph/wrappers.rs index cc8d3a1..e97cfc5 100644 --- a/src/graph/wrappers.rs +++ b/src/graph/wrappers.rs @@ -212,8 +212,6 @@ impl GraphblasVector { indices.truncate(actual_nvals as usize); Ok(indices) } - - } impl Drop for GraphblasVector { diff --git a/src/rpq/mod.rs b/src/rpq/mod.rs index d7d430a..48e0cc4 100644 --- a/src/rpq/mod.rs +++ b/src/rpq/mod.rs @@ -52,7 +52,9 @@ impl RpqQuery { } fn strip_endpoint(ep: &mut Endpoint, base: &str) { - if let Endpoint::Named(s) = ep && s.starts_with(base) { + if let Endpoint::Named(s) = ep + && s.starts_with(base) + { *s = s[base.len()..].to_owned(); } } diff --git a/src/utils.rs b/src/utils.rs index a6d5d51..edf2b18 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -134,16 +134,14 @@ impl From for GrB_Info { /// ``` #[macro_export] macro_rules! grb_ok { - ($grb_func:expr) => { - { - let info: $crate::lagraph_sys::GrB_Info = $grb_func.into(); - if info == $crate::lagraph_sys::GrB_Info::GrB_SUCCESS { - Ok(()) - } else { - Err($crate::graph::GraphError::GraphBlas(info)) - } + ($grb_func:expr) => {{ + let info: $crate::lagraph_sys::GrB_Info = $grb_func.into(); + if info == $crate::lagraph_sys::GrB_Info::GrB_SUCCESS { + Ok(()) + } else { + Err($crate::graph::GraphError::GraphBlas(info)) } - }; + }}; } /// Calls a raw LAGraph function and maps its `i32` return code to diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index 4963868..0f97b09 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -21,13 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line.split_once(',').map(|x| x.1) + let query_str = line + .split_once(',') + .map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } diff --git a/tests/rpqmatrix_tests.rs b/tests/rpqmatrix_tests.rs index e04379d..3f84d80 100644 --- a/tests/rpqmatrix_tests.rs +++ b/tests/rpqmatrix_tests.rs @@ -21,13 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line.split_once(',').map(|x| x.1) + let query_str = line + .split_once(',') + .map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } From eeb4b24e6751a75618afd0063de4c03eb491d34d Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Wed, 29 Apr 2026 14:31:01 +0300 Subject: [PATCH 10/11] fix: remove criterion dirs after each group run --- Cargo.toml | 3 +- src/bin/pathrex.rs | 4 ++- src/cli/args.rs | 12 +++++--- src/cli/bench/error.rs | 6 ++++ src/cli/bench/runner.rs | 62 +++++++++++++++++++++++++---------------- src/cli/dispatch.rs | 7 ++--- 6 files changed, 60 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a93d66..ae9cf61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,10 +20,11 @@ serde = { version = "1", features = ["derive"], optional = true } serde_json = { version = "1", optional = true } chrono = { version = "0.4", features = ["serde"], optional = true } criterion = { version = "0.5", optional = true } +tempfile = { version = "3", optional = true } [features] regenerate-bindings = ["bindgen"] -bench = ["clap", "serde", "serde_json", "chrono", "criterion"] +bench = ["clap", "serde", "serde_json", "chrono", "criterion", "tempfile"] [dev-dependencies] tempfile = "3" diff --git a/src/bin/pathrex.rs b/src/bin/pathrex.rs index 91194f0..69b508a 100644 --- a/src/bin/pathrex.rs +++ b/src/bin/pathrex.rs @@ -242,7 +242,9 @@ fn run_bench_cmd(args: BenchArgs) -> Result<(), MainError> { eprintln!(); eprintln!("=== Done ==="); eprintln!("Results written to: {}", args.output); - eprintln!("Criterion data in: {}", args.criterion_dir); + if let Some(dir) = &args.criterion_dir { + eprintln!("Criterion data in: {dir}") + } Ok(()) } diff --git a/src/cli/args.rs b/src/cli/args.rs index 0a8528a..2ac9eb3 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -93,11 +93,15 @@ pub struct BenchArgs { #[arg(long)] pub resume: bool, - /// Directory for criterion output. - #[arg(long, default_value = "bench_criterion/")] - pub criterion_dir: String, + /// Directory for criterion output. When omitted, criterion writes into a + /// per-group temporary directory that is wiped immediately after each + /// benchmark group is parsed (default behavior). + #[arg(long)] + pub criterion_dir: Option, - /// Enable criterion HTML plot generation. + /// Enable criterion HTML plot generation. Requires `--criterion-dir`, + /// since plots written to a tempdir would be wiped before they could be + /// inspected. #[arg(long)] pub plots: bool, diff --git a/src/cli/bench/error.rs b/src/cli/bench/error.rs index 3514b11..ad4c38a 100644 --- a/src/cli/bench/error.rs +++ b/src/cli/bench/error.rs @@ -18,4 +18,10 @@ pub enum BenchError { #[error("checkpoint error: {0}")] Checkpoint(#[from] CheckpointError), + + #[error("invalid bench arguments: {0}")] + InvalidArgs(String), + + #[error("failed to create temporary directory for criterion output: {0}")] + TempDir(#[source] std::io::Error), } diff --git a/src/cli/bench/runner.rs b/src/cli/bench/runner.rs index b16a9c1..4b5a54a 100644 --- a/src/cli/bench/runner.rs +++ b/src/cli/bench/runner.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::time::Duration; use criterion::{Criterion, black_box}; @@ -14,13 +14,37 @@ use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; use crate::graph::InMemoryGraph; use crate::rpq::{RpqError, RpqQuery}; -/// Build a criterion instance from CLI bench args. -pub(crate) fn build_criterion(args: &BenchArgs) -> Criterion { +/// Per-group criterion output destination. +pub(crate) enum GroupOutput { + Temp(tempfile::TempDir), + Persistent(PathBuf), +} + +impl GroupOutput { + pub(crate) fn for_group(args: &BenchArgs) -> Result { + match &args.criterion_dir { + Some(p) => Ok(Self::Persistent(PathBuf::from(p))), + None => { + let td = tempfile::tempdir().map_err(BenchError::TempDir)?; + Ok(Self::Temp(td)) + } + } + } + + pub(crate) fn path(&self) -> &Path { + match self { + Self::Temp(td) => td.path(), + Self::Persistent(p) => p.as_path(), + } + } +} + +pub(crate) fn build_criterion(args: &BenchArgs, output_dir: &Path) -> Criterion { let c = Criterion::default() .sample_size(args.sample_size) .warm_up_time(Duration::from_secs(args.warm_up)) .measurement_time(Duration::from_secs(args.measurement)) - .output_directory(Path::new(&args.criterion_dir)); + .output_directory(output_dir); if args.plots { c.with_plots() } else { @@ -33,7 +57,6 @@ fn group_name(query_index: usize, algo_id: &str) -> String { } fn run_benchmark_group( - criterion: &mut Criterion, args: &BenchArgs, algo_name: &str, evaluator: E, @@ -48,6 +71,14 @@ where let mut prepared = evaluator.prepare(query, graph)?; let group = group_name(query_index, algo_name); + let output = match GroupOutput::for_group(args) { + Ok(o) => o, + Err(e) => return Ok(Err(e)), + }; + let output_path = output.path().to_path_buf(); + + let mut criterion = build_criterion(args, &output_path); + { let mut g = criterion.benchmark_group(&group); @@ -66,13 +97,10 @@ where g.finish(); } - Ok(read_algo_timing(Path::new(&args.criterion_dir), &group)) + Ok(read_algo_timing(&output_path, &group)) } /// Run the bench loop for every query in `queries` for one evaluator. -/// -/// One criterion group is created per `(query, algo)` pair; checkpoint persists -/// after every algo completion. pub fn run_bench_for_evaluator( args: &BenchArgs, algo: &Algo, @@ -81,7 +109,6 @@ pub fn run_bench_for_evaluator( graph: &InMemoryGraph, queries: &[LoadedQuery], checkpointer: &mut Checkpointer, - criterion: &mut Criterion, ) -> Result, BenchError> where E: Evaluator + Copy, @@ -123,7 +150,7 @@ where eprintln!("[query #{}] id={}", idx, loaded.id); eprintln!(" [bench] algo={algo_name}"); - match run_benchmark_group(criterion, args, algo_name, evaluator, query, graph, idx) { + match run_benchmark_group(args, algo_name, evaluator, query, graph, idx) { Ok(Ok(timing)) => { algorithms.insert(algo_name.to_string(), AlgoResult::ok(None, Some(timing))); } @@ -148,16 +175,3 @@ where Ok(results) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn group_name_is_filesystem_safe() { - let g = group_name(7, "nfa"); - assert_eq!(g, "query_7_nfa"); - assert!(!g.contains('/')); - assert!(!g.contains(' ')); - } -} diff --git a/src/cli/dispatch.rs b/src/cli/dispatch.rs index 7263e83..581fe77 100644 --- a/src/cli/dispatch.rs +++ b/src/cli/dispatch.rs @@ -2,7 +2,7 @@ use crate::cli::args::{Algo, BenchArgs, QueryArgs}; use crate::cli::bench::error::BenchError; -use crate::cli::bench::runner::{build_criterion, run_bench_for_evaluator}; +use crate::cli::bench::runner::run_bench_for_evaluator; use crate::cli::checkpoint::Checkpointer; use crate::cli::loader::LoadedQuery; use crate::cli::output::QueryResult; @@ -49,7 +49,8 @@ pub fn dispatch_bench( queries: &[LoadedQuery], checkpointer: &mut Checkpointer, ) -> Result, BenchError> { - let mut criterion = build_criterion(args); + validate_bench_args(args)?; + let mut all = Vec::new(); for algo in &args.common.algo { @@ -63,7 +64,6 @@ pub fn dispatch_bench( graph, queries, checkpointer, - &mut criterion, )?, Algo::Rpqmatrix => run_bench_for_evaluator( args, @@ -73,7 +73,6 @@ pub fn dispatch_bench( graph, queries, checkpointer, - &mut criterion, )?, }; merge_results(&mut all, per_algo); From cd5f603bc75632acf9411735546fe995a4d381ea Mon Sep 17 00:00:00 2001 From: Ivan Glazunov Date: Wed, 29 Apr 2026 15:01:55 +0300 Subject: [PATCH 11/11] feat: add rdf format to cli --- src/cli/args.rs | 4 ++-- src/cli/dispatch.rs | 2 -- src/cli/loader.rs | 10 +++++++++- src/cli/output.rs | 4 ++-- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/cli/args.rs b/src/cli/args.rs index 2ac9eb3..d95f478 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -141,10 +141,9 @@ impl std::fmt::Display for Algo { #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] #[value(rename_all = "lowercase")] pub enum GraphFormat { - /// MatrixMarket directory layout (vertices.txt, edges.txt, *.txt). Mm, - /// CSV file with source/target/label columns. Csv, + Rdf, } impl std::fmt::Display for GraphFormat { @@ -152,6 +151,7 @@ impl std::fmt::Display for GraphFormat { match self { GraphFormat::Mm => write!(f, "mm"), GraphFormat::Csv => write!(f, "csv"), + GraphFormat::Rdf => write!(f, "rdf"), } } } diff --git a/src/cli/dispatch.rs b/src/cli/dispatch.rs index 581fe77..2701460 100644 --- a/src/cli/dispatch.rs +++ b/src/cli/dispatch.rs @@ -49,8 +49,6 @@ pub fn dispatch_bench( queries: &[LoadedQuery], checkpointer: &mut Checkpointer, ) -> Result, BenchError> { - validate_bench_args(args)?; - let mut all = Vec::new(); for algo in &args.common.algo { diff --git a/src/cli/loader.rs b/src/cli/loader.rs index 096185a..a652043 100644 --- a/src/cli/loader.rs +++ b/src/cli/loader.rs @@ -10,7 +10,8 @@ use std::path::Path; use thiserror::Error; use crate::formats::Csv; -use crate::formats::mm::MatrixMarket; +use crate::formats::MatrixMarket; +use crate::formats::Rdf; use crate::graph::{Graph, GraphError, InMemory, InMemoryGraph}; use crate::rpq::{RpqError, RpqQuery}; use crate::sparql::parse_rpq; @@ -65,6 +66,13 @@ pub fn load_graph( source: e, }) } + GraphFormat::Rdf => { + let rdf = Rdf::from_path(graph_path).unwrap(); + Graph::::try_from(rdf).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, + }) + } } } diff --git a/src/cli/output.rs b/src/cli/output.rs index fabb2a4..9b568ea 100644 --- a/src/cli/output.rs +++ b/src/cli/output.rs @@ -100,7 +100,7 @@ pub struct QueryMetadata { impl QueryOutput { pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { let json = serde_json::to_string_pretty(self) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + .map_err(std::io::Error::other)?; fs::write(path, json) } } @@ -129,7 +129,7 @@ pub struct BenchMetadata { impl BenchOutput { pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { let json = serde_json::to_string_pretty(self) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + .map_err(std::io::Error::other)?; fs::write(path, json) } }