diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..faf37b5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git/ +.worktrees/ +target/ +.tmp/ +deps/LAGraph/build/ +bench_criterion/ +bench_results.json +bench_checkpoint.json diff --git a/AGENTS.md b/AGENTS.md index 223cfcb..22dc2df 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,7 @@ pathrex/ ├── Cargo.toml # Crate manifest (edition 2024) ├── build.rs # Links LAGraph + LAGraphX; optionally regenerates FFI bindings ├── src/ -│ ├── lib.rs # Modules: formats, graph, rpq, sparql, utils, lagraph_sys +│ ├── lib.rs # Modules: eval, formats, graph, rpq, sparql, utils, lagraph_sys │ ├── main.rs # Binary entry point (placeholder) │ ├── lagraph_sys.rs # FFI module — includes generated bindings │ ├── lagraph_sys_generated.rs# Bindgen output (checked in, regenerated in CI) @@ -24,8 +24,10 @@ pathrex/ │ │ ├── mod.rs # Core traits (GraphBuilder, GraphDecomposition, GraphSource, │ │ │ # Backend, Graph), error types, RAII wrappers, GrB init │ │ └── inmemory.rs # InMemory marker, InMemoryBuilder, InMemoryGraph +│ ├── eval/ +│ │ └── mod.rs # Evaluator, PreparedEvaluator, ResultCount traits │ ├── rpq/ -│ │ ├── mod.rs # RpqEvaluator (assoc. Result), RpqQuery, Endpoint, PathExpr, RpqError +│ │ ├── mod.rs # RPQ query types, RpqError, RPQ marker subtraits │ │ ├── nfarpq.rs # NfaRpqEvaluator (LAGraph_RegularPathQuery) │ │ └── rpqmatrix.rs # Matrix-plan RPQ evaluator │ ├── sparql/ @@ -187,6 +189,18 @@ pub trait Backend { - [`get_node_id(string_id)`](src/graph/mod.rs:200) / [`get_node_name(mapped_id)`](src/graph/mod.rs:203) — bidirectional string ↔ integer dictionary. - [`num_nodes()`](src/graph/mod.rs:204) — total unique nodes. +### Generic evaluator abstraction (`src/eval/`) + +[`src/eval/mod.rs`](src/eval/mod.rs) defines query-language-agnostic evaluator traits: + +- [`Evaluator`](src/eval/mod.rs) uses associated types for `Query`, `Result`, `Error`, and + `Prepared`. The graph backend stays a method-level generic (`G: GraphDecomposition`) so one + evaluator type can run against any graph backend selected at the call site. +- [`PreparedEvaluator`](src/eval/mod.rs) represents prepared `(query, graph)` state that can be + executed repeatedly, which is used by benchmark timing loops. +- [`ResultCount`](src/eval/mod.rs) is separate from `Evaluator::Result`; only CLI runners that + need counts require this bound, leaving room for future evaluators with richer result types. + ### InMemoryBuilder / InMemoryGraph [`InMemoryBuilder`](src/graph/inmemory.rs:36) is the primary `GraphBuilder` implementation. @@ -331,10 +345,11 @@ Key public items: - [`RpqQuery`](src/rpq/mod.rs) — `{ subject, path, object }` using the types above; [`strip_base(&mut self, base)`](src/rpq/mod.rs) removes a shared IRI prefix from named endpoints and labels. -- [`RpqEvaluator`](src/rpq/mod.rs) — trait with associated type `Result` and - [`evaluate(query, graph)`](src/rpq/mod.rs) taking `&RpqQuery` and - [`GraphDecomposition`], returning `Result`. - Each concrete evaluator exposes its own output type (see below). +- [`RpqEvaluator`](src/rpq/mod.rs) — marker subtrait over + [`Evaluator`](src/eval/mod.rs), preserving the RPQ-facing + trait name while the generic evaluator hierarchy lives in `src/eval/`. +- [`PreparedRpq`](src/rpq/mod.rs) — marker subtrait over + [`PreparedEvaluator`](src/eval/mod.rs). - [`RpqError`](src/rpq/mod.rs) — unified error type for RPQ parsing and evaluation: `Parse` (SPARQL syntax), `Extract` (query extraction), `UnsupportedPath`, `VertexNotFound`, and `Graph` (wraps [`GraphError`](src/graph/mod.rs) for @@ -378,10 +393,11 @@ Key public items: - [`RpqQuery`](src/rpq/mod.rs) — `{ subject, path, object }` using the types above; [`strip_base(&mut self, base)`](src/rpq/mod.rs) removes a shared IRI prefix from named endpoints and labels. -- [`RpqEvaluator`](src/rpq/mod.rs) — trait with associated type `Result` and - [`evaluate(query, graph)`](src/rpq/mod.rs) taking `&RpqQuery` and - [`GraphDecomposition`], returning `Result`. - Each concrete evaluator exposes its own output type (see below). +- [`RpqEvaluator`](src/rpq/mod.rs) — marker subtrait over + [`Evaluator`](src/eval/mod.rs), preserving the RPQ-facing + trait name while the generic evaluator hierarchy lives in `src/eval/`. +- [`PreparedRpq`](src/rpq/mod.rs) — marker subtrait over + [`PreparedEvaluator`](src/eval/mod.rs). - [`RpqError`](src/rpq/mod.rs) — unified error type for RPQ parsing and evaluation: `Parse` (SPARQL syntax), `Extract` (query extraction), `UnsupportedPath`, `VertexNotFound`, and `Graph` (wraps [`GraphError`](src/graph/mod.rs) for @@ -423,6 +439,18 @@ over label adjacency matrices and runs [`LAGraph_RPQMatrix`]. It returns Subject/object do not filter the matrix; a named subject is only validated to exist. Bound objects are not supported yet ([`RpqError::UnsupportedPath`]). +### CLI dispatch (`src/cli/dispatch.rs`) + +With the `bench` feature enabled, [`src/cli/dispatch.rs`](src/cli/dispatch.rs) is the single +mapping from [`Algo`](src/cli/args.rs) variants to concrete evaluator types. `dispatch_query` +and `dispatch_bench` each perform one exhaustive `match` per requested algorithm, then call +generic runners (`run_query_for_evaluator` and `run_bench_for_evaluator`) that are +monomorphized for the selected evaluator. + +Adding a new algorithm requires a new `Algo` variant, its `Display` arm, one `dispatch_query` +arm, one `dispatch_bench` arm, an `impl Evaluator` for the evaluator type, and an +`impl ResultCount` for any result type used by CLI count reporting. + ### FFI layer [`lagraph_sys`](src/lagraph_sys.rs) exposes raw C bindings for GraphBLAS and diff --git a/Cargo.toml b/Cargo.toml index c454400..ae9cf61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,11 +15,24 @@ rustfst = "1.2" spargebra = "0.4.6" thiserror = "1.0" +clap = { version = "4", features = ["derive"], optional = true } +serde = { version = "1", features = ["derive"], optional = true } +serde_json = { version = "1", optional = true } +chrono = { version = "0.4", features = ["serde"], optional = true } +criterion = { version = "0.5", optional = true } +tempfile = { version = "3", optional = true } + [features] regenerate-bindings = ["bindgen"] +bench = ["clap", "serde", "serde_json", "chrono", "criterion", "tempfile"] [dev-dependencies] tempfile = "3" [build-dependencies] bindgen = { version = "0.71", optional = true } + +[[bin]] +name = "pathrex" +path = "src/bin/pathrex.rs" +required-features = ["bench"] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0d8f56a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,49 @@ +FROM rust:1-bookworm AS builder + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + clang \ + cmake \ + git \ + libclang-dev \ + make \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /src + +COPY . . + +RUN git clone --depth 1 https://github.com/DrTimothyAldenDavis/GraphBLAS.git /tmp/GraphBLAS \ + && make -C /tmp/GraphBLAS compact \ + && make -C /tmp/GraphBLAS install \ + && mkdir -p deps/LAGraph/build \ + && make -C deps/LAGraph \ + && cargo build --release --bin pathrex --features "bench,regenerate-bindings" + +FROM debian:bookworm-slim AS runtime + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + gcc \ + libc6-dev \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /work + +COPY --from=builder /src/target/release/pathrex /usr/local/bin/pathrex +COPY --from=builder /usr/local/lib/libgraphblas.so* /usr/local/lib/ +COPY --from=builder /src/deps/LAGraph/build/src/liblagraph.so* /usr/local/lib/ +COPY --from=builder /src/deps/LAGraph/build/experimental/liblagraphx.so* /usr/local/lib/ +COPY --from=builder /src/docker/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh + +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +ENV LD_LIBRARY_PATH=/usr/local/lib + +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] +CMD ["--help"] diff --git a/build.rs b/build.rs index 243bbcb..95243fb 100644 --- a/build.rs +++ b/build.rs @@ -65,6 +65,7 @@ fn regenerate_bindings() { .allowlist_item("GrB_Info") .allowlist_function("GrB_Matrix_new") .allowlist_function("GrB_Matrix_nvals") + .allowlist_function("GrB_Matrix_dup") .allowlist_function("GrB_Matrix_free") .allowlist_function("GrB_Matrix_extractElement_BOOL") .allowlist_function("GrB_Matrix_build_BOOL") @@ -89,6 +90,7 @@ fn regenerate_bindings() { .allowlist_function("LAGraph_Cached_AT") .allowlist_function("LAGraph_MMRead") .allowlist_function("LAGraph_RPQMatrix") + .allowlist_function("LAGraph_RPQMatrix_reduce") .allowlist_function("LAGraph_DestroyRpqMatrixPlan") .allowlist_function("LAGraph_RPQMatrix_label") .allowlist_function("LAGraph_RPQMatrix_Free") diff --git a/docker/docker-entrypoint.sh b/docker/docker-entrypoint.sh new file mode 100755 index 0000000..8da1b38 --- /dev/null +++ b/docker/docker-entrypoint.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -euo pipefail + +pathrex_bin="${PATHREX_BIN:-/usr/local/bin/pathrex}" + +if [ "${1-}" = "bench" ]; then + shift + has_output=false + has_checkpoint=false + has_criterion_dir=false + args=("$@") + + for arg in "${args[@]}"; do + case "$arg" in + -o|-o*|--output|--output=*) + has_output=true + ;; + -c|-c*|--checkpoint|--checkpoint=*) + has_checkpoint=true + ;; + --criterion-dir|--criterion-dir=*) + has_criterion_dir=true + ;; + esac + done + + if [ "$has_output" = false ]; then + args+=(--output /results/bench_results.json) + fi + + if [ "$has_checkpoint" = false ]; then + args+=(--checkpoint /results/bench_checkpoint.json) + fi + + if [ "$has_criterion_dir" = false ]; then + args+=(--criterion-dir /results/criterion) + fi + + exec "$pathrex_bin" bench "${args[@]}" +fi + +exec "$pathrex_bin" "$@" diff --git a/src/bin/pathrex.rs b/src/bin/pathrex.rs new file mode 100644 index 0000000..69b508a --- /dev/null +++ b/src/bin/pathrex.rs @@ -0,0 +1,250 @@ +//! Entry point for the `pathrex` binary. +//! +//! Subcommands: +//! - `query` — run queries once and report result counts +//! - `bench` — benchmark RPQ evaluators with criterion +//! +//! # Examples +//! +//! ```bash +//! # Run queries once (prints per-query result counts): +//! cargo run --release --bin pathrex --features bench -- query \ +//! --graph tests/testdata/mm_graph \ +//! --queries tests/testdata/cases/any-any/queries.txt +//! +//! # Benchmark with criterion: +//! cargo run --release --bin pathrex --features bench -- bench \ +//! --graph tests/testdata/mm_graph \ +//! --queries tests/testdata/cases/any-any/queries.txt \ +//! --algo nfa rpqmatrix \ +//! --output results.json +//! ``` + +use std::error::Error as StdError; +use std::path::{Path, PathBuf}; + +use chrono::Utc; +use clap::Parser; +use thiserror::Error; + +use pathrex::cli::args::{BenchArgs, Cli, Commands, QueryArgs}; +use pathrex::cli::bench::BenchError; +use pathrex::cli::checkpoint::{Checkpoint, CheckpointError, Checkpointer}; +use pathrex::cli::dispatch::{dispatch_bench, dispatch_query}; +use pathrex::cli::loader::{GraphLoadError, LoadedQuery, load_graph, load_queries}; +use pathrex::cli::output::{BenchMetadata, BenchOutput, QueryMetadata, QueryOutput}; +use pathrex::graph::{GraphDecomposition, InMemoryGraph}; + +#[derive(Debug, Error)] +enum MainError { + #[error(transparent)] + Graph(#[from] GraphLoadError), + #[error("error loading queries from '{path}': {source}")] + Queries { + path: String, + #[source] + source: std::io::Error, + }, + #[error(transparent)] + Checkpoint(#[from] CheckpointError), + #[error(transparent)] + Bench(#[from] BenchError), + #[error("error writing output to '{path}': {source}")] + Output { + path: String, + #[source] + source: std::io::Error, + }, +} + +fn main() { + if let Err(e) = run() { + eprintln!("Error: {e}"); + // Walk the source chain so users see the underlying cause. + let mut cur: Option<&dyn StdError> = e.source(); + while let Some(c) = cur { + eprintln!(" caused by: {c}"); + cur = c.source(); + } + std::process::exit(1); + } +} + +fn run() -> Result<(), MainError> { + let cli = Cli::parse(); + match cli.command { + Commands::Query(args) => run_query_cmd(args), + Commands::Bench(args) => run_bench_cmd(args), + } +} + +fn load_query_file(path: &str, base_iri: Option<&str>) -> Result, MainError> { + load_queries(Path::new(path), base_iri).map_err(|e| MainError::Queries { + path: path.to_string(), + source: e, + }) +} + +fn run_query_cmd(args: QueryArgs) -> Result<(), MainError> { + let common = &args.common; + + eprintln!("=== pathrex query ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!(); + + eprintln!("[1/2] Loading graph..."); + let graph: InMemoryGraph = + load_graph(&common.graph, common.format, common.base_iri.as_deref())?; + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/2] Loading and running queries..."); + let queries = load_query_file(&common.queries, common.base_iri.as_deref())?; + eprintln!(" loaded {} queries", queries.len()); + + let results = dispatch_query(&args, &graph, &queries); + + let errors = results + .iter() + .flat_map(|r| r.algorithms.values()) + .filter(|a| !matches!(a.status, pathrex::cli::output::AlgoStatus::Ok)) + .count(); + eprintln!(); + eprintln!( + "Done. {} queries × {} algos. {errors} error(s).", + results.len(), + common.algo.len() + ); + + if let Some(ref out_path) = args.output { + let output = QueryOutput { + metadata: QueryMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.to_string(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + }, + results, + }; + output + .write_to_file(Path::new(out_path)) + .map_err(|e| MainError::Output { + path: out_path.clone(), + source: e, + })?; + eprintln!("Results written to: {out_path}"); + } + + Ok(()) +} + +fn build_checkpointer(args: &BenchArgs, queries_len: usize) -> Result { + let common = &args.common; + let path = PathBuf::from(&args.checkpoint); + + if args.resume { + match Checkpoint::load(&path)? { + Some(cp) => { + cp.validate(&common.graph, &common.queries, &common.algo)?; + let cper = Checkpointer::with_inner(cp, path); + eprintln!( + " resuming: {}/{} queries fully done", + cper.fully_done_count(&common.algo), + queries_len + ); + Ok(cper) + } + None => { + eprintln!(" no checkpoint file found, starting fresh"); + Ok(Checkpointer::fresh( + &common.graph, + &common.queries, + &common.algo, + path, + )) + } + } + } else { + Ok(Checkpointer::fresh( + &common.graph, + &common.queries, + &common.algo, + path, + )) + } +} + +fn run_bench_cmd(args: BenchArgs) -> Result<(), MainError> { + let common = &args.common; + + eprintln!("=== pathrex bench ==="); + eprintln!("Graph: {}", common.graph); + eprintln!("Format: {}", common.format); + eprintln!("Queries: {}", common.queries); + eprintln!("Algos: {:?}", common.algo); + eprintln!("Output: {}", args.output); + eprintln!(); + + eprintln!("[1/4] Loading graph..."); + let graph: InMemoryGraph = + load_graph(&common.graph, common.format, common.base_iri.as_deref())?; + eprintln!(" nodes: {}", graph.num_nodes()); + eprintln!(" labels: {}", graph.num_labels()); + eprintln!(); + + eprintln!("[2/4] Loading queries..."); + let queries = load_query_file(&common.queries, common.base_iri.as_deref())?; + eprintln!(" loaded {} queries", queries.len()); + let parse_errors = queries.iter().filter(|q| q.parsed.is_err()).count(); + if parse_errors > 0 { + eprintln!(" ({parse_errors} queries failed to parse)"); + } + eprintln!(); + + eprintln!("[3/4] Setting up checkpoint..."); + let mut checkpointer = build_checkpointer(&args, queries.len())?; + eprintln!(); + + eprintln!("[4/4] Running benchmarks..."); + eprintln!(); + let results = dispatch_bench(&args, &graph, &queries, &mut checkpointer)?; + + let output = BenchOutput { + metadata: BenchMetadata { + timestamp: Utc::now().to_rfc3339(), + graph_path: common.graph.clone(), + graph_format: common.format.to_string(), + queries_file: common.queries.clone(), + base_iri: common.base_iri.clone(), + num_nodes: graph.num_nodes(), + num_labels: graph.num_labels(), + sample_size: args.sample_size, + warm_up_secs: args.warm_up, + measurement_secs: args.measurement, + }, + results, + }; + + output + .write_to_file(Path::new(&args.output)) + .map_err(|e| MainError::Output { + path: args.output.clone(), + source: e, + })?; + + eprintln!(); + eprintln!("=== Done ==="); + eprintln!("Results written to: {}", args.output); + if let Some(dir) = &args.criterion_dir { + eprintln!("Criterion data in: {dir}") + } + + Ok(()) +} diff --git a/src/cli/args.rs b/src/cli/args.rs new file mode 100644 index 0000000..d95f478 --- /dev/null +++ b/src/cli/args.rs @@ -0,0 +1,177 @@ +//! CLI argument definitions for the `pathrex` binary. +//! +//! Structure: +//! - [`Cli`] — top-level parser with a `subcommand` field +//! - [`Commands`] — `bench` or `query` +//! - [`CommonArgs`] — args shared by both subcommands (graph, queries, algo, …) +//! - [`BenchArgs`] — bench-specific args (criterion, checkpoint, …) +//! - [`QueryArgs`] — query-specific args (optional output file) +//! - [`Algo`] — algorithm identifier enum +//! - [`GraphFormat`] — input graph format enum + +use clap::{Args, Parser, Subcommand, ValueEnum}; + +/// Top-level CLI for pathrex. +#[derive(Parser, Debug)] +#[command( + name = "pathrex", + about = "RPQ evaluator and benchmarking tool for edge-labeled graphs" +)] +pub struct Cli { + #[command(subcommand)] + pub command: Commands, +} + +/// Available subcommands. +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Run queries once and report result counts + Query(QueryArgs), + /// Benchmark RPQ evaluators with criterion + Bench(BenchArgs), +} + +/// Arguments shared by both subcommands. +#[derive(Args, Debug)] +pub struct CommonArgs { + /// Path to graph directory (mm) or file (csv). + #[arg(short = 'g', long)] + pub graph: String, + + /// Graph format. + #[arg(short = 'f', long, value_enum, default_value_t = GraphFormat::Mm)] + pub format: GraphFormat, + + /// Path to queries file (format: `,` per line). + #[arg(short = 'q', long)] + pub queries: String, + + /// Optional base IRI prepended to bare SPARQL patterns as `BASE `. + /// Pass without a value (`--base-iri`) to use the default `http://example.org/`. + /// Pass with a value (`--base-iri `) to use a custom IRI. + /// When omitted entirely, no BASE declaration is added to the query. + #[arg( + short = 'b', + long, + num_args = 0..=1, + default_missing_value = "http://example.org/", + require_equals = false + )] + pub base_iri: Option, + + /// Algorithms to use. + #[arg(short = 'a', long, value_enum, num_args = 1.., required = true)] + pub algo: Vec, +} + +/// Arguments for the `query` subcommand. +#[derive(Args, Debug)] +pub struct QueryArgs { + #[command(flatten)] + pub common: CommonArgs, + + /// Optional path to write results as JSON. + #[arg(short = 'o', long)] + pub output: Option, +} + +/// Arguments for the `bench` subcommand. +#[derive(Args, Debug)] +pub struct BenchArgs { + #[command(flatten)] + pub common: CommonArgs, + + /// Output JSON file for benchmark results. + #[arg(short = 'o', long, default_value = "bench_results.json")] + pub output: String, + + /// Checkpoint file path. + #[arg(short = 'c', long, default_value = "bench_checkpoint.json")] + pub checkpoint: String, + + /// Resume from checkpoint, skipping completed queries. + #[arg(long)] + pub resume: bool, + + /// Directory for criterion output. When omitted, criterion writes into a + /// per-group temporary directory that is wiped immediately after each + /// benchmark group is parsed (default behavior). + #[arg(long)] + pub criterion_dir: Option, + + /// Enable criterion HTML plot generation. Requires `--criterion-dir`, + /// since plots written to a tempdir would be wiped before they could be + /// inspected. + #[arg(long)] + pub plots: bool, + + /// Criterion sample size per benchmark group. + #[arg(long, default_value_t = 10)] + pub sample_size: usize, + + /// Criterion warm-up time in seconds. + #[arg(long, default_value_t = 1)] + pub warm_up: u64, + + /// Criterion measurement time in seconds. + #[arg(long, default_value_t = 5)] + pub measurement: u64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, ValueEnum, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +#[value(rename_all = "lowercase")] +pub enum Algo { + /// NFA-based evaluator (`LAGraph_RegularPathQuery`). + NfaRpq, + /// Matrix-plan evaluator (`LAGraph_RPQMatrix`). + Rpqmatrix, +} + +impl std::fmt::Display for Algo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Algo::NfaRpq => write!(f, "nfarpq"), + Algo::Rpqmatrix => write!(f, "rpqmatrix"), + } + } +} + +/// Input graph format. +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +#[value(rename_all = "lowercase")] +pub enum GraphFormat { + Mm, + Csv, + Rdf, +} + +impl std::fmt::Display for GraphFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GraphFormat::Mm => write!(f, "mm"), + GraphFormat::Csv => write!(f, "csv"), + GraphFormat::Rdf => write!(f, "rdf"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::Parser; + + #[test] + fn query_requires_at_least_one_algo() { + let result = Cli::try_parse_from([ + "pathrex", + "query", + "--graph", + "graph", + "--queries", + "queries", + ]); + + assert!(result.is_err()); + } +} diff --git a/src/cli/bench/error.rs b/src/cli/bench/error.rs new file mode 100644 index 0000000..ad4c38a --- /dev/null +++ b/src/cli/bench/error.rs @@ -0,0 +1,27 @@ +//! Error type for the bench pipeline. + +use thiserror::Error; + +use crate::cli::checkpoint::CheckpointError; + +#[derive(Debug, Error)] +pub enum BenchError { + #[error("criterion estimates missing for group '{0}' (file not found or unreadable)")] + MissingEstimates(String), + + #[error("criterion estimates parse error for group '{group}': {source}")] + EstimatesParse { + group: String, + #[source] + source: serde_json::Error, + }, + + #[error("checkpoint error: {0}")] + Checkpoint(#[from] CheckpointError), + + #[error("invalid bench arguments: {0}")] + InvalidArgs(String), + + #[error("failed to create temporary directory for criterion output: {0}")] + TempDir(#[source] std::io::Error), +} diff --git a/src/cli/bench/estimates.rs b/src/cli/bench/estimates.rs new file mode 100644 index 0000000..1cf72cb --- /dev/null +++ b/src/cli/bench/estimates.rs @@ -0,0 +1,149 @@ +//! Typed deserialization of criterion's per-benchmark JSON output. +//! +//! After `group.finish()`, criterion writes: +//! +//! ```text +//! ///new/estimates.json +//! ///new/sample.json +//! ``` +//! +//! We read both to extract a [`TimingStats`]. + +use std::fs::File; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +use crate::cli::bench::error::BenchError; +use crate::cli::output::{AlgoTiming, TimingStats}; + +#[derive(Deserialize)] +struct Estimates { + mean: PointEstimate, + median: PointEstimate, + std_dev: PointEstimate, +} + +#[derive(Deserialize)] +struct PointEstimate { + point_estimate: f64, +} + +#[derive(Deserialize)] +struct Sample { + iters: Vec, +} + +fn estimates_path(criterion_dir: &Path, group: &str, bench: &str) -> PathBuf { + criterion_dir + .join(group) + .join(bench) + .join("new") + .join("estimates.json") +} + +fn sample_path(criterion_dir: &Path, group: &str, bench: &str) -> PathBuf { + criterion_dir + .join(group) + .join(bench) + .join("new") + .join("sample.json") +} + +/// Read one bench's timing stats out of criterion's output directory. +/// +/// Returns `Err(MissingEstimates)` if the JSON file isn't present (e.g. the +/// criterion run was interrupted) and `Err(EstimatesParse)` on shape +/// mismatches — making schema drift loud rather than silent. +pub fn read_timing_stats( + criterion_dir: &Path, + group: &str, + bench: &str, +) -> Result { + let est_path = estimates_path(criterion_dir, group, bench); + let file = File::open(&est_path) + .map_err(|_| BenchError::MissingEstimates(format!("{}/{}", group, bench)))?; + let est: Estimates = serde_json::from_reader(file).map_err(|e| BenchError::EstimatesParse { + group: format!("{}/{}", group, bench), + source: e, + })?; + + // sample.json is best-effort; if missing we report iterations=0. + let iterations = File::open(sample_path(criterion_dir, group, bench)) + .ok() + .and_then(|f| serde_json::from_reader::<_, Sample>(f).ok()) + .map(|s| s.iters.len()) + .unwrap_or(0); + + Ok(TimingStats { + mean_ns: est.mean.point_estimate, + median_ns: est.median.point_estimate, + stddev_ns: est.std_dev.point_estimate, + iterations, + }) +} + +/// Read both `eval_total` and `eval_ffi_only` benches for a single group. +pub fn read_algo_timing(criterion_dir: &Path, group: &str) -> Result { + let total = read_timing_stats(criterion_dir, group, "eval_total")?; + let ffi_only = read_timing_stats(criterion_dir, group, "eval_ffi_only")?; + Ok(AlgoTiming { total, ffi_only }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + fn write_estimate_files(base: &Path, group: &str, bench: &str, mean_ns: f64, iters: usize) { + let bench_dir = base.join(group).join(bench).join("new"); + fs::create_dir_all(&bench_dir).expect("create bench dir"); + fs::write( + bench_dir.join("estimates.json"), + format!( + r#"{{"mean":{{"point_estimate":{mean_ns}}},"median":{{"point_estimate":{mean_ns}}},"std_dev":{{"point_estimate":0.0}}}}"# + ), + ) + .expect("write estimates"); + let iters_array: Vec<&str> = vec!["1.0"; iters]; + let sample = format!("{{\"iters\":[{}]}}", iters_array.join(",")); + fs::write(bench_dir.join("sample.json"), sample).expect("write sample"); + } + + #[test] + fn reads_split_estimates() { + let dir = tempfile::tempdir().expect("tempdir"); + write_estimate_files(dir.path(), "query_0_nfa", "eval_total", 10.0, 3); + write_estimate_files(dir.path(), "query_0_nfa", "eval_ffi_only", 4.0, 5); + + let timing = read_algo_timing(dir.path(), "query_0_nfa").expect("split timing"); + assert_eq!(timing.total.mean_ns, 10.0); + assert_eq!(timing.total.iterations, 3); + assert_eq!(timing.ffi_only.mean_ns, 4.0); + assert_eq!(timing.ffi_only.iterations, 5); + } + + #[test] + fn missing_file_is_reported() { + let dir = tempfile::tempdir().expect("tempdir"); + let err = read_timing_stats(dir.path(), "missing", "eval_total").unwrap_err(); + match err { + BenchError::MissingEstimates(g) => assert!(g.contains("missing")), + other => panic!("unexpected error: {other}"), + } + } + + #[test] + fn malformed_json_is_reported() { + let dir = tempfile::tempdir().expect("tempdir"); + let bench_dir = dir.path().join("g").join("b").join("new"); + fs::create_dir_all(&bench_dir).unwrap(); + fs::write(bench_dir.join("estimates.json"), "{not json").unwrap(); + + let err = read_timing_stats(dir.path(), "g", "b").unwrap_err(); + match err { + BenchError::EstimatesParse { .. } => {} + other => panic!("unexpected error: {other}"), + } + } +} diff --git a/src/cli/bench/mod.rs b/src/cli/bench/mod.rs new file mode 100644 index 0000000..dc410ee --- /dev/null +++ b/src/cli/bench/mod.rs @@ -0,0 +1,7 @@ +//! Benchmarking subsystem for the `pathrex bench` subcommand. + +pub mod error; +pub mod estimates; +pub mod runner; + +pub use error::BenchError; diff --git a/src/cli/bench/runner.rs b/src/cli/bench/runner.rs new file mode 100644 index 0000000..4b5a54a --- /dev/null +++ b/src/cli/bench/runner.rs @@ -0,0 +1,177 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use criterion::{Criterion, black_box}; + +use crate::cli::args::{Algo, BenchArgs}; +use crate::cli::bench::error::BenchError; +use crate::cli::bench::estimates::read_algo_timing; +use crate::cli::checkpoint::Checkpointer; +use crate::cli::loader::LoadedQuery; +use crate::cli::output::{AlgoResult, QueryResult}; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::InMemoryGraph; +use crate::rpq::{RpqError, RpqQuery}; + +/// Per-group criterion output destination. +pub(crate) enum GroupOutput { + Temp(tempfile::TempDir), + Persistent(PathBuf), +} + +impl GroupOutput { + pub(crate) fn for_group(args: &BenchArgs) -> Result { + match &args.criterion_dir { + Some(p) => Ok(Self::Persistent(PathBuf::from(p))), + None => { + let td = tempfile::tempdir().map_err(BenchError::TempDir)?; + Ok(Self::Temp(td)) + } + } + } + + pub(crate) fn path(&self) -> &Path { + match self { + Self::Temp(td) => td.path(), + Self::Persistent(p) => p.as_path(), + } + } +} + +pub(crate) fn build_criterion(args: &BenchArgs, output_dir: &Path) -> Criterion { + let c = Criterion::default() + .sample_size(args.sample_size) + .warm_up_time(Duration::from_secs(args.warm_up)) + .measurement_time(Duration::from_secs(args.measurement)) + .output_directory(output_dir); + if args.plots { + c.with_plots() + } else { + c.without_plots() + } +} + +fn group_name(query_index: usize, algo_id: &str) -> String { + format!("query_{query_index}_{algo_id}") +} + +fn run_benchmark_group( + args: &BenchArgs, + algo_name: &str, + evaluator: E, + query: &RpqQuery, + graph: &InMemoryGraph, + query_index: usize, +) -> Result, RpqError> +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ + let mut prepared = evaluator.prepare(query, graph)?; + let group = group_name(query_index, algo_name); + + let output = match GroupOutput::for_group(args) { + Ok(o) => o, + Err(e) => return Ok(Err(e)), + }; + let output_path = output.path().to_path_buf(); + + let mut criterion = build_criterion(args, &output_path); + + { + let mut g = criterion.benchmark_group(&group); + + g.bench_function("eval_total", |b| { + b.iter(|| { + let _ = black_box(evaluator.evaluate(query, graph)); + }); + }); + + g.bench_function("eval_ffi_only", |b| { + b.iter(|| { + let _ = black_box(prepared.execute()); + }); + }); + + g.finish(); + } + + Ok(read_algo_timing(&output_path, &group)) +} + +/// Run the bench loop for every query in `queries` for one evaluator. +pub fn run_bench_for_evaluator( + args: &BenchArgs, + algo: &Algo, + algo_name: &str, + evaluator: E, + graph: &InMemoryGraph, + queries: &[LoadedQuery], + checkpointer: &mut Checkpointer, +) -> Result, BenchError> +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ + let mut results = Vec::with_capacity(queries.len()); + + for (idx, loaded) in queries.iter().enumerate() { + if checkpointer.is_algo_done(idx, algo) { + eprintln!( + " [skip] query #{} id={} algo={algo_name} already done", + idx, loaded.id + ); + continue; + } + + let mut algorithms: HashMap = HashMap::new(); + + let query = match &loaded.parsed { + Ok(q) => q, + Err(e) => { + let msg = e.to_string(); + eprintln!( + " [error] query #{} (id={}) algo={} parse error: {}", + idx, loaded.id, algo_name, msg + ); + algorithms.insert(algo_name.to_string(), AlgoResult::error(msg)); + checkpointer.mark_and_save(idx, algo)?; + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms, + }); + continue; + } + }; + + eprintln!("[query #{}] id={}", idx, loaded.id); + eprintln!(" [bench] algo={algo_name}"); + + match run_benchmark_group(args, algo_name, evaluator, query, graph, idx) { + Ok(Ok(timing)) => { + algorithms.insert(algo_name.to_string(), AlgoResult::ok(None, Some(timing))); + } + Ok(Err(e)) => return Err(e), + Err(e) => { + eprintln!( + " [error] query #{} (id={}) algo={} prepare error: {}", + idx, loaded.id, algo_name, e + ); + algorithms.insert(algo_name.to_string(), AlgoResult::error(e.to_string())); + } + } + + checkpointer.mark_and_save(idx, algo)?; + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms, + }); + } + + Ok(results) +} diff --git a/src/cli/checkpoint.rs b/src/cli/checkpoint.rs new file mode 100644 index 0000000..f59f3bb --- /dev/null +++ b/src/cli/checkpoint.rs @@ -0,0 +1,206 @@ +//! Checkpoint read/write/validation for crash recovery. +//! +//! After each query-algorithm pair completes, the checkpoint file is updated +//! so that a crashed run can be resumed from the last completed point. +//! +//! Two layers: +//! - [`Checkpoint`] — pure data; serialised to disk. +//! - [`Checkpointer`] — runtime owner; pairs the data with its file path +//! and exposes a fallible `mark_and_save`. Errors propagate; no silent +//! corruption. + +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::args::Algo; + +/// Persistent checkpoint state written to disk as JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Checkpoint { + pub version: u32, + pub graph_path: String, + pub queries_file: String, + pub algorithms: Vec, + pub completed: Vec, +} + +/// Tracks which algorithms have been completed for a single query. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryCompletion { + pub query_index: usize, + pub algorithms_done: Vec, +} + +impl Checkpoint { + /// Create a fresh checkpoint for a new benchmark run. + pub fn new(graph_path: &str, queries_file: &str, algorithms: &[Algo]) -> Self { + Self { + version: 1, + graph_path: graph_path.to_string(), + queries_file: queries_file.to_string(), + algorithms: algorithms.to_vec(), + completed: Vec::new(), + } + } + + /// Load a checkpoint from disk. Returns `None` if the file doesn't exist. + pub fn load(path: &Path) -> Result, CheckpointError> { + if !path.exists() { + return Ok(None); + } + let data = fs::read_to_string(path) + .map_err(|e| CheckpointError::Io(path.display().to_string(), e))?; + let cp: Self = serde_json::from_str(&data) + .map_err(|e| CheckpointError::Parse(path.display().to_string(), e))?; + Ok(Some(cp)) + } + + /// Validate that a loaded checkpoint matches the current run parameters. + pub fn validate( + &self, + graph_path: &str, + queries_file: &str, + algorithms: &[Algo], + ) -> Result<(), CheckpointError> { + if self.graph_path != graph_path { + return Err(CheckpointError::Mismatch(format!( + "graph_path: checkpoint has '{}', current is '{}'", + self.graph_path, graph_path + ))); + } + if self.queries_file != queries_file { + return Err(CheckpointError::Mismatch(format!( + "queries_file: checkpoint has '{}', current is '{}'", + self.queries_file, queries_file + ))); + } + let cp_algos: HashSet<&Algo> = self.algorithms.iter().collect(); + let cur_algos: HashSet<&Algo> = algorithms.iter().collect(); + if cp_algos != cur_algos { + return Err(CheckpointError::Mismatch(format!( + "algorithms: checkpoint has {:?}, current is {:?}", + self.algorithms, algorithms + ))); + } + Ok(()) + } + + /// Save the checkpoint to disk. + pub fn save(&self, path: &Path) -> Result<(), CheckpointError> { + let json = serde_json::to_string_pretty(self).map_err(CheckpointError::Serialize)?; + + // Write to a temp file first, then rename for atomicity. + let tmp_path = path.with_extension("json.tmp"); + fs::write(&tmp_path, &json) + .map_err(|e| CheckpointError::Io(tmp_path.display().to_string(), e))?; + fs::rename(&tmp_path, path) + .map_err(|e| CheckpointError::Io(path.display().to_string(), e))?; + Ok(()) + } + + /// Check if all requested algorithms are done for a given query index. + pub fn is_fully_done(&self, query_index: usize, algos: &[Algo]) -> bool { + let Some(entry) = self.completed.iter().find(|c| c.query_index == query_index) else { + return false; + }; + let done: HashSet<&Algo> = entry.algorithms_done.iter().collect(); + algos.iter().all(|a| done.contains(a)) + } + + pub fn is_algo_done(&self, query_index: usize, algo: &Algo) -> bool { + self.completed + .iter() + .find(|c| c.query_index == query_index) + .map(|c| c.algorithms_done.contains(algo)) + .unwrap_or(false) + } + + pub fn mark_algo_done(&mut self, query_index: usize, algo: &Algo) { + if let Some(entry) = self + .completed + .iter_mut() + .find(|c| c.query_index == query_index) + { + if !entry.algorithms_done.contains(algo) { + entry.algorithms_done.push(algo.clone()); + } + } else { + self.completed.push(QueryCompletion { + query_index, + algorithms_done: vec![algo.clone()], + }); + } + } +} + +/// Runtime owner for a [`Checkpoint`] paired with its on-disk path. +pub struct Checkpointer { + inner: Checkpoint, + path: PathBuf, +} + +impl Checkpointer { + /// Create a new checkpointer with no completions. + pub fn fresh(graph_path: &str, queries_file: &str, algorithms: &[Algo], path: PathBuf) -> Self { + Self { + inner: Checkpoint::new(graph_path, queries_file, algorithms), + path, + } + } + + /// Wrap an existing [`Checkpoint`] (e.g. one loaded from disk). + pub fn with_inner(inner: Checkpoint, path: PathBuf) -> Self { + Self { inner, path } + } + + /// Number of queries that have *all* requested algorithms done. + pub fn fully_done_count(&self, algos: &[Algo]) -> usize { + self.inner + .completed + .iter() + .filter(|c| { + let done: HashSet<&Algo> = c.algorithms_done.iter().collect(); + algos.iter().all(|a| done.contains(a)) + }) + .count() + } + + pub fn is_fully_done(&self, query_index: usize, algos: &[Algo]) -> bool { + self.inner.is_fully_done(query_index, algos) + } + + pub fn is_algo_done(&self, query_index: usize, algo: &Algo) -> bool { + self.inner.is_algo_done(query_index, algo) + } + + /// Mark `(query_index, algo)` complete and persist atomically. + pub fn mark_and_save( + &mut self, + query_index: usize, + algo: &Algo, + ) -> Result<(), CheckpointError> { + self.inner.mark_algo_done(query_index, algo); + self.inner.save(&self.path) + } +} + +/// Errors that can occur during checkpoint operations. +#[derive(Debug, Error)] +pub enum CheckpointError { + /// I/O error reading or writing the checkpoint file. + #[error("checkpoint I/O error ({0}): {1}")] + Io(String, std::io::Error), + /// JSON parsing error. + #[error("checkpoint parse error ({0}): {1}")] + Parse(String, serde_json::Error), + /// JSON serialization error. + #[error("checkpoint serialize error: {0}")] + Serialize(serde_json::Error), + /// Checkpoint parameters don't match current run. + #[error("checkpoint mismatch: {0}")] + Mismatch(String), +} diff --git a/src/cli/dispatch.rs b/src/cli/dispatch.rs new file mode 100644 index 0000000..2701460 --- /dev/null +++ b/src/cli/dispatch.rs @@ -0,0 +1,108 @@ +//! Typed dispatch from CLI algorithm choices to concrete evaluators. + +use crate::cli::args::{Algo, BenchArgs, QueryArgs}; +use crate::cli::bench::error::BenchError; +use crate::cli::bench::runner::run_bench_for_evaluator; +use crate::cli::checkpoint::Checkpointer; +use crate::cli::loader::LoadedQuery; +use crate::cli::output::QueryResult; +use crate::cli::query::run_query_for_evaluator; +use crate::graph::InMemoryGraph; +use crate::rpq::nfarpq::NfaRpqEvaluator; +use crate::rpq::rpqmatrix::RpqMatrixEvaluator; + +fn merge_results(all: &mut Vec, per_algo: Vec) { + for result in per_algo { + if let Some(existing) = all + .iter_mut() + .find(|existing| existing.query_index == result.query_index) + { + existing.algorithms.extend(result.algorithms); + } else { + all.push(result); + } + } +} + +pub fn dispatch_query( + args: &QueryArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], +) -> Vec { + let mut all = Vec::new(); + + for algo in &args.common.algo { + let name = algo.to_string(); + let per_algo = match algo { + Algo::NfaRpq => run_query_for_evaluator(&name, NfaRpqEvaluator, graph, queries), + Algo::Rpqmatrix => run_query_for_evaluator(&name, RpqMatrixEvaluator, graph, queries), + }; + merge_results(&mut all, per_algo); + } + + all +} + +pub fn dispatch_bench( + args: &BenchArgs, + graph: &InMemoryGraph, + queries: &[LoadedQuery], + checkpointer: &mut Checkpointer, +) -> Result, BenchError> { + let mut all = Vec::new(); + + for algo in &args.common.algo { + let name = algo.to_string(); + let per_algo = match algo { + Algo::NfaRpq => run_bench_for_evaluator( + args, + algo, + &name, + NfaRpqEvaluator, + graph, + queries, + checkpointer, + )?, + Algo::Rpqmatrix => run_bench_for_evaluator( + args, + algo, + &name, + RpqMatrixEvaluator, + graph, + queries, + checkpointer, + )?, + }; + merge_results(&mut all, per_algo); + } + + Ok(all) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cli::output::AlgoResult; + use std::collections::HashMap; + + #[test] + fn merge_results_combines_algorithms_by_query_index() { + let mut all = vec![QueryResult { + query_index: 0, + query_id: "q0".into(), + query_text: "query".into(), + algorithms: HashMap::from([("nfarpq".into(), AlgoResult::ok(Some(1), None))]), + }]; + let per_algo = vec![QueryResult { + query_index: 0, + query_id: "q0".into(), + query_text: "query".into(), + algorithms: HashMap::from([("rpqmatrix".into(), AlgoResult::ok(Some(1), None))]), + }]; + + merge_results(&mut all, per_algo); + + assert_eq!(all.len(), 1); + assert_eq!(all[0].algorithms.len(), 2); + } +} diff --git a/src/cli/loader.rs b/src/cli/loader.rs new file mode 100644 index 0000000..a652043 --- /dev/null +++ b/src/cli/loader.rs @@ -0,0 +1,135 @@ +//! Graph and query loading for the `pathrex` CLI. +//! +//! Both subcommands (`bench` and `query`) need to load a graph and a queries +//! file. + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; + +use thiserror::Error; + +use crate::formats::Csv; +use crate::formats::MatrixMarket; +use crate::formats::Rdf; +use crate::graph::{Graph, GraphError, InMemory, InMemoryGraph}; +use crate::rpq::{RpqError, RpqQuery}; +use crate::sparql::parse_rpq; + +use super::args::GraphFormat; + +#[derive(Debug, Error)] +pub enum GraphLoadError { + #[error("error opening graph at '{path}': {source}")] + Open { + path: String, + #[source] + source: std::io::Error, + }, + #[error("error loading graph from '{path}': {source}")] + Build { + path: String, + #[source] + source: GraphError, + }, +} + +/// Load an [`InMemoryGraph`] from `graph_path` in the given `format`. +pub fn load_graph( + graph_path: &str, + format: GraphFormat, + base_iri: Option<&str>, +) -> Result { + match format { + GraphFormat::Mm => { + let mm_base = MatrixMarket::from_dir(graph_path); + let mm = match base_iri { + Some(iri) => mm_base.with_base_iri(iri), + None => mm_base, + }; + Graph::::try_from(mm).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, + }) + } + GraphFormat::Csv => { + let file = File::open(graph_path).map_err(|e| GraphLoadError::Open { + path: graph_path.to_string(), + source: e, + })?; + let csv_source = Csv::from_reader(file).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e.into(), + })?; + Graph::::try_from(csv_source).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, + }) + } + GraphFormat::Rdf => { + let rdf = Rdf::from_path(graph_path).unwrap(); + Graph::::try_from(rdf).map_err(|e| GraphLoadError::Build { + path: graph_path.to_string(), + source: e, + }) + } + } +} + +#[derive(Debug)] +pub struct LoadedQuery { + pub id: String, + pub text: String, + pub parsed: Result, +} + +/// Load and parse queries from a file. +/// +/// Each non-empty line must have the format `,`. +/// The pattern is wrapped into a full SPARQL query before parsing: +/// - When `base_iri` is `Some(iri)`: `BASE <{iri}> SELECT * WHERE { {pattern} . }` +/// - When `base_iri` is `None`: `SELECT * WHERE { {pattern} . }` +pub fn load_queries( + path: &Path, + base_iri: Option<&str>, +) -> Result, std::io::Error> { + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut queries = Vec::new(); + + for line in reader.lines() { + let line = line?; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + + let (id, pattern) = match trimmed.splitn(2, ',').collect::>().as_slice() { + [id, pattern] => (id.trim().to_string(), pattern.trim().to_string()), + _ => { + queries.push(LoadedQuery { + id: "?".to_string(), + text: trimmed.to_string(), + parsed: Err(RpqError::UnsupportedPath(format!( + "query line has no comma: {trimmed:?}" + ))), + }); + continue; + } + }; + + let sparql = match base_iri { + Some(iri) => format!("BASE <{iri}> SELECT * WHERE {{ {pattern} . }}"), + None => format!("SELECT * WHERE {{ {pattern} . }}"), + }; + let parsed = parse_rpq(&sparql); + + queries.push(LoadedQuery { + id, + text: pattern, + parsed, + }); + } + + Ok(queries) +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..9e2cb96 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,16 @@ +//! CLI layer for the `pathrex` binary. +//! +//! This module is only compiled when the `bench` feature is enabled. +//! It provides the argument definitions, graph/query loading, and the two +//! subcommand runners: +//! +//! - `bench` — criterion-based benchmarking with checkpointing +//! - `query` — single-shot query execution with result counts + +pub mod args; +pub mod bench; +pub mod checkpoint; +pub mod dispatch; +pub mod loader; +pub mod output; +pub mod query; diff --git a/src/cli/output.rs b/src/cli/output.rs new file mode 100644 index 0000000..9b568ea --- /dev/null +++ b/src/cli/output.rs @@ -0,0 +1,181 @@ +//! JSON output types and serialization for benchmark and query results. + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +use serde::Serialize; + +/// Outcome of running a single algorithm on a single query. +#[derive(Debug, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum AlgoStatus { + Ok, + Error, + Panic, +} + +/// Result of running a single algorithm on a single query. +#[derive(Debug, Serialize)] +pub struct AlgoResult { + pub status: AlgoStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub result_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timing: Option, +} + +impl AlgoResult { + pub fn ok(result_count: Option, timing: Option) -> Self { + Self { + status: AlgoStatus::Ok, + error: None, + result_count, + timing, + } + } + + pub fn error(message: String) -> Self { + Self { + status: AlgoStatus::Error, + error: Some(message), + result_count: None, + timing: None, + } + } + + pub fn panic(message: String) -> Self { + Self { + status: AlgoStatus::Panic, + error: Some(message), + result_count: None, + timing: None, + } + } +} + +#[derive(Debug, Serialize)] +pub struct AlgoTiming { + pub total: TimingStats, + pub ffi_only: TimingStats, +} + +/// Timing statistics extracted from criterion estimates. +#[derive(Debug, Serialize)] +pub struct TimingStats { + pub mean_ns: f64, + pub median_ns: f64, + pub stddev_ns: f64, + pub iterations: usize, +} + +#[derive(Debug, Serialize)] +pub struct QueryResult { + pub query_index: usize, + pub query_id: String, + pub query_text: String, + pub algorithms: HashMap, +} + +#[derive(Debug, Serialize)] +pub struct QueryOutput { + pub metadata: QueryMetadata, + pub results: Vec, +} + +#[derive(Debug, Serialize)] +pub struct QueryMetadata { + pub timestamp: String, + pub graph_path: String, + pub graph_format: String, + pub queries_file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_iri: Option, + pub num_nodes: usize, + pub num_labels: usize, +} + +impl QueryOutput { + pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { + let json = serde_json::to_string_pretty(self) + .map_err(std::io::Error::other)?; + fs::write(path, json) + } +} + +#[derive(Debug, Serialize)] +pub struct BenchOutput { + pub metadata: BenchMetadata, + pub results: Vec, +} + +#[derive(Debug, Serialize)] +pub struct BenchMetadata { + pub timestamp: String, + pub graph_path: String, + pub graph_format: String, + pub queries_file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_iri: Option, + pub num_nodes: usize, + pub num_labels: usize, + pub sample_size: usize, + pub warm_up_secs: u64, + pub measurement_secs: u64, +} + +impl BenchOutput { + pub fn write_to_file(&self, path: &Path) -> Result<(), std::io::Error> { + let json = serde_json::to_string_pretty(self) + .map_err(std::io::Error::other)?; + fs::write(path, json) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn algo_result_serializes_split_timing() { + let result = AlgoResult::ok( + Some(3), + Some(AlgoTiming { + total: TimingStats { + mean_ns: 1.0, + median_ns: 1.0, + stddev_ns: 0.0, + iterations: 10, + }, + ffi_only: TimingStats { + mean_ns: 0.5, + median_ns: 0.5, + stddev_ns: 0.0, + iterations: 10, + }, + }), + ); + + let value = serde_json::to_value(&result).expect("serialize"); + assert!(value["timing"]["total"].is_object()); + assert!(value["timing"]["ffi_only"].is_object()); + assert_eq!(value["status"], "ok"); + } + + #[test] + fn error_status_serializes_lowercase() { + let r = AlgoResult::error("boom".into()); + let v = serde_json::to_value(&r).expect("serialize"); + assert_eq!(v["status"], "error"); + assert_eq!(v["error"], "boom"); + } + + #[test] + fn panic_status_serializes_lowercase() { + let r = AlgoResult::panic("kaboom".into()); + let v = serde_json::to_value(&r).expect("serialize"); + assert_eq!(v["status"], "panic"); + } +} diff --git a/src/cli/query.rs b/src/cli/query.rs new file mode 100644 index 0000000..152883c --- /dev/null +++ b/src/cli/query.rs @@ -0,0 +1,114 @@ +//! Single-shot query runner for the `query` subcommand. +//! +//! Runs each query once per algorithm, prints a per-query summary to stderr, +//! and returns structured results that the binary can optionally write to JSON. + +use std::collections::HashMap; + +use crate::eval::{Evaluator, ResultCount}; +use crate::graph::InMemoryGraph; +use crate::rpq::{RpqError, RpqQuery}; + +use super::loader::LoadedQuery; +use super::output::{AlgoResult, QueryResult}; + +/// Run all queries once for one evaluator and return structured results. +pub fn run_query_for_evaluator( + algo_name: &str, + evaluator: E, + graph: &InMemoryGraph, + queries: &[LoadedQuery], +) -> Vec +where + E: Evaluator + Copy, + E::Result: ResultCount, +{ + let mut results = Vec::with_capacity(queries.len()); + + for (idx, loaded) in queries.iter().enumerate() { + let mut algo_results: HashMap = HashMap::new(); + + let query: &RpqQuery = match &loaded.parsed { + Ok(q) => q, + Err(e) => { + eprintln!("[query #{idx}] id={} — parse error: {e}", loaded.id); + algo_results.insert(algo_name.to_string(), AlgoResult::error(e.to_string())); + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + continue; + } + }; + + let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + evaluator + .evaluate(query, graph) + .and_then(|result| result.result_count().map_err(RpqError::Graph)) + })); + + let algo_result = match outcome { + Ok(Ok(count)) => { + eprintln!( + "[query #{idx}] id={} algo={} — {count} count", + loaded.id, algo_name + ); + AlgoResult::ok(Some(count), None) + } + Ok(Err(e)) => { + eprintln!( + "[query #{idx}] id={} algo={} — error: {e}", + loaded.id, algo_name + ); + AlgoResult::error(e.to_string()) + } + Err(panic_info) => { + let msg = format!("{:?}", panic_info); + eprintln!( + "[query #{idx}] id={} algo={} — panic: {msg}", + loaded.id, algo_name + ); + AlgoResult::panic(msg) + } + }; + + algo_results.insert(algo_name.to_string(), algo_result); + + results.push(QueryResult { + query_index: idx, + query_id: loaded.id.clone(), + query_text: loaded.text.clone(), + algorithms: algo_results, + }); + } + + results +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rpq::nfarpq::NfaRpqEvaluator; + use crate::utils::build_graph; + + #[test] + fn generic_runner_records_result_for_one_algorithm() { + let graph = build_graph(&[("A", "B", "p")]); + let queries = vec![LoadedQuery { + id: "q1".into(), + text: "SELECT ?x ?y WHERE { ?x

?y . }".into(), + parsed: Ok(RpqQuery { + subject: crate::rpq::Endpoint::Variable("x".into()), + path: crate::rpq::PathExpr::Label("p".into()), + object: crate::rpq::Endpoint::Variable("y".into()), + }), + }]; + + let results = run_query_for_evaluator("nfarpq", NfaRpqEvaluator, &graph, &queries); + + assert_eq!(results.len(), 1); + assert!(results[0].algorithms.contains_key("nfarpq")); + } +} diff --git a/src/eval/mod.rs b/src/eval/mod.rs new file mode 100644 index 0000000..fb652c5 --- /dev/null +++ b/src/eval/mod.rs @@ -0,0 +1,35 @@ +//! Generic abstractions over query evaluators. + +use crate::graph::{GraphDecomposition, GraphError}; + +pub trait Evaluator { + type Query; + type Result; + type Error; + type Prepared: PreparedEvaluator; + + fn prepare( + &self, + query: &Self::Query, + graph: &G, + ) -> Result; + + fn evaluate( + &self, + query: &Self::Query, + graph: &G, + ) -> Result { + self.prepare(query, graph)?.execute() + } +} + +pub trait PreparedEvaluator { + type Result; + type Error; + + fn execute(&mut self) -> Result; +} + +pub trait ResultCount { + fn result_count(&self) -> Result; +} diff --git a/src/graph/inmemory.rs b/src/graph/inmemory.rs index 5c764e4..9121101 100644 --- a/src/graph/inmemory.rs +++ b/src/graph/inmemory.rs @@ -11,8 +11,8 @@ use crate::{ }; use super::{ - compute_outer_inner, load_mm_file, Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, - LagraphGraph, ThreadScope, + Backend, Edge, GraphBuilder, GraphDecomposition, GraphError, LagraphGraph, ThreadScope, + compute_outer_inner, load_mm_file, }; /// Marker type for the in-memory GraphBLAS-backed backend. @@ -199,6 +199,13 @@ impl GraphDecomposition for InMemoryGraph { } } +impl InMemoryGraph { + /// Returns the number of distinct edge labels in the graph. + pub fn num_labels(&self) -> usize { + self.graphs.len() + } +} + impl GraphSource for Csv { fn apply_to( self, diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 5b489d0..0922459 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -4,8 +4,8 @@ pub mod inmemory; pub mod wrappers; pub use inmemory::{InMemory, InMemoryBuilder, InMemoryGraph}; -pub(crate) use wrappers::{compute_outer_inner, ensure_grb_init, ThreadScope}; -pub use wrappers::{load_mm_file, GraphblasMatrix, GraphblasVector, LagraphGraph}; +pub use wrappers::{GraphblasMatrix, GraphblasVector, LagraphGraph, load_mm_file}; +pub(crate) use wrappers::{ThreadScope, compute_outer_inner, ensure_grb_init}; use std::marker::PhantomData; use std::sync::Arc; diff --git a/src/graph/wrappers.rs b/src/graph/wrappers.rs index f8128e6..e97cfc5 100644 --- a/src/graph/wrappers.rs +++ b/src/graph/wrappers.rs @@ -176,6 +176,12 @@ pub struct GraphblasVector { } impl GraphblasVector { + pub fn new_bool(n: GrB_Index) -> Result { + let mut v: GrB_Vector = std::ptr::null_mut(); + unsafe { grb_ok!(GrB_Vector_new(&mut v, GrB_BOOL, n))? }; + Ok(Self { inner: v }) + } + /// Returns the number of stored values in this vector. pub fn nvals(&self) -> Result { let mut nvals: GrB_Index = 0; diff --git a/src/lagraph_sys_generated.rs b/src/lagraph_sys_generated.rs index 5e02d97..1a9188f 100644 --- a/src/lagraph_sys_generated.rs +++ b/src/lagraph_sys_generated.rs @@ -155,6 +155,9 @@ unsafe extern "C" { ncols: GrB_Index, ) -> GrB_Info; } +unsafe extern "C" { + pub fn GrB_Matrix_dup(C: *mut GrB_Matrix, A: GrB_Matrix) -> GrB_Info; +} unsafe extern "C" { pub fn GrB_Matrix_nvals(nvals: *mut GrB_Index, A: GrB_Matrix) -> GrB_Info; } @@ -338,3 +341,10 @@ unsafe extern "C" { unsafe extern "C" { pub fn LAGraph_RPQMatrix_Free(mat: *mut GrB_Matrix) -> GrB_Info; } +unsafe extern "C" { + pub fn LAGraph_RPQMatrix_reduce( + res: *mut GrB_Index, + mat: GrB_Matrix, + reduce_type: u8, + ) -> GrB_Info; +} diff --git a/src/lib.rs b/src/lib.rs index 0f89008..2502767 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod eval; pub mod formats; pub mod graph; pub mod rpq; @@ -6,3 +7,6 @@ pub mod sparql; pub mod utils; pub mod lagraph_sys; + +#[cfg(feature = "bench")] +pub mod cli; diff --git a/src/rpq/mod.rs b/src/rpq/mod.rs index db2d220..48e0cc4 100644 --- a/src/rpq/mod.rs +++ b/src/rpq/mod.rs @@ -13,6 +13,7 @@ pub mod nfarpq; pub mod rpqmatrix; +use crate::eval::{Evaluator, PreparedEvaluator}; use crate::graph::{GraphDecomposition, GraphError}; use crate::sparql::ExtractError; use spargebra::SparqlSyntaxError; @@ -51,7 +52,9 @@ impl RpqQuery { } fn strip_endpoint(ep: &mut Endpoint, base: &str) { - if let Endpoint::Named(s) = ep && s.starts_with(base) { + if let Endpoint::Named(s) = ep + && s.starts_with(base) + { *s = s[base.len()..].to_owned(); } } @@ -91,13 +94,28 @@ pub enum RpqError { Graph(#[from] GraphError), } -pub trait RpqEvaluator { - /// Output of this evaluator (e.g. reachable vector vs path matrix + nnz). - type Result; +pub trait RpqEvaluator: Evaluator { + fn prepare( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + ::prepare(self, query, graph) + } fn evaluate( &self, query: &RpqQuery, graph: &G, - ) -> Result; + ) -> Result { + ::evaluate(self, query, graph) + } +} +impl RpqEvaluator for T where T: Evaluator {} + +pub trait PreparedRpq: PreparedEvaluator { + fn execute(&mut self) -> Result { + ::execute(self) + } } +impl PreparedRpq for T where T: PreparedEvaluator {} diff --git a/src/rpq/nfarpq.rs b/src/rpq/nfarpq.rs index a616b64..29e040b 100644 --- a/src/rpq/nfarpq.rs +++ b/src/rpq/nfarpq.rs @@ -1,11 +1,12 @@ //! NFA-based RPQ evaluation using `LAGraph_RegularPathQuery`. -use crate::graph::{GraphDecomposition, GraphblasVector, LagraphGraph}; -use crate::la_ok; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::{GraphDecomposition, GraphError, GraphblasVector, LagraphGraph}; use crate::lagraph_sys::LAGraph_Kind; use crate::lagraph_sys::*; -use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; -use rustfst::algorithms::closure::{closure, ClosureType}; +use crate::rpq::{Endpoint, PathExpr, RpqError, RpqQuery}; +use crate::{grb_ok, la_ok}; +use rustfst::algorithms::closure::{ClosureType, closure}; use rustfst::algorithms::concat::concat; use rustfst::algorithms::rm_epsilon::rm_epsilon; use rustfst::algorithms::union::union; @@ -13,6 +14,7 @@ use rustfst::prelude::*; use rustfst::semirings::TropicalWeight; use rustfst::utils::{acceptor, epsilon_machine}; use std::collections::HashMap; +use std::sync::Arc; /// Transitions for a single edge label in the NFA. /// @@ -207,60 +209,127 @@ pub struct NfaRpqResult { pub reachable: GraphblasVector, } +impl ResultCount for NfaRpqResult { + fn result_count(&self) -> Result { + Ok(self.reachable.nvals()? as usize) + } +} + +pub struct PreparedNfaRpq { + nfa: Nfa, + nfa_matrices: Vec<(String, LagraphGraph)>, + nfa_graph_ptrs: Vec, + _data_graphs: Vec>, + data_graph_ptrs: Vec, + source_vertices: Vec, + destination_vertex: Option, + num_nodes: usize, +} + +fn filter_reachable_by_destination( + reachable: GraphblasVector, + destination_vertex: Option, + num_nodes: usize, +) -> Result { + let Some(destination_vertex) = destination_vertex else { + return Ok(reachable); + }; + + let indices = reachable.indices().map_err(RpqError::Graph)?; + let filtered = GraphblasVector::new_bool(num_nodes as GrB_Index)?; + + if indices.contains(&(destination_vertex as GrB_Index)) { + unsafe { + grb_ok!(GrB_Vector_setElement_BOOL( + filtered.inner, + true, + destination_vertex as GrB_Index, + ))? + }; + } + + Ok(filtered) +} + +impl PreparedEvaluator for PreparedNfaRpq { + type Result = NfaRpqResult; + type Error = RpqError; + + fn execute(&mut self) -> Result { + let mut reachable: GrB_Vector = std::ptr::null_mut(); + + unsafe { + la_ok!(LAGraph_RegularPathQuery( + &mut reachable, + self.nfa_graph_ptrs.as_mut_ptr(), + self.nfa_matrices.len(), + self.nfa.start_states.as_ptr(), + self.nfa.start_states.len(), + self.nfa.final_states.as_ptr(), + self.nfa.final_states.len(), + self.data_graph_ptrs.as_mut_ptr(), + self.source_vertices.as_ptr(), + self.source_vertices.len(), + ))? + }; + + let reachable = filter_reachable_by_destination( + GraphblasVector { inner: reachable }, + self.destination_vertex, + self.num_nodes, + )?; + + Ok(NfaRpqResult { reachable }) + } +} + /// Evaluates RPQs using `LAGraph_RegularPathQuery`. +#[derive(Clone, Copy)] pub struct NfaRpqEvaluator; -impl RpqEvaluator for NfaRpqEvaluator { +impl Evaluator for NfaRpqEvaluator { + type Query = RpqQuery; type Result = NfaRpqResult; + type Error = RpqError; + type Prepared = PreparedNfaRpq; - fn evaluate( + fn prepare( &self, query: &RpqQuery, graph: &G, - ) -> Result { + ) -> Result { let nfa = Nfa::from_path_expr(&query.path)?; let nfa_matrices = nfa.build_lagraph_matrices()?; let src_id = resolve_endpoint(&query.subject, graph)?; - let _dst_id = resolve_endpoint(&query.object, graph)?; + let dst_id = resolve_endpoint(&query.object, graph)?; let n = graph.num_nodes(); - let source_vertices: Vec = match src_id { Some(id) => vec![id as GrB_Index], None => (0..n as GrB_Index).collect(), }; - let mut nfa_graph_ptrs: Vec = + let nfa_graph_ptrs: Vec = nfa_matrices.iter().map(|(_, lg)| lg.inner).collect(); - let mut data_graph_ptrs: Vec = Vec::with_capacity(nfa_matrices.len()); + let mut data_graphs = Vec::with_capacity(nfa_matrices.len()); + let mut data_graph_ptrs = Vec::with_capacity(nfa_matrices.len()); for (label, _) in &nfa_matrices { let lg = graph.get_graph(label)?; data_graph_ptrs.push(lg.inner); + data_graphs.push(lg); } - let mut reachable: GrB_Vector = std::ptr::null_mut(); - - unsafe { - la_ok!(LAGraph_RegularPathQuery( - &mut reachable, - nfa_graph_ptrs.as_mut_ptr(), - nfa_matrices.len(), - nfa.start_states.as_ptr(), - nfa.start_states.len(), - nfa.final_states.as_ptr(), - nfa.final_states.len(), - data_graph_ptrs.as_mut_ptr(), - source_vertices.as_ptr(), - source_vertices.len(), - ))? - }; - - let result_vec = GraphblasVector { inner: reachable }; - - Ok(NfaRpqResult { - reachable: result_vec, + Ok(PreparedNfaRpq { + nfa, + nfa_matrices, + nfa_graph_ptrs, + _data_graphs: data_graphs, + data_graph_ptrs, + source_vertices, + destination_vertex: dst_id, + num_nodes: n, }) } } diff --git a/src/rpq/rpqmatrix.rs b/src/rpq/rpqmatrix.rs index 72f9110..1462abc 100644 --- a/src/rpq/rpqmatrix.rs +++ b/src/rpq/rpqmatrix.rs @@ -4,11 +4,14 @@ use std::ptr::null_mut; use egg::{Id, RecExpr, define_language}; -use crate::graph::{GraphDecomposition, GraphblasMatrix}; +use crate::eval::{Evaluator, PreparedEvaluator, ResultCount}; +use crate::graph::{GraphDecomposition, GraphError, GraphblasMatrix}; use crate::lagraph_sys::*; -use crate::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use crate::rpq::{Endpoint, PathExpr, RpqError, RpqQuery}; use crate::{grb_ok, la_ok}; +const RPQMATRIX_REDUCE_BY_COL: u8 = 1; + define_language! { pub enum RpqPlan { Label(String), @@ -174,43 +177,89 @@ pub struct RpqMatrixResult { pub matrix: GraphblasMatrix, } -/// RPQ evaluator backed by `LAGraph_RPQMatrix`. -pub struct RpqMatrixEvaluator; +impl RpqMatrixResult { + /// Count distinct reachable target vertices by reducing the path relation + /// matrix to its non-empty columns. + pub fn reachable_target_count(&self) -> Result { + let mut count: GrB_Index = 0; + unsafe { + grb_ok!(LAGraph_RPQMatrix_reduce( + &mut count, + self.matrix.inner, + RPQMATRIX_REDUCE_BY_COL, + ))? + }; + Ok(count as u64) + } +} -impl RpqEvaluator for RpqMatrixEvaluator { - type Result = RpqMatrixResult; +impl ResultCount for RpqMatrixResult { + fn result_count(&self) -> Result { + Ok(self.reachable_target_count()? as usize) + } +} - fn evaluate( - &self, - query: &RpqQuery, - graph: &G, - ) -> Result { - let expr = query_to_expr(query)?; - let (mut plans, owned_matrices) = materialize(&expr, graph)?; +pub struct PreparedRpqMatrix { + plans: Vec, + owned_matrices: Vec, +} - let root_ptr = unsafe { plans.as_mut_ptr().add(plans.len() - 1) }; +impl PreparedEvaluator for PreparedRpqMatrix { + type Result = RpqMatrixResult; + type Error = RpqError; + + fn execute(&mut self) -> Result { + let root_ptr = unsafe { self.plans.as_mut_ptr().add(self.plans.len() - 1) }; let mut nnz: GrB_Index = 0; unsafe { la_ok!(LAGraph_RPQMatrix(&mut nnz, root_ptr))? }; - let matrix = unsafe { - let mat = (*root_ptr).res_mat; - (*root_ptr).res_mat = null_mut(); - GraphblasMatrix { inner: mat } + let mut matrix_inner: GrB_Matrix = null_mut(); + unsafe { grb_ok!(GrB_Matrix_dup(&mut matrix_inner, (*root_ptr).res_mat))? }; + let matrix = GraphblasMatrix { + inner: matrix_inner, }; unsafe { grb_ok!(LAGraph_DestroyRpqMatrixPlan(root_ptr))? }; - // Free diagonal matrices created for named vertices. - for mut mat in owned_matrices { + Ok(RpqMatrixResult { + nnz: nnz as u64, + matrix, + }) + } +} + +impl Drop for PreparedRpqMatrix { + fn drop(&mut self) { + for mat in &mut self.owned_matrices { unsafe { - LAGraph_RPQMatrix_Free(&mut mat); + LAGraph_RPQMatrix_Free(mat); } } + } +} - Ok(RpqMatrixResult { - nnz: nnz as u64, - matrix, +/// RPQ evaluator backed by `LAGraph_RPQMatrix`. +#[derive(Clone, Copy)] +pub struct RpqMatrixEvaluator; + +impl Evaluator for RpqMatrixEvaluator { + type Query = RpqQuery; + type Result = RpqMatrixResult; + type Error = RpqError; + type Prepared = PreparedRpqMatrix; + + fn prepare( + &self, + query: &RpqQuery, + graph: &G, + ) -> Result { + let expr = query_to_expr(query)?; + let (plans, owned_matrices) = materialize(&expr, graph)?; + + Ok(PreparedRpqMatrix { + plans, + owned_matrices, }) } } diff --git a/src/utils.rs b/src/utils.rs index a6d5d51..edf2b18 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -134,16 +134,14 @@ impl From for GrB_Info { /// ``` #[macro_export] macro_rules! grb_ok { - ($grb_func:expr) => { - { - let info: $crate::lagraph_sys::GrB_Info = $grb_func.into(); - if info == $crate::lagraph_sys::GrB_Info::GrB_SUCCESS { - Ok(()) - } else { - Err($crate::graph::GraphError::GraphBlas(info)) - } + ($grb_func:expr) => {{ + let info: $crate::lagraph_sys::GrB_Info = $grb_func.into(); + if info == $crate::lagraph_sys::GrB_Info::GrB_SUCCESS { + Ok(()) + } else { + Err($crate::graph::GraphError::GraphBlas(info)) } - }; + }}; } /// Calls a raw LAGraph function and maps its `i32` return code to diff --git a/tests/nfarpq_tests.rs b/tests/nfarpq_tests.rs index fbdf42a..0f97b09 100644 --- a/tests/nfarpq_tests.rs +++ b/tests/nfarpq_tests.rs @@ -7,7 +7,7 @@ use pathrex::formats::mm::MatrixMarket; use pathrex::graph::{Graph, GraphDecomposition, GraphError, InMemory, InMemoryGraph}; use pathrex::lagraph_sys::GrB_Index; use pathrex::rpq::nfarpq::NfaRpqEvaluator; -use pathrex::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use pathrex::rpq::{Endpoint, PathExpr, PreparedRpq, RpqError, RpqEvaluator, RpqQuery}; use pathrex::sparql::parse_rpq; use pathrex::utils::build_graph; @@ -21,13 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line.split_once(',').map(|x| x.1) + let query_str = line + .split_once(',') + .map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } @@ -167,6 +168,70 @@ fn test_sequence_path() { assert_eq!(count, 1); } +#[test] +fn prepared_nfa_execution_matches_evaluate() { + let graph = build_graph(&[("A", "B", "knows"), ("B", "C", "likes")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("knows")), Box::new(label("likes"))), + var("y"), + ); + + let evaluator = NfaRpqEvaluator; + let direct = evaluator.evaluate(&query, &graph).expect("direct"); + let direct_count = direct.reachable.nvals().expect("direct nvals"); + + let mut prepared = evaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("prepared execute"); + let prepared_count = prepared_result.reachable.nvals().expect("prepared nvals"); + + assert_eq!(prepared_count, direct_count); +} + +#[test] +fn test_bound_object_filters_reachable_targets() { + let graph = build_graph(&[("A", "B", "knows"), ("C", "D", "knows")]); + let evaluator = NfaRpqEvaluator; + + let result = evaluator + .evaluate(&rq(var("x"), label("knows"), named_ep("B")), &graph) + .expect("evaluate should succeed"); + + let indices = result + .reachable + .indices() + .expect("failed to extract indices"); + let b_id = graph.get_node_id("B").expect("B should exist") as GrB_Index; + let d_id = graph.get_node_id("D").expect("D should exist") as GrB_Index; + + assert_eq!( + indices, + vec![b_id], + "only the bound object should remain reachable" + ); + assert!( + !indices.contains(&d_id), + "unbound reachable targets must be filtered out" + ); +} + +#[test] +fn prepared_nfa_execution_respects_bound_object() { + let graph = build_graph(&[("A", "B", "knows"), ("C", "D", "knows")]); + let evaluator = NfaRpqEvaluator; + let query = rq(var("x"), label("knows"), named_ep("B")); + + let mut prepared = evaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("prepared execute"); + let prepared_indices = prepared_result + .reachable + .indices() + .expect("prepared indices"); + let b_id = graph.get_node_id("B").expect("B should exist") as GrB_Index; + + assert_eq!(prepared_indices, vec![b_id]); +} + /// Graph: A --knows--> B --likes--> C /// Query: / ?y #[test] diff --git a/tests/rpqmatrix_tests.rs b/tests/rpqmatrix_tests.rs index b23e0e9..3f84d80 100644 --- a/tests/rpqmatrix_tests.rs +++ b/tests/rpqmatrix_tests.rs @@ -7,7 +7,7 @@ use pathrex::formats::mm::MatrixMarket; use pathrex::graph::{Graph, GraphDecomposition, GraphError, InMemory, InMemoryGraph}; use pathrex::lagraph_sys::{GrB_Index, GrB_Info, GrB_Matrix_extractElement_BOOL}; use pathrex::rpq::rpqmatrix::{RpqMatrixEvaluator, RpqMatrixResult}; -use pathrex::rpq::{Endpoint, PathExpr, RpqError, RpqEvaluator, RpqQuery}; +use pathrex::rpq::{Endpoint, PathExpr, PreparedRpq, RpqError, RpqEvaluator, RpqQuery}; use pathrex::sparql::parse_rpq; use pathrex::utils::build_graph; @@ -21,13 +21,14 @@ static LA_N_EGG_GRAPH: LazyLock = LazyLock::new(|| { }); fn convert_query_line(line: &str) -> RpqQuery { - let query_str = line.split_once(',').map(|x| x.1) + let query_str = line + .split_once(',') + .map(|x| x.1) .unwrap_or_else(|| panic!("query line has no comma: {line:?}")) .trim(); let sparql = format!("BASE <{BASE_IRI}> SELECT * WHERE {{ {query_str} . }}"); - parse_rpq(&sparql).unwrap_or_else(|e| panic!("failed to parse query {line:?}: {e}")) } @@ -172,6 +173,38 @@ fn test_sequence_path() { assert_eq!(result.nnz, 1); } +#[test] +fn prepared_rpqmatrix_execution_matches_evaluate() { + let graph = build_graph(&[("A", "B", "p"), ("B", "C", "q")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("p")), Box::new(label("q"))), + var("y"), + ); + + let direct = RpqMatrixEvaluator.evaluate(&query, &graph).expect("direct"); + let mut prepared = RpqMatrixEvaluator.prepare(&query, &graph).expect("prepare"); + let prepared_result = prepared.execute().expect("execute"); + + assert_eq!(prepared_result.nnz, direct.nnz); +} + +#[test] +fn prepared_rpqmatrix_execution_can_run_twice() { + let graph = build_graph(&[("A", "B", "p"), ("B", "C", "q")]); + let query = rq( + var("x"), + PathExpr::Sequence(Box::new(label("p")), Box::new(label("q"))), + var("y"), + ); + + let mut prepared = RpqMatrixEvaluator.prepare(&query, &graph).expect("prepare"); + let first = prepared.execute().expect("first"); + let second = prepared.execute().expect("second"); + + assert_eq!(first.nnz, second.nnz); +} + /// Graph: A --knows--> B --likes--> C /// Query: / ?y → only A→C, nnz=1 #[test]