From 27e3d34bc3a38962ab45f9fb9c3079f15d87ed26 Mon Sep 17 00:00:00 2001 From: "John C. Burnham" Date: Fri, 6 Feb 2026 01:26:09 -0500 Subject: [PATCH] Add DAG-based kernel typechecker Implement a Lean 4 kernel typechecker using a DAG representation with BUBS (Bottom-Up Beta Substitution) for efficient reduction. The kernel operates on a mutable DAG rather than tree-based expressions, enabling in-place substitution and shared subterm reduction. 12 modules: doubly-linked list, DAG nodes with 10 pointer variants, BUBS upcopy with 12 parent cases, Expr/DAG conversion, universe level operations, WHNF via trail algorithm, definitional equality with lazy delta/proof irrelevance/eta, type inference, and checking for quotients and inductives. --- src/ix.rs | 1 + src/ix/kernel/convert.rs | 813 +++++++++++++++++ src/ix/kernel/dag.rs | 527 +++++++++++ src/ix/kernel/def_eq.rs | 1298 +++++++++++++++++++++++++++ src/ix/kernel/dll.rs | 214 +++++ src/ix/kernel/error.rs | 59 ++ src/ix/kernel/inductive.rs | 772 ++++++++++++++++ src/ix/kernel/level.rs | 393 +++++++++ src/ix/kernel/mod.rs | 11 + src/ix/kernel/quot.rs | 291 +++++++ src/ix/kernel/tc.rs | 1694 ++++++++++++++++++++++++++++++++++++ src/ix/kernel/upcopy.rs | 659 ++++++++++++++ src/ix/kernel/whnf.rs | 1420 ++++++++++++++++++++++++++++++ 13 files changed, 8152 insertions(+) create mode 100644 src/ix/kernel/convert.rs create mode 100644 src/ix/kernel/dag.rs create mode 100644 src/ix/kernel/def_eq.rs create mode 100644 src/ix/kernel/dll.rs create mode 100644 src/ix/kernel/error.rs create mode 100644 src/ix/kernel/inductive.rs create mode 100644 src/ix/kernel/level.rs create mode 100644 src/ix/kernel/mod.rs create mode 100644 src/ix/kernel/quot.rs create mode 100644 src/ix/kernel/tc.rs create mode 100644 src/ix/kernel/upcopy.rs create mode 100644 src/ix/kernel/whnf.rs diff --git a/src/ix.rs b/src/ix.rs index f200d81b..42d298c2 100644 --- a/src/ix.rs +++ b/src/ix.rs @@ -12,6 +12,7 @@ pub mod env; pub mod graph; pub mod ground; pub mod ixon; +pub mod kernel; pub mod mutual; pub mod store; pub mod strong_ordering; diff --git a/src/ix/kernel/convert.rs b/src/ix/kernel/convert.rs new file mode 100644 index 00000000..90811948 --- /dev/null +++ b/src/ix/kernel/convert.rs @@ -0,0 +1,813 @@ +use core::ptr::NonNull; +use std::collections::BTreeMap; + +use crate::ix::env::{Expr, ExprData, Level, Name}; +use crate::lean::nat::Nat; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Expr -> DAG +// ============================================================================ + +pub fn from_expr(expr: &Expr) -> DAG { + let root_parents = DLL::alloc(ParentPtr::Root); + let head = from_expr_go(expr, 0, &BTreeMap::new(), Some(root_parents)); + DAG { head } +} + +fn from_expr_go( + expr: &Expr, + depth: u64, + ctx: &BTreeMap>, + parents: Option>, +) -> DAGPtr { + match expr.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 < depth { + let level = depth - 1 - idx_u64; + match ctx.get(&level) { + Some(&var_ptr) => { + if let Some(parent_link) = parents { + add_to_parents(DAGPtr::Var(var_ptr), parent_link); + } + DAGPtr::Var(var_ptr) + }, + None => { + let var = alloc_val(Var { + depth: level, + binder: BinderPtr::Free, + parents, + }); + DAGPtr::Var(var) + }, + } + } else { + // Free bound variable (dangling de Bruijn index) + let var = + alloc_val(Var { depth: idx_u64, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + } + }, + + ExprData::Fvar(_name, _) => { + // Encode fvar name into depth as a unique ID. + // We'll recover it during to_expr using a side table. + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + // Store name→var mapping (caller should manage the side table) + DAGPtr::Var(var) + }, + + ExprData::Sort(level, _) => { + let sort = alloc_val(Sort { level: level.clone(), parents }); + DAGPtr::Sort(sort) + }, + + ExprData::Const(name, levels, _) => { + let cnst = alloc_val(Cnst { + name: name.clone(), + levels: levels.clone(), + parents, + }); + DAGPtr::Cnst(cnst) + }, + + ExprData::Lit(lit, _) => { + let lit_node = alloc_val(LitNode { val: lit.clone(), parents }); + DAGPtr::Lit(lit_node) + }, + + ExprData::App(fun_expr, arg_expr, _) => { + let app_ptr = alloc_app( + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let app = &mut *app_ptr.as_ptr(); + let fun_ref_ptr = + NonNull::new(&mut app.fun_ref as *mut Parents).unwrap(); + let arg_ref_ptr = + NonNull::new(&mut app.arg_ref as *mut Parents).unwrap(); + app.fun = from_expr_go(fun_expr, depth, ctx, Some(fun_ref_ptr)); + app.arg = from_expr_go(arg_expr, depth, ctx, Some(arg_ref_ptr)); + } + DAGPtr::App(app_ptr) + }, + + ExprData::Lam(name, typ, body, bi, _) => { + // Lean Lam → DAG Fun(dom, Lam(bod, var)) + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let fun_ptr = alloc_fun( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut fun.dom_ref as *mut Parents).unwrap(); + fun.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + // Set Lam's parent to FunImg + let img_ref_ptr = + NonNull::new(&mut fun.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Fun(fun_ptr) + }, + + ExprData::ForallE(name, typ, body, bi, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let pi_ptr = alloc_pi( + name.clone(), + bi.clone(), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + let dom_ref_ptr = + NonNull::new(&mut pi.dom_ref as *mut Parents).unwrap(); + pi.dom = from_expr_go(typ, depth, ctx, Some(dom_ref_ptr)); + + let img_ref_ptr = + NonNull::new(&mut pi.img_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), img_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(bod_ref_ptr)); + } + DAGPtr::Pi(pi_ptr) + }, + + ExprData::LetE(name, typ, val, body, non_dep, _) => { + let lam_ptr = + alloc_lam(depth, DAGPtr::Var(NonNull::dangling()), None); + let let_ptr = alloc_let( + name.clone(), + *non_dep, + DAGPtr::Var(NonNull::dangling()), + DAGPtr::Var(NonNull::dangling()), + lam_ptr, + parents, + ); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let typ_ref_ptr = + NonNull::new(&mut let_node.typ_ref as *mut Parents).unwrap(); + let val_ref_ptr = + NonNull::new(&mut let_node.val_ref as *mut Parents).unwrap(); + let_node.typ = from_expr_go(typ, depth, ctx, Some(typ_ref_ptr)); + let_node.val = from_expr_go(val, depth, ctx, Some(val_ref_ptr)); + + let bod_ref_ptr = + NonNull::new(&mut let_node.bod_ref as *mut Parents).unwrap(); + add_to_parents(DAGPtr::Lam(lam_ptr), bod_ref_ptr); + + let lam = &mut *lam_ptr.as_ptr(); + let var_ptr = NonNull::new(&mut lam.var as *mut Var).unwrap(); + let mut new_ctx = ctx.clone(); + new_ctx.insert(depth, var_ptr); + let inner_bod_ref_ptr = + NonNull::new(&mut lam.bod_ref as *mut Parents).unwrap(); + lam.bod = + from_expr_go(body, depth + 1, &new_ctx, Some(inner_bod_ref_ptr)); + } + DAGPtr::Let(let_ptr) + }, + + ExprData::Proj(type_name, idx, structure, _) => { + let proj_ptr = alloc_proj( + type_name.clone(), + idx.clone(), + DAGPtr::Var(NonNull::dangling()), + parents, + ); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + let expr_ref_ptr = + NonNull::new(&mut proj.expr_ref as *mut Parents).unwrap(); + proj.expr = + from_expr_go(structure, depth, ctx, Some(expr_ref_ptr)); + } + DAGPtr::Proj(proj_ptr) + }, + + // Mdata: strip metadata, convert inner expression + ExprData::Mdata(_, inner, _) => from_expr_go(inner, depth, ctx, parents), + + // Mvar: treat as terminal (shouldn't appear in well-typed terms) + ExprData::Mvar(_name, _) => { + let var = alloc_val(Var { depth: 0, binder: BinderPtr::Free, parents }); + DAGPtr::Var(var) + }, + } +} + +// ============================================================================ +// Literal clone +// ============================================================================ + +impl Clone for crate::ix::env::Literal { + fn clone(&self) -> Self { + match self { + crate::ix::env::Literal::NatVal(n) => { + crate::ix::env::Literal::NatVal(n.clone()) + }, + crate::ix::env::Literal::StrVal(s) => { + crate::ix::env::Literal::StrVal(s.clone()) + }, + } + } +} + +// ============================================================================ +// DAG -> Expr +// ============================================================================ + +pub fn to_expr(dag: &DAG) -> Expr { + let mut var_map: BTreeMap<*const Var, u64> = BTreeMap::new(); + to_expr_go(dag.head, &mut var_map, 0) +} + +fn to_expr_go( + node: DAGPtr, + var_map: &mut BTreeMap<*const Var, u64>, + depth: u64, +) -> Expr { + unsafe { + match node { + DAGPtr::Var(link) => { + let var = link.as_ptr(); + let var_key = var as *const Var; + if let Some(&bind_depth) = var_map.get(&var_key) { + let idx = depth - bind_depth - 1; + Expr::bvar(Nat::from(idx)) + } else { + // Free variable + Expr::bvar(Nat::from((*var).depth)) + } + }, + + DAGPtr::Sort(link) => { + let sort = &*link.as_ptr(); + Expr::sort(sort.level.clone()) + }, + + DAGPtr::Cnst(link) => { + let cnst = &*link.as_ptr(); + Expr::cnst(cnst.name.clone(), cnst.levels.clone()) + }, + + DAGPtr::Lit(link) => { + let lit = &*link.as_ptr(); + Expr::lit(lit.val.clone()) + }, + + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun = to_expr_go(app.fun, var_map, depth); + let arg = to_expr_go(app.arg, var_map, depth); + Expr::app(fun, arg) + }, + + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let lam = &*fun.img.as_ptr(); + let dom = to_expr_go(fun.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::lam( + fun.binder_name.clone(), + dom, + bod, + fun.binder_info.clone(), + ) + }, + + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let lam = &*pi.img.as_ptr(); + let dom = to_expr_go(pi.dom, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::all( + pi.binder_name.clone(), + dom, + bod, + pi.binder_info.clone(), + ) + }, + + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let lam = &*let_node.bod.as_ptr(); + let typ = to_expr_go(let_node.typ, var_map, depth); + let val = to_expr_go(let_node.val, var_map, depth); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + Expr::letE( + let_node.binder_name.clone(), + typ, + val, + bod, + let_node.non_dep, + ) + }, + + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let structure = to_expr_go(proj.expr, var_map, depth); + Expr::proj(proj.type_name.clone(), proj.idx.clone(), structure) + }, + + DAGPtr::Lam(link) => { + // Standalone Lam shouldn't appear at the top level, + // but handle it gracefully for completeness. + let lam = &*link.as_ptr(); + let var_ptr = &lam.var as *const Var; + var_map.insert(var_ptr, depth); + let bod = to_expr_go(lam.bod, var_map, depth + 1); + var_map.remove(&var_ptr); + // Wrap in a lambda with anonymous name and default binder info + Expr::lam( + Name::anon(), + Expr::sort(Level::zero()), + bod, + crate::ix::env::BinderInfo::Default, + ) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::env::{BinderInfo, Literal}; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + // ========================================================================== + // Terminal roundtrips + // ========================================================================== + + #[test] + fn roundtrip_sort() { + let e = Expr::sort(Level::succ(Level::zero())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_sort_param() { + let e = Expr::sort(Level::param(mk_name("u"))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const() { + let e = Expr::cnst( + mk_name("Foo"), + vec![Level::zero(), Level::succ(Level::zero())], + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nat_lit() { + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_string_lit() { + let e = Expr::lit(Literal::StrVal("hello world".into())); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Binder roundtrips + // ========================================================================== + + #[test] + fn roundtrip_identity_lambda() { + // fun (x : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_const_lambda() { + // fun (x : Nat) (y : Nat) => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_pi() { + // (x : Nat) → Nat + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_dependent_pi() { + // (A : Sort 0) → A → A + let sort0 = Expr::sort(Level::zero()); + let e = Expr::all( + mk_name("A"), + sort0, + Expr::all( + mk_name("x"), + Expr::bvar(Nat::from(0u64)), // A + Expr::bvar(Nat::from(1u64)), // A + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // App roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app() { + // f a + let e = Expr::app( + Expr::cnst(mk_name("f"), vec![]), + nat_zero(), + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_nested_app() { + // f a b + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f, a), b); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Let roundtrips + // ========================================================================== + + #[test] + fn roundtrip_let() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_let_non_dep() { + // let x : Nat := Nat.zero in Nat.zero (non_dep = true) + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + nat_zero(), + true, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Proj roundtrips + // ========================================================================== + + #[test] + fn roundtrip_proj() { + let e = Expr::proj(mk_name("Prod"), Nat::from(0u64), nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Complex roundtrips + // ========================================================================== + + #[test] + fn roundtrip_app_of_lambda() { + // (fun x : Nat => x) Nat.zero + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_lambda_in_lambda() { + // fun (f : Nat → Nat) (x : Nat) => f x + let nat_to_nat = Expr::all( + mk_name("_"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::lam( + mk_name("f"), + nat_to_nat, + Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(1u64)), // f + Expr::bvar(Nat::from(0u64)), // x + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_bvar_sharing() { + // fun (x : Nat) => App(x, x) + // Both bvar(0) should map to the same Var in DAG + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app( + Expr::bvar(Nat::from(0u64)), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_free_bvar() { + // Bvar(5) with no enclosing binder — should survive roundtrip + let e = Expr::bvar(Nat::from(5u64)); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + #[test] + fn roundtrip_implicit_binder() { + // fun {x : Nat} => x + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Implicit, + ); + let dag = from_expr(&e); + let result = to_expr(&dag); + assert_eq!(result, e); + } + + // ========================================================================== + // Property tests (quickcheck) + // ========================================================================== + + /// Generate a random well-formed Expr with bound variables properly scoped. + /// `depth` tracks how many binders are in scope (for valid bvar generation). + fn arb_expr(g: &mut Gen, depth: u64, size: usize) -> Expr { + if size == 0 { + // Terminal: pick among Sort, Const, Lit, or Bvar (if depth > 0) + let choices = if depth > 0 { 5 } else { 4 }; + match usize::arbitrary(g) % choices { + 0 => Expr::sort(arb_level(g, 2)), + 1 => { + let names = ["Nat", "Bool", "String", "Unit", "Int"]; + let idx = usize::arbitrary(g) % names.len(); + Expr::cnst(mk_name(names[idx]), vec![]) + }, + 2 => { + let n = u64::arbitrary(g) % 100; + Expr::lit(Literal::NatVal(Nat::from(n))) + }, + 3 => { + let s: String = String::arbitrary(g); + // Truncate at a char boundary to avoid panics + let s: String = s.chars().take(10).collect(); + Expr::lit(Literal::StrVal(s)) + }, + 4 => { + // Bvar within scope + let idx = u64::arbitrary(g) % depth; + Expr::bvar(Nat::from(idx)) + }, + _ => unreachable!(), + } + } else { + let next = size / 2; + match usize::arbitrary(g) % 5 { + 0 => { + // App + let f = arb_expr(g, depth, next); + let a = arb_expr(g, depth, next); + Expr::app(f, a) + }, + 1 => { + // Lam + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::lam(mk_name("x"), dom, bod, BinderInfo::Default) + }, + 2 => { + // Pi + let dom = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next); + Expr::all(mk_name("a"), dom, bod, BinderInfo::Default) + }, + 3 => { + // Let + let typ = arb_expr(g, depth, next); + let val = arb_expr(g, depth, next); + let bod = arb_expr(g, depth + 1, next / 2); + Expr::letE(mk_name("v"), typ, val, bod, bool::arbitrary(g)) + }, + 4 => { + // Proj + let idx = u64::arbitrary(g) % 4; + let structure = arb_expr(g, depth, next); + Expr::proj(mk_name("S"), Nat::from(idx), structure) + }, + _ => unreachable!(), + } + } + } + + fn arb_level(g: &mut Gen, size: usize) -> Level { + if size == 0 { + match usize::arbitrary(g) % 3 { + 0 => Level::zero(), + 1 => { + let params = ["u", "v", "w"]; + let idx = usize::arbitrary(g) % params.len(); + Level::param(mk_name(params[idx])) + }, + 2 => Level::succ(Level::zero()), + _ => unreachable!(), + } + } else { + match usize::arbitrary(g) % 3 { + 0 => Level::succ(arb_level(g, size - 1)), + 1 => Level::max(arb_level(g, size / 2), arb_level(g, size / 2)), + 2 => Level::imax(arb_level(g, size / 2), arb_level(g, size / 2)), + _ => unreachable!(), + } + } + } + + /// Newtype wrapper for quickcheck Arbitrary derivation. + #[derive(Clone, Debug)] + struct ArbExpr(Expr); + + impl Arbitrary for ArbExpr { + fn arbitrary(g: &mut Gen) -> Self { + let size = usize::arbitrary(g) % 5; + ArbExpr(arb_expr(g, 0, size)) + } + } + + #[quickcheck] + fn prop_roundtrip(e: ArbExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } + + /// Same test but with expressions generated inside binders. + #[derive(Clone, Debug)] + struct ArbBinderExpr(Expr); + + impl Arbitrary for ArbBinderExpr { + fn arbitrary(g: &mut Gen) -> Self { + let inner_size = usize::arbitrary(g) % 4; + let body = arb_expr(g, 1, inner_size); + let dom = arb_expr(g, 0, 0); + ArbBinderExpr(Expr::lam( + mk_name("x"), + dom, + body, + BinderInfo::Default, + )) + } + } + + #[quickcheck] + fn prop_roundtrip_binder(e: ArbBinderExpr) -> bool { + let dag = from_expr(&e.0); + let result = to_expr(&dag); + result == e.0 + } +} diff --git a/src/ix/kernel/dag.rs b/src/ix/kernel/dag.rs new file mode 100644 index 00000000..9837405f --- /dev/null +++ b/src/ix/kernel/dag.rs @@ -0,0 +1,527 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Level, Literal, Name}; +use crate::lean::nat::Nat; +use rustc_hash::FxHashSet; + +use super::dll::DLL; + +pub type Parents = DLL; + +// ============================================================================ +// Pointer types +// ============================================================================ + +#[derive(Debug)] +pub enum DAGPtr { + Var(NonNull), + Sort(NonNull), + Cnst(NonNull), + Lit(NonNull), + Lam(NonNull), + Fun(NonNull), + Pi(NonNull), + App(NonNull), + Let(NonNull), + Proj(NonNull), +} + +impl Copy for DAGPtr {} +impl Clone for DAGPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for DAGPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (DAGPtr::Var(a), DAGPtr::Var(b)) => a == b, + (DAGPtr::Sort(a), DAGPtr::Sort(b)) => a == b, + (DAGPtr::Cnst(a), DAGPtr::Cnst(b)) => a == b, + (DAGPtr::Lit(a), DAGPtr::Lit(b)) => a == b, + (DAGPtr::Lam(a), DAGPtr::Lam(b)) => a == b, + (DAGPtr::Fun(a), DAGPtr::Fun(b)) => a == b, + (DAGPtr::Pi(a), DAGPtr::Pi(b)) => a == b, + (DAGPtr::App(a), DAGPtr::App(b)) => a == b, + (DAGPtr::Let(a), DAGPtr::Let(b)) => a == b, + (DAGPtr::Proj(a), DAGPtr::Proj(b)) => a == b, + _ => false, + } + } +} +impl Eq for DAGPtr {} + +#[derive(Debug)] +pub enum ParentPtr { + Root, + LamBod(NonNull), + FunDom(NonNull), + FunImg(NonNull), + PiDom(NonNull), + PiImg(NonNull), + AppFun(NonNull), + AppArg(NonNull), + LetTyp(NonNull), + LetVal(NonNull), + LetBod(NonNull), + ProjExpr(NonNull), +} + +impl Copy for ParentPtr {} +impl Clone for ParentPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for ParentPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ParentPtr::Root, ParentPtr::Root) => true, + (ParentPtr::LamBod(a), ParentPtr::LamBod(b)) => a == b, + (ParentPtr::FunDom(a), ParentPtr::FunDom(b)) => a == b, + (ParentPtr::FunImg(a), ParentPtr::FunImg(b)) => a == b, + (ParentPtr::PiDom(a), ParentPtr::PiDom(b)) => a == b, + (ParentPtr::PiImg(a), ParentPtr::PiImg(b)) => a == b, + (ParentPtr::AppFun(a), ParentPtr::AppFun(b)) => a == b, + (ParentPtr::AppArg(a), ParentPtr::AppArg(b)) => a == b, + (ParentPtr::LetTyp(a), ParentPtr::LetTyp(b)) => a == b, + (ParentPtr::LetVal(a), ParentPtr::LetVal(b)) => a == b, + (ParentPtr::LetBod(a), ParentPtr::LetBod(b)) => a == b, + (ParentPtr::ProjExpr(a), ParentPtr::ProjExpr(b)) => a == b, + _ => false, + } + } +} +impl Eq for ParentPtr {} + +/// Binder pointer: from a Var to its binding Lam, or Free. +#[derive(Debug)] +pub enum BinderPtr { + Free, + Lam(NonNull), +} + +impl Copy for BinderPtr {} +impl Clone for BinderPtr { + fn clone(&self) -> Self { + *self + } +} + +impl PartialEq for BinderPtr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (BinderPtr::Free, BinderPtr::Free) => true, + (BinderPtr::Lam(a), BinderPtr::Lam(b)) => a == b, + _ => false, + } + } +} + +// ============================================================================ +// Node structs +// ============================================================================ + +/// Bound or free variable. +#[repr(C)] +pub struct Var { + /// De Bruijn level (used during from_expr/to_expr conversion). + pub depth: u64, + /// Points to the binding Lam, or Free for free variables. + pub binder: BinderPtr, + /// Parent pointers. + pub parents: Option>, +} + +impl Copy for Var {} +impl Clone for Var { + fn clone(&self) -> Self { + *self + } +} + +/// Sort node (universe). +#[repr(C)] +pub struct Sort { + pub level: Level, + pub parents: Option>, +} + +/// Constant reference. +#[repr(C)] +pub struct Cnst { + pub name: Name, + pub levels: Vec, + pub parents: Option>, +} + +/// Literal value (Nat or String). +#[repr(C)] +pub struct LitNode { + pub val: Literal, + pub parents: Option>, +} + +/// Internal binding node (spine). Carries an embedded Var. +/// Always appears as the img/bod of Fun/Pi/Let. +#[repr(C)] +pub struct Lam { + pub bod: DAGPtr, + pub bod_ref: Parents, + pub var: Var, + pub parents: Option>, +} + +/// Lean lambda: `fun (name : dom) => bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Fun { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Lean Pi/ForallE: `(name : dom) → bod`. +/// Branch node wrapping a Lam for the body. +#[repr(C)] +pub struct Pi { + pub binder_name: Name, + pub binder_info: BinderInfo, + pub dom: DAGPtr, + pub img: NonNull, + pub dom_ref: Parents, + pub img_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Application node. +#[repr(C)] +pub struct App { + pub fun: DAGPtr, + pub arg: DAGPtr, + pub fun_ref: Parents, + pub arg_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Let binding: `let name : typ := val in bod`. +#[repr(C)] +pub struct LetNode { + pub binder_name: Name, + pub non_dep: bool, + pub typ: DAGPtr, + pub val: DAGPtr, + pub bod: NonNull, + pub typ_ref: Parents, + pub val_ref: Parents, + pub bod_ref: Parents, + pub copy: Option>, + pub parents: Option>, +} + +/// Projection from a structure. +#[repr(C)] +pub struct ProjNode { + pub type_name: Name, + pub idx: Nat, + pub expr: DAGPtr, + pub expr_ref: Parents, + pub parents: Option>, +} + +/// A DAG with a head node. +pub struct DAG { + pub head: DAGPtr, +} + +// ============================================================================ +// Allocation helpers +// ============================================================================ + +#[inline] +pub fn alloc_val(val: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(val))).unwrap() +} + +pub fn alloc_lam( + depth: u64, + bod: DAGPtr, + parents: Option>, +) -> NonNull { + let lam_ptr = alloc_val(Lam { + bod, + bod_ref: DLL::singleton(ParentPtr::Root), + var: Var { depth, binder: BinderPtr::Free, parents: None }, + parents, + }); + unsafe { + let lam = &mut *lam_ptr.as_ptr(); + lam.bod_ref = DLL::singleton(ParentPtr::LamBod(lam_ptr)); + lam.var.binder = BinderPtr::Lam(lam_ptr); + } + lam_ptr +} + +pub fn alloc_app( + fun: DAGPtr, + arg: DAGPtr, + parents: Option>, +) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +pub fn alloc_fun( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +pub fn alloc_pi( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, + parents: Option>, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +pub fn alloc_let( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, + parents: Option>, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +pub fn alloc_proj( + type_name: Name, + idx: Nat, + expr: DAGPtr, + parents: Option>, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Parent pointer helpers +// ============================================================================ + +pub fn get_parents(node: DAGPtr) -> Option> { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents, + DAGPtr::App(p) => (*p.as_ptr()).parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents, + } + } +} + +pub fn set_parents(node: DAGPtr, parents: Option>) { + unsafe { + match node { + DAGPtr::Var(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Sort(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Cnst(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lit(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Lam(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Fun(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Pi(p) => (*p.as_ptr()).parents = parents, + DAGPtr::App(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Let(p) => (*p.as_ptr()).parents = parents, + DAGPtr::Proj(p) => (*p.as_ptr()).parents = parents, + } + } +} + +pub fn add_to_parents(node: DAGPtr, parent_link: NonNull) { + unsafe { + match get_parents(node) { + None => set_parents(node, Some(parent_link)), + Some(parents) => { + (*parents.as_ptr()).merge(parent_link); + }, + } + } +} + +// ============================================================================ +// DAG-level helpers +// ============================================================================ + +/// Get a unique key for a DAG node pointer (for use in hash sets). +pub fn dag_ptr_key(node: DAGPtr) -> usize { + match node { + DAGPtr::Var(p) => p.as_ptr() as usize, + DAGPtr::Sort(p) => p.as_ptr() as usize, + DAGPtr::Cnst(p) => p.as_ptr() as usize, + DAGPtr::Lit(p) => p.as_ptr() as usize, + DAGPtr::Lam(p) => p.as_ptr() as usize, + DAGPtr::Fun(p) => p.as_ptr() as usize, + DAGPtr::Pi(p) => p.as_ptr() as usize, + DAGPtr::App(p) => p.as_ptr() as usize, + DAGPtr::Let(p) => p.as_ptr() as usize, + DAGPtr::Proj(p) => p.as_ptr() as usize, + } +} + +/// Free all DAG nodes reachable from the head. +/// Only frees the node structs themselves; DLL parent entries that are +/// inline in parent structs are freed with those structs. The root_parents +/// DLL node (heap-allocated in from_expr) is a small accepted leak. +pub fn free_dag(dag: DAG) { + let mut visited = FxHashSet::default(); + free_dag_nodes(dag.head, &mut visited); +} + +fn free_dag_nodes(node: DAGPtr, visited: &mut FxHashSet) { + let key = dag_ptr_key(node); + if !visited.insert(key) { + return; + } + unsafe { + match node { + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + // Only free separately-allocated free vars; bound vars are + // embedded in their Lam struct and freed with it. + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + free_dag_nodes(lam.bod, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + free_dag_nodes(fun.dom, visited); + free_dag_nodes(DAGPtr::Lam(fun.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + free_dag_nodes(pi.dom, visited); + free_dag_nodes(DAGPtr::Lam(pi.img), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + free_dag_nodes(app.fun, visited); + free_dag_nodes(app.arg, visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + free_dag_nodes(let_node.typ, visited); + free_dag_nodes(let_node.val, visited); + free_dag_nodes(DAGPtr::Lam(let_node.bod), visited); + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + free_dag_nodes(proj.expr, visited); + drop(Box::from_raw(link.as_ptr())); + }, + } + } +} diff --git a/src/ix/kernel/def_eq.rs b/src/ix/kernel/def_eq.rs new file mode 100644 index 00000000..c2110381 --- /dev/null +++ b/src/ix/kernel/def_eq.rs @@ -0,0 +1,1298 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::level::{eq_antisymm, eq_antisymm_many}; +use super::tc::TypeChecker; +use super::whnf::*; + +/// Result of lazy delta reduction. +enum DeltaResult { + Found(bool), + Exhausted(Expr, Expr), +} + +/// Check definitional equality of two expressions. +pub fn def_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + if let Some(quick) = def_eq_quick_check(x, y) { + return quick; + } + + let x_n = tc.whnf(x); + let y_n = tc.whnf(y); + + if let Some(quick) = def_eq_quick_check(&x_n, &y_n) { + return quick; + } + + if proof_irrel_eq(&x_n, &y_n, tc) { + return true; + } + + match lazy_delta_step(&x_n, &y_n, tc) { + DeltaResult::Found(result) => result, + DeltaResult::Exhausted(x_e, y_e) => { + def_eq_const(&x_e, &y_e) + || def_eq_proj(&x_e, &y_e, tc) + || def_eq_app(&x_e, &y_e, tc) + || def_eq_binder_full(&x_e, &y_e, tc) + || try_eta_expansion(&x_e, &y_e, tc) + || try_eta_struct(&x_e, &y_e, tc) + || is_def_eq_unit_like(&x_e, &y_e, tc) + }, + } +} + +/// Quick syntactic checks. +fn def_eq_quick_check(x: &Expr, y: &Expr) -> Option { + if x == y { + return Some(true); + } + if let Some(r) = def_eq_sort(x, y) { + return Some(r); + } + if let Some(r) = def_eq_binder(x, y) { + return Some(r); + } + None +} + +fn def_eq_sort(x: &Expr, y: &Expr) -> Option { + match (x.as_data(), y.as_data()) { + (ExprData::Sort(l, _), ExprData::Sort(r, _)) => { + Some(eq_antisymm(l, r)) + }, + _ => None, + } +} + +/// Check if two binder expressions (Pi/Lam) are definitionally equal. +/// Always defers to full checking after WHNF, since binder types could be +/// definitionally equal without being syntactically identical. +fn def_eq_binder(_x: &Expr, _y: &Expr) -> Option { + None +} + +fn def_eq_const(x: &Expr, y: &Expr) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Const(xn, xl, _), + ExprData::Const(yn, yl, _), + ) => xn == yn && eq_antisymm_many(xl, yl), + _ => false, + } +} + +fn def_eq_proj(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::Proj(_, idx_l, structure_l, _), + ExprData::Proj(_, idx_r, structure_r, _), + ) => idx_l == idx_r && def_eq(structure_l, structure_r, tc), + _ => false, + } +} + +fn def_eq_app(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let (f1, args1) = unfold_apps(x); + if args1.is_empty() { + return false; + } + let (f2, args2) = unfold_apps(y); + if args2.is_empty() { + return false; + } + if args1.len() != args2.len() { + return false; + } + + if !def_eq(&f1, &f2, tc) { + return false; + } + args1.iter().zip(args2.iter()).all(|(a, b)| def_eq(a, b, tc)) +} + +/// Full recursive binder comparison: two Pi or two Lam types with +/// definitionally equal domain types and bodies (ignoring binder names). +fn def_eq_binder_full( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + match (x.as_data(), y.as_data()) { + ( + ExprData::ForallE(_, t1, b1, _, _), + ExprData::ForallE(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + ( + ExprData::Lam(_, t1, b1, _, _), + ExprData::Lam(_, t2, b2, _, _), + ) => def_eq(t1, t2, tc) && def_eq(b1, b2, tc), + _ => false, + } +} + +/// Proof irrelevance: if both x and y are proofs of the same proposition, +/// they are definitionally equal. +fn proof_irrel_eq(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&x_ty, tc) { + return false; + } + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + if !is_proposition(&y_ty, tc) { + return false; + } + def_eq(&x_ty, &y_ty, tc) +} + +/// Check if an expression's type is Prop (Sort 0). +fn is_proposition(ty: &Expr, tc: &mut TypeChecker) -> bool { + let ty_of_ty = match tc.infer(ty) { + Ok(t) => t, + Err(_) => return false, + }; + let whnfd = tc.whnf(&ty_of_ty); + matches!(whnfd.as_data(), ExprData::Sort(l, _) if super::level::is_zero(l)) +} + +/// Eta expansion: `fun x => f x` ≡ `f` when `f : (x : A) → B`. +fn try_eta_expansion(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_expansion_aux(x, y, tc) || try_eta_expansion_aux(y, x, tc) +} + +fn try_eta_expansion_aux( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> bool { + if let ExprData::Lam(_, _, _, _, _) = x.as_data() { + let y_ty = match tc.infer(y) { + Ok(t) => t, + Err(_) => return false, + }; + let y_ty_whnf = tc.whnf(&y_ty); + if let ExprData::ForallE(name, binder_type, _, bi, _) = + y_ty_whnf.as_data() + { + // eta-expand y: fun x => y x + let body = Expr::app(y.clone(), Expr::bvar(crate::lean::nat::Nat::from(0))); + let expanded = Expr::lam( + name.clone(), + binder_type.clone(), + body, + bi.clone(), + ); + return def_eq(x, &expanded, tc); + } + } + false +} + +/// Check if a name refers to a structure-like inductive: +/// exactly 1 constructor, not recursive, no indices. +fn is_structure_like(name: &Name, env: &Env) -> bool { + match env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + iv.ctors.len() == 1 && !iv.is_rec && iv.num_indices == Nat::ZERO + }, + _ => false, + } +} + +/// Structure eta: `p =def= S.mk (S.1 p) (S.2 p)` when S is a +/// single-constructor non-recursive inductive with no indices. +fn try_eta_struct(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + try_eta_struct_core(x, y, tc) || try_eta_struct_core(y, x, tc) +} + +/// Try to decompose `s` as a constructor application for a structure-like +/// type, then check that each field matches the corresponding projection of `t`. +fn try_eta_struct_core( + t: &Expr, + s: &Expr, + tc: &mut TypeChecker, +) -> bool { + let (head, args) = unfold_apps(s); + let ctor_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + + let ctor_info = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => return false, + }; + + if !is_structure_like(&ctor_info.induct, tc.env) { + return false; + } + + let num_params = ctor_info.num_params.to_u64().unwrap() as usize; + let num_fields = ctor_info.num_fields.to_u64().unwrap() as usize; + + if args.len() != num_params + num_fields { + return false; + } + + for i in 0..num_fields { + let field = &args[num_params + i]; + let proj = Expr::proj( + ctor_info.induct.clone(), + Nat::from(i as u64), + t.clone(), + ); + if !def_eq(field, &proj, tc) { + return false; + } + } + + true +} + +/// Unit-like equality: types with a single zero-field constructor have all +/// inhabitants definitionally equal. +fn is_def_eq_unit_like(x: &Expr, y: &Expr, tc: &mut TypeChecker) -> bool { + let x_ty = match tc.infer(x) { + Ok(ty) => ty, + Err(_) => return false, + }; + let y_ty = match tc.infer(y) { + Ok(ty) => ty, + Err(_) => return false, + }; + // Types must be def-eq + if !def_eq(&x_ty, &y_ty, tc) { + return false; + } + // Check if the type is a unit-like inductive + let whnf_ty = tc.whnf(&x_ty); + let (head, _) = unfold_apps(&whnf_ty); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return false, + }; + match tc.env.get(name) { + Some(ConstantInfo::InductInfo(iv)) => { + if iv.ctors.len() != 1 { + return false; + } + // Check single constructor has zero fields + if let Some(ConstantInfo::CtorInfo(c)) = tc.env.get(&iv.ctors[0]) { + c.num_fields == Nat::ZERO + } else { + false + } + }, + _ => false, + } +} + +/// Lazy delta reduction: unfold definitions step by step. +fn lazy_delta_step( + x: &Expr, + y: &Expr, + tc: &mut TypeChecker, +) -> DeltaResult { + let mut x = x.clone(); + let mut y = y.clone(); + + loop { + let x_def = get_applied_def(&x, tc.env); + let y_def = get_applied_def(&y, tc.env); + + match (&x_def, &y_def) { + (None, None) => return DeltaResult::Exhausted(x, y), + (Some(_), None) => { + x = delta(&x, tc); + }, + (None, Some(_)) => { + y = delta(&y, tc); + }, + (Some((x_name, x_hint)), Some((y_name, y_hint))) => { + // Same name and same height: try congruence first + if x_name == y_name && x_hint == y_hint { + if def_eq_app(&x, &y, tc) { + return DeltaResult::Found(true); + } + x = delta(&x, tc); + y = delta(&y, tc); + } else if hint_lt(x_hint, y_hint) { + y = delta(&y, tc); + } else { + x = delta(&x, tc); + } + }, + } + + if let Some(quick) = def_eq_quick_check(&x, &y) { + return DeltaResult::Found(quick); + } + } +} + +/// Get the name and reducibility hint of an applied definition. +fn get_applied_def( + e: &Expr, + env: &Env, +) -> Option<(Name, ReducibilityHints)> { + let (head, _) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + let ci = env.get(name)?; + match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + None + } else { + Some((name.clone(), d.hints)) + } + }, + ConstantInfo::ThmInfo(_) => { + Some((name.clone(), ReducibilityHints::Opaque)) + }, + _ => None, + } +} + +/// Unfold a definition and do cheap WHNF. +fn delta(e: &Expr, tc: &mut TypeChecker) -> Expr { + match try_unfold_def(e, tc.env) { + Some(unfolded) => tc.whnf(&unfolded), + None => e.clone(), + } +} + +/// Compare reducibility hints for ordering. +fn hint_lt(a: &ReducibilityHints, b: &ReducibilityHints) -> bool { + match (a, b) { + (ReducibilityHints::Opaque, _) => true, + (_, ReducibilityHints::Opaque) => false, + (ReducibilityHints::Abbrev, _) => false, + (_, ReducibilityHints::Abbrev) => true, + (ReducibilityHints::Regular(ha), ReducibilityHints::Regular(hb)) => { + ha < hb + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + /// Minimal env with Nat, Nat.zero, Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + // ========================================================================== + // Reflexivity + // ========================================================================== + + #[test] + fn def_eq_reflexive_sort() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_const() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + assert!(tc.def_eq(&e, &e)); + } + + #[test] + fn def_eq_reflexive_lambda() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e, &e)); + } + + // ========================================================================== + // Sort equality + // ========================================================================== + + #[test] + fn def_eq_sort_max_comm() { + // Sort(max u v) =def= Sort(max v u) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let s1 = Expr::sort(Level::max(u.clone(), v.clone())); + let s2 = Expr::sort(Level::max(v, u)); + assert!(tc.def_eq(&s1, &s2)); + } + + #[test] + fn def_eq_sort_not_equal() { + // Sort(0) ≠ Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s0 = Expr::sort(Level::zero()); + let s1 = Expr::sort(Level::succ(Level::zero())); + assert!(!tc.def_eq(&s0, &s1)); + } + + // ========================================================================== + // Alpha equivalence (same structure, different binder names) + // ========================================================================== + + #[test] + fn def_eq_alpha_lambda() { + // fun (x : Nat) => x =def= fun (y : Nat) => y + // (de Bruijn indices are the same, so this is syntactic equality) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e2 = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + #[test] + fn def_eq_alpha_pi() { + // (x : Nat) → Nat =def= (y : Nat) → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e1 = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e2 = Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&e1, &e2)); + } + + // ========================================================================== + // Beta equivalence + // ========================================================================== + + #[test] + fn def_eq_beta() { + // (fun x : Nat => x) Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let lhs = Expr::app(id_fn, nat_zero()); + let rhs = nat_zero(); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_beta_nested() { + // (fun x y : Nat => x) Nat.zero Nat.zero =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + inner, + BinderInfo::Default, + ); + let lhs = Expr::app(Expr::app(k_fn, nat_zero()), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Delta equivalence (definition unfolding) + // ========================================================================== + + #[test] + fn def_eq_delta() { + // def myZero := Nat.zero + // myZero =def= Nat.zero + let mut env = mk_nat_env(); + let my_zero = mk_name("myZero"); + env.insert( + my_zero.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_zero.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_zero.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(my_zero, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + #[test] + fn def_eq_delta_both_sides() { + // def a := Nat.zero, def b := Nat.zero + // a =def= b + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(a, vec![]); + let rhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Zeta equivalence (let unfolding) + // ========================================================================== + + #[test] + fn def_eq_zeta() { + // (let x : Nat := Nat.zero in x) =def= Nat.zero + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Negative tests + // ========================================================================== + + #[test] + fn def_eq_different_consts() { + // Nat ≠ String + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let nat = nat_type(); + let string = Expr::cnst(mk_name("String"), vec![]); + assert!(!tc.def_eq(&nat, &string)); + } + + #[test] + fn def_eq_different_nat_levels() { + // Nat.zero ≠ Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let zero = nat_zero(); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + assert!(!tc.def_eq(&zero, &succ)); + } + + #[test] + fn def_eq_app_congruence() { + // f a =def= f a (for same f, same a) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let f = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let a = nat_zero(); + let lhs = Expr::app(f.clone(), a.clone()); + let rhs = Expr::app(f, a); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_app_different_args() { + // Nat.succ Nat.zero ≠ Nat.succ (Nat.succ Nat.zero) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let lhs = Expr::app(succ.clone(), nat_zero()); + let rhs = + Expr::app(succ.clone(), Expr::app(succ, nat_zero())); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Const-level equality + // ========================================================================== + + #[test] + fn def_eq_const_levels() { + // A.{max u v} =def= A.{max v u} + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let v = Level::param(mk_name("v")); + let lhs = Expr::cnst(a_name.clone(), vec![Level::max(u.clone(), v.clone()), Level::zero()]); + let rhs = Expr::cnst(a_name, vec![Level::max(v, u), Level::zero()]); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Hint ordering + // ========================================================================== + + #[test] + fn hint_lt_opaque_less_than_all() { + assert!(hint_lt(&ReducibilityHints::Opaque, &ReducibilityHints::Abbrev)); + assert!(hint_lt( + &ReducibilityHints::Opaque, + &ReducibilityHints::Regular(0) + )); + } + + #[test] + fn hint_lt_abbrev_greatest() { + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Opaque + )); + assert!(!hint_lt( + &ReducibilityHints::Abbrev, + &ReducibilityHints::Regular(100) + )); + } + + #[test] + fn hint_lt_regular_ordering() { + assert!(hint_lt( + &ReducibilityHints::Regular(1), + &ReducibilityHints::Regular(2) + )); + assert!(!hint_lt( + &ReducibilityHints::Regular(2), + &ReducibilityHints::Regular(1) + )); + } + + // ========================================================================== + // Eta expansion + // ========================================================================== + + #[test] + fn def_eq_eta_lam_vs_const() { + // fun x : Nat => Nat.succ x =def= Nat.succ + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&eta_expanded, &succ)); + } + + #[test] + fn def_eq_eta_symmetric() { + // Nat.succ =def= fun x : Nat => Nat.succ x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let succ = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let eta_expanded = Expr::lam( + mk_name("x"), + nat_type(), + Expr::app(succ.clone(), Expr::bvar(Nat::from(0u64))), + BinderInfo::Default, + ); + assert!(tc.def_eq(&succ, &eta_expanded)); + } + + // ========================================================================== + // Lazy delta step with different heights + // ========================================================================== + + #[test] + fn def_eq_lazy_delta_higher_unfolds_first() { + // def a := Nat.zero (height 1) + // def b := a (height 2) + // b =def= Nat.zero should work by unfolding b first (higher height) + let mut env = mk_nat_env(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Regular(1), + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Regular(2), + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::cnst(b, vec![]); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Transitivity through delta + // ========================================================================== + + #[test] + fn def_eq_transitive_delta() { + // def a := Nat.zero, def b := Nat.zero + // def c := Nat.zero + // a =def= b, a =def= c, b =def= c + let mut env = mk_nat_env(); + for name_str in &["a", "b", "c"] { + let n = mk_name(name_str); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + } + let mut tc = TypeChecker::new(&env); + let a = Expr::cnst(mk_name("a"), vec![]); + let b = Expr::cnst(mk_name("b"), vec![]); + let c = Expr::cnst(mk_name("c"), vec![]); + assert!(tc.def_eq(&a, &b)); + assert!(tc.def_eq(&a, &c)); + assert!(tc.def_eq(&b, &c)); + } + + // ========================================================================== + // Nat literal equality through WHNF + // ========================================================================== + + #[test] + fn def_eq_nat_lit_same() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(42u64))); + assert!(tc.def_eq(&a, &b)); + } + + #[test] + fn def_eq_nat_lit_different() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let a = Expr::lit(Literal::NatVal(Nat::from(1u64))); + let b = Expr::lit(Literal::NatVal(Nat::from(2u64))); + assert!(!tc.def_eq(&a, &b)); + } + + // ========================================================================== + // Beta-delta combined + // ========================================================================== + + #[test] + fn def_eq_beta_delta_combined() { + // def myId := fun x : Nat => x + // myId Nat.zero =def= Nat.zero + let mut env = mk_nat_env(); + let my_id = mk_name("myId"); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + env.insert( + my_id.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_id.clone(), + level_params: vec![], + typ: fun_ty, + }, + value: Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_id.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::app(Expr::cnst(my_id, vec![]), nat_zero()); + assert!(tc.def_eq(&lhs, &nat_zero())); + } + + // ========================================================================== + // Structure eta + // ========================================================================== + + /// Build an env with Nat + Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_ctor_name = mk_name2("Prod", "mk"); + + // Prod.{u,v} (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_ctor_name.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_ctor_name, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn eta_struct_ctor_eq_proj() { + // Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) =def= p + // where p is a free variable of type Prod Nat Nat + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&ctor_app, &p)); + } + + #[test] + fn eta_struct_symmetric() { + // p =def= Prod.mk Nat Nat (Prod.1 p) (Prod.2 p) + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let prod_nat_nat = Expr::app( + Expr::app( + Expr::cnst(mk_name("Prod"), vec![one.clone(), one.clone()]), + nat_type(), + ), + nat_type(), + ); + let p = tc.mk_local(&mk_name("p"), &prod_nat_nat); + + let ctor_app = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + Expr::proj(mk_name("Prod"), Nat::from(0u64), p.clone()), + ), + Expr::proj(mk_name("Prod"), Nat::from(1u64), p.clone()), + ); + + assert!(tc.def_eq(&p, &ctor_app)); + } + + #[test] + fn eta_struct_nat_not_structure_like() { + // Nat has 2 constructors, so it is NOT structure-like + let env = mk_nat_env(); + assert!(!super::is_structure_like(&mk_name("Nat"), &env)); + } + + // ========================================================================== + // Binder full comparison + // ========================================================================== + + #[test] + fn def_eq_binder_full_different_domains() { + // (x : myNat) → Nat =def= (x : Nat) → Nat + // where myNat unfolds to Nat + let mut env = mk_nat_env(); + let my_nat = mk_name("myNat"); + env.insert( + my_nat.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: my_nat.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: nat_type(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![my_nat.clone()], + }), + ); + let mut tc = TypeChecker::new(&env); + let lhs = Expr::all( + mk_name("x"), + Expr::cnst(my_nat, vec![]), + nat_type(), + BinderInfo::Default, + ); + let rhs = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert!(tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Proj congruence + // ========================================================================== + + #[test] + fn def_eq_proj_congruence() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(0u64), s); + assert!(tc.def_eq(&lhs, &rhs)); + } + + #[test] + fn def_eq_proj_different_idx() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let s = nat_zero(); + let lhs = Expr::proj(mk_name("S"), Nat::from(0u64), s.clone()); + let rhs = Expr::proj(mk_name("S"), Nat::from(1u64), s); + assert!(!tc.def_eq(&lhs, &rhs)); + } + + // ========================================================================== + // Unit-like equality + // ========================================================================== + + #[test] + fn def_eq_unit_like() { + // Unit-type: single ctor, zero fields + // Any two inhabitants should be def-eq + let mut env = mk_nat_env(); + let unit_name = mk_name("Unit"); + let unit_star = mk_name2("Unit", "star"); + + env.insert( + unit_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: unit_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![unit_name.clone()], + ctors: vec![unit_star.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + unit_star.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: unit_star.clone(), + level_params: vec![], + typ: Expr::cnst(unit_name.clone(), vec![]), + }, + induct: unit_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + + // Two distinct fvars of type Unit should be def-eq + let unit_ty = Expr::cnst(unit_name, vec![]); + let x = tc.mk_local(&mk_name("x"), &unit_ty); + let y = tc.mk_local(&mk_name("y"), &unit_ty); + assert!(tc.def_eq(&x, &y)); + } +} diff --git a/src/ix/kernel/dll.rs b/src/ix/kernel/dll.rs new file mode 100644 index 00000000..07dfe135 --- /dev/null +++ b/src/ix/kernel/dll.rs @@ -0,0 +1,214 @@ +use core::marker::PhantomData; +use core::ptr::NonNull; + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +pub struct DLL { + pub next: Option>>, + pub prev: Option>>, + pub elem: T, +} + +pub struct Iter<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &*node.as_ptr() }; + self.next = deref.next; + &deref.elem + }) + } +} + +pub struct IterMut<'a, T> { + next: Option>>, + marker: PhantomData<&'a mut DLL>, +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + #[inline] + fn next(&mut self) -> Option { + self.next.map(|node| { + let deref = unsafe { &mut *node.as_ptr() }; + self.next = deref.next; + &mut deref.elem + }) + } +} + +impl DLL { + #[inline] + pub fn singleton(elem: T) -> Self { + DLL { next: None, prev: None, elem } + } + + #[inline] + pub fn alloc(elem: T) -> NonNull { + NonNull::new(Box::into_raw(Box::new(Self::singleton(elem)))).unwrap() + } + + #[inline] + pub fn is_singleton(dll: Option>) -> bool { + dll.is_some_and(|dll| unsafe { + let dll = &*dll.as_ptr(); + dll.prev.is_none() && dll.next.is_none() + }) + } + + #[inline] + pub fn is_empty(dll: Option>) -> bool { + dll.is_none() + } + + pub fn merge(&mut self, node: NonNull) { + unsafe { + (*node.as_ptr()).prev = self.prev; + (*node.as_ptr()).next = NonNull::new(self); + if let Some(ptr) = self.prev { + (*ptr.as_ptr()).next = Some(node); + } + self.prev = Some(node); + } + } + + pub fn unlink_node(&self) -> Option> { + unsafe { + let next = self.next; + let prev = self.prev; + if let Some(next) = next { + (*next.as_ptr()).prev = prev; + } + if let Some(prev) = prev { + (*prev.as_ptr()).next = next; + } + prev.or(next) + } + } + + pub fn first(mut node: NonNull) -> NonNull { + loop { + let prev = unsafe { (*node.as_ptr()).prev }; + match prev { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn last(mut node: NonNull) -> NonNull { + loop { + let next = unsafe { (*node.as_ptr()).next }; + match next { + None => break, + Some(ptr) => node = ptr, + } + } + node + } + + pub fn concat(dll: NonNull, rest: Option>) { + let last = DLL::last(dll); + let first = rest.map(DLL::first); + unsafe { + (*last.as_ptr()).next = first; + } + if let Some(first) = first { + unsafe { + (*first.as_ptr()).prev = Some(last); + } + } + } + + #[inline] + pub fn iter_option(dll: Option>) -> Iter<'static, T> { + Iter { next: dll.map(DLL::first), marker: PhantomData } + } + + #[inline] + #[allow(dead_code)] + pub fn iter_mut_option(dll: Option>) -> IterMut<'static, T> { + IterMut { next: dll.map(DLL::first), marker: PhantomData } + } + + #[allow(unsafe_op_in_unsafe_fn)] + pub unsafe fn free_all(dll: Option>) { + if let Some(start) = dll { + let first = DLL::first(start); + let mut current = Some(first); + while let Some(node) = current { + let next = (*node.as_ptr()).next; + drop(Box::from_raw(node.as_ptr())); + current = next; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn to_vec(dll: Option>>) -> Vec { + DLL::iter_option(dll).copied().collect() + } + + #[test] + fn test_singleton() { + let dll = DLL::alloc(42); + assert!(DLL::is_singleton(Some(dll))); + unsafe { + assert_eq!((*dll.as_ptr()).elem, 42); + drop(Box::from_raw(dll.as_ptr())); + } + } + + #[test] + fn test_is_empty() { + assert!(DLL::::is_empty(None)); + let dll = DLL::alloc(1); + assert!(!DLL::is_empty(Some(dll))); + unsafe { DLL::free_all(Some(dll)) }; + } + + #[test] + fn test_merge() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + (*a.as_ptr()).merge(b); + assert_eq!(to_vec(Some(a)), vec![2, 1]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_concat() { + unsafe { + let a = DLL::alloc(1); + let b = DLL::alloc(2); + DLL::concat(a, Some(b)); + assert_eq!(to_vec(Some(a)), vec![1, 2]); + DLL::free_all(Some(a)); + } + } + + #[test] + fn test_unlink_singleton() { + unsafe { + let dll = DLL::alloc(42); + let remaining = (*dll.as_ptr()).unlink_node(); + assert!(remaining.is_none()); + drop(Box::from_raw(dll.as_ptr())); + } + } +} diff --git a/src/ix/kernel/error.rs b/src/ix/kernel/error.rs new file mode 100644 index 00000000..33816246 --- /dev/null +++ b/src/ix/kernel/error.rs @@ -0,0 +1,59 @@ +use crate::ix::env::{Expr, Name}; + +#[derive(Debug)] +pub enum TcError { + TypeExpected { + expr: Expr, + inferred: Expr, + }, + FunctionExpected { + expr: Expr, + inferred: Expr, + }, + TypeMismatch { + expected: Expr, + found: Expr, + expr: Expr, + }, + DefEqFailure { + lhs: Expr, + rhs: Expr, + }, + UnknownConst { + name: Name, + }, + DuplicateUniverse { + name: Name, + }, + FreeBoundVariable { + idx: u64, + }, + KernelException { + msg: String, + }, +} + +impl std::fmt::Display for TcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TcError::TypeExpected { .. } => write!(f, "type expected"), + TcError::FunctionExpected { .. } => write!(f, "function expected"), + TcError::TypeMismatch { .. } => write!(f, "type mismatch"), + TcError::DefEqFailure { .. } => { + write!(f, "definitional equality failure") + }, + TcError::UnknownConst { name } => { + write!(f, "unknown constant: {}", name.pretty()) + }, + TcError::DuplicateUniverse { name } => { + write!(f, "duplicate universe: {}", name.pretty()) + }, + TcError::FreeBoundVariable { idx } => { + write!(f, "free bound variable at index {}", idx) + }, + TcError::KernelException { msg } => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for TcError {} diff --git a/src/ix/kernel/inductive.rs b/src/ix/kernel/inductive.rs new file mode 100644 index 00000000..a06ed819 --- /dev/null +++ b/src/ix/kernel/inductive.rs @@ -0,0 +1,772 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; + +use super::error::TcError; +use super::level; +use super::tc::TypeChecker; +use super::whnf::{inst, unfold_apps}; + +type TcResult = Result; + +/// Validate an inductive type declaration. +/// Performs structural checks: constructors exist, belong to this inductive, +/// and have well-formed types. Mutual types are verified to exist. +pub fn check_inductive( + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Verify the type is well-formed + tc.check_declar_info(&ind.cnst)?; + + // Verify all constructors exist and belong to this inductive + for ctor_name in &ind.ctors { + let ctor_ci = tc.env.get(ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + let ctor = match ctor_ci { + ConstantInfo::CtorInfo(c) => c, + _ => { + return Err(TcError::KernelException { + msg: format!( + "{} is not a constructor", + ctor_name.pretty() + ), + }) + }, + }; + // Verify constructor's induct field matches + if ctor.induct != ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} belongs to {} but expected {}", + ctor_name.pretty(), + ctor.induct.pretty(), + ind.cnst.name.pretty() + ), + }); + } + // Verify constructor type is well-formed + tc.check_declar_info(&ctor.cnst)?; + } + + // Verify constructor return types and positivity + for ctor_name in &ind.ctors { + let ctor = match tc.env.get(ctor_name) { + Some(ConstantInfo::CtorInfo(c)) => c, + _ => continue, // already checked above + }; + check_ctor_return_type(ctor, ind, tc)?; + if !ind.is_unsafe { + check_ctor_positivity(ctor, ind, tc)?; + check_field_universe_constraints(ctor, ind, tc)?; + } + } + + // Verify all mutual types exist + for name in &ind.all { + if tc.env.get(name).is_none() { + return Err(TcError::UnknownConst { name: name.clone() }); + } + } + + Ok(()) +} + +/// Validate that a recursor's K flag is consistent with the inductive's structure. +/// K-target requires: non-mutual, in Prop, single constructor, zero fields. +/// If `rec.k == true` but conditions don't hold, reject. +pub fn validate_k_flag( + rec: &RecursorVal, + env: &Env, +) -> TcResult<()> { + if !rec.k { + return Ok(()); // conservative false is always fine + } + + // Must be non-mutual: `rec.all` should have exactly 1 inductive + if rec.all.len() != 1 { + return Err(TcError::KernelException { + msg: "recursor claims K but inductive is mutual".into(), + }); + } + + let ind_name = &rec.all[0]; + let ind = match env.get(ind_name) { + Some(ConstantInfo::InductInfo(iv)) => iv, + _ => { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not an inductive", + ind_name.pretty() + ), + }) + }, + }; + + // Must be in Prop (Sort 0) + // Walk type telescope past all binders to get the sort + let mut ty = ind.cnst.typ.clone(); + loop { + match ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ty = body.clone(); + }, + _ => break, + } + } + let is_prop = match ty.as_data() { + ExprData::Sort(l, _) => level::is_zero(l), + _ => false, + }; + if !is_prop { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} is not in Prop", + ind_name.pretty() + ), + }); + } + + // Must have single constructor + if ind.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but {} has {} constructors (need 1)", + ind_name.pretty(), + ind.ctors.len() + ), + }); + } + + // Constructor must have zero fields (all args are params) + let ctor_name = &ind.ctors[0]; + if let Some(ConstantInfo::CtorInfo(c)) = env.get(ctor_name) { + if c.num_fields != Nat::ZERO { + return Err(TcError::KernelException { + msg: format!( + "recursor claims K but constructor {} has {} fields (need 0)", + ctor_name.pretty(), + c.num_fields + ), + }); + } + } + + Ok(()) +} + +/// Check if an expression mentions a constant by name. +fn expr_mentions_const(e: &Expr, name: &Name) -> bool { + match e.as_data() { + ExprData::Const(n, _, _) => n == name, + ExprData::App(f, a, _) => { + expr_mentions_const(f, name) || expr_mentions_const(a, name) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + expr_mentions_const(t, name) || expr_mentions_const(b, name) + }, + ExprData::LetE(_, t, v, b, _, _) => { + expr_mentions_const(t, name) + || expr_mentions_const(v, name) + || expr_mentions_const(b, name) + }, + ExprData::Proj(_, _, s, _) => expr_mentions_const(s, name), + ExprData::Mdata(_, inner, _) => expr_mentions_const(inner, name), + _ => false, + } +} + +/// Check that no inductive name from `ind.all` appears in a negative position +/// in the constructor's field types. +fn check_ctor_positivity( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ty = ctor.cnst.typ.clone(); + + // Skip parameter binders + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => return Ok(()), // fewer binders than params — odd but not our problem + } + } + + // For each remaining field, check its domain for positivity + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // The domain is the field type — check strict positivity + check_strict_positivity(binder_type, &ind.all, tc)?; + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Check strict positivity of a field type w.r.t. a set of inductive names. +/// +/// Strict positivity for `T` w.r.t. `I`: +/// - If `T` doesn't mention `I`, OK. +/// - If `T = I args...`, OK (the inductive itself at the head). +/// - If `T = (x : A) → B`, then `A` must NOT mention `I` at all, +/// and `B` must satisfy strict positivity w.r.t. `I`. +/// - Otherwise (I appears but not at head and not in Pi), reject. +fn check_strict_positivity( + ty: &Expr, + ind_names: &[Name], + tc: &mut TypeChecker, +) -> TcResult<()> { + let whnf_ty = tc.whnf(ty); + + // If no inductive name is mentioned, we're fine + if !ind_names.iter().any(|n| expr_mentions_const(&whnf_ty, n)) { + return Ok(()); + } + + match whnf_ty.as_data() { + ExprData::ForallE(_, domain, body, _, _) => { + // Domain must NOT mention any inductive name + for ind_name in ind_names { + if expr_mentions_const(domain, ind_name) { + return Err(TcError::KernelException { + msg: format!( + "inductive {} occurs in negative position (strict positivity violation)", + ind_name.pretty() + ), + }); + } + } + // Recurse into body + check_strict_positivity(body, ind_names, tc) + }, + _ => { + // The inductive is mentioned and we're not in a Pi — check if + // it's simply an application `I args...` (which is OK). + let (head, _) = unfold_apps(&whnf_ty); + match head.as_data() { + ExprData::Const(name, _, _) + if ind_names.iter().any(|n| n == name) => + { + Ok(()) + }, + _ => Err(TcError::KernelException { + msg: "inductive type occurs in a non-positive position".into(), + }), + } + }, + } +} + +/// Check that constructor field types live in universes ≤ the inductive's universe. +fn check_field_universe_constraints( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + // Walk the inductive type telescope past num_params binders to find the sort level. + let num_params = ind.num_params.to_u64().unwrap() as usize; + let mut ind_ty = ind.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + // Skip remaining binders (indices) to get to the target sort + loop { + let whnf_ty = tc.whnf(&ind_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ind_ty = inst(body, &[local]); + }, + _ => { + ind_ty = whnf_ty; + break; + }, + } + } + let ind_level = match ind_ty.as_data() { + ExprData::Sort(l, _) => l.clone(), + _ => return Ok(()), // can't extract sort, skip + }; + + // Walk ctor type, skip params, then check each field + let mut ctor_ty = ctor.cnst.typ.clone(); + for _ in 0..num_params { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => return Ok(()), + } + } + + // For each remaining field binder, check its sort level ≤ ind_level + loop { + let whnf_ty = tc.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + // Infer the sort of the binder_type + if let Ok(field_level) = tc.infer_sort_of(binder_type) { + if !level::leq(&field_level, &ind_level) { + return Err(TcError::KernelException { + msg: format!( + "constructor {} field type lives in a universe larger than the inductive's universe", + ctor.cnst.name.pretty() + ), + }); + } + } + let local = tc.mk_local(name, binder_type); + ctor_ty = inst(body, &[local]); + }, + _ => break, + } + } + + Ok(()) +} + +/// Verify that a constructor's return type targets the parent inductive. +/// Walks the constructor type telescope, then checks that the resulting +/// type is an application of the parent inductive with at least `num_params` args. +fn check_ctor_return_type( + ctor: &ConstructorVal, + ind: &InductiveVal, + tc: &mut TypeChecker, +) -> TcResult<()> { + let mut ty = ctor.cnst.typ.clone(); + + // Walk past all Pi binders + loop { + let whnf_ty = tc.whnf(&ty); + match whnf_ty.as_data() { + ExprData::ForallE(name, binder_type, body, _, _) => { + let local = tc.mk_local(name, binder_type); + ty = inst(body, &[local]); + }, + _ => { + ty = whnf_ty; + break; + }, + } + } + + // The return type should be `I args...` + let (head, args) = unfold_apps(&ty); + let head_name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type head is not a constant", + ctor.cnst.name.pretty() + ), + }) + }, + }; + + if head_name != &ind.cnst.name { + return Err(TcError::KernelException { + msg: format!( + "constructor {} returns {} but should return {}", + ctor.cnst.name.pretty(), + head_name.pretty(), + ind.cnst.name.pretty() + ), + }); + } + + let num_params = ind.num_params.to_u64().unwrap() as usize; + if args.len() < num_params { + return Err(TcError::KernelException { + msg: format!( + "constructor {} return type has {} args but inductive has {} params", + ctor.cnst.name.pretty(), + args.len(), + num_params + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ix::kernel::tc::TypeChecker; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn mk_nat_env() -> Env { + let mut env = Env::default(); + let nat_name = mk_name("Nat"); + env.insert( + nat_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }), + ); + let zero_name = mk_name2("Nat", "zero"); + env.insert( + zero_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + let succ_name = mk_name2("Nat", "succ"); + env.insert( + succ_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn check_nat_inductive_passes() { + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn check_ctor_wrong_return_type() { + let mut env = mk_nat_env(); + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name.clone()], + ctors: vec![mk_name2("Bool", "bad")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // Constructor returns Nat instead of Bool + let bad_ctor_name = mk_name2("Bool", "bad"); + env.insert( + bad_ctor_name.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: bad_ctor_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: bool_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&bool_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Positivity checking + // ========================================================================== + + fn bool_type() -> Expr { + Expr::cnst(mk_name("Bool"), vec![]) + } + + /// Helper to make a simple inductive + ctor env for positivity tests. + fn mk_single_ctor_env( + ind_name: &str, + ctor_name: &str, + ctor_typ: Expr, + num_fields: u64, + ) -> Env { + let mut env = mk_nat_env(); + // Bool + let bool_name = mk_name("Bool"); + env.insert( + bool_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: bool_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![bool_name], + ctors: vec![mk_name2("Bool", "true"), mk_name2("Bool", "false")], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + let iname = mk_name(ind_name); + let cname = mk_name2(ind_name, ctor_name); + env.insert( + iname.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: iname.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![iname.clone()], + ctors: vec![cname.clone()], + num_nested: Nat::from(0u64), + is_rec: num_fields > 0, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + cname.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: cname, + level_params: vec![], + typ: ctor_typ, + }, + induct: iname, + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(num_fields), + is_unsafe: false, + }), + ); + env + } + + #[test] + fn positivity_bad_negative() { + // inductive Bad | mk : (Bad → Bool) → Bad + let bad = mk_name("Bad"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("x"), Expr::cnst(bad, vec![]), bool_type(), BinderInfo::Default), + Expr::cnst(mk_name("Bad"), vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + #[test] + fn positivity_nat_succ_ok() { + // Nat.succ : Nat → Nat (positive) + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_tree_positive_function() { + // inductive Tree | node : (Nat → Tree) → Tree + // Tree appears positive in `Nat → Tree` + let tree = mk_name("Tree"); + let ctor_ty = Expr::all( + mk_name("f"), + Expr::all(mk_name("n"), nat_type(), Expr::cnst(tree.clone(), vec![]), BinderInfo::Default), + Expr::cnst(tree, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Tree", "node", ctor_ty, 1); + let ind = match env.get(&mk_name("Tree")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn positivity_depth2_negative() { + // inductive Bad2 | mk : ((Bad2 → Nat) → Nat) → Bad2 + // Bad2 appears in negative position at depth 2 + let bad2 = mk_name("Bad2"); + let inner = Expr::all( + mk_name("g"), + Expr::all(mk_name("x"), Expr::cnst(bad2.clone(), vec![]), nat_type(), BinderInfo::Default), + nat_type(), + BinderInfo::Default, + ); + let ctor_ty = Expr::all( + mk_name("f"), + inner, + Expr::cnst(bad2, vec![]), + BinderInfo::Default, + ); + let env = mk_single_ctor_env("Bad2", "mk", ctor_ty, 1); + let ind = match env.get(&mk_name("Bad2")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } + + // ========================================================================== + // Field universe constraints + // ========================================================================== + + #[test] + fn field_universe_nat_field_in_type1_ok() { + // Nat : Sort 1, Nat.succ field is Nat : Sort 1 — leq(1, 1) passes + let env = mk_nat_env(); + let ind = match env.get(&mk_name("Nat")).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_ok()); + } + + #[test] + fn field_universe_prop_inductive_with_type_field_fails() { + // inductive PropBad : Prop | mk : Nat → PropBad + // PropBad lives in Sort 0, Nat lives in Sort 1 — leq(1, 0) fails + let mut env = mk_nat_env(); + let pb_name = mk_name("PropBad"); + let pb_mk = mk_name2("PropBad", "mk"); + env.insert( + pb_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: pb_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::zero()), // Prop + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![pb_name.clone()], + ctors: vec![pb_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + pb_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: pb_mk, + level_params: vec![], + typ: Expr::all( + mk_name("n"), + nat_type(), // Nat : Sort 1 + Expr::cnst(pb_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: pb_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + let ind = match env.get(&pb_name).unwrap() { + ConstantInfo::InductInfo(v) => v, + _ => panic!(), + }; + let mut tc = TypeChecker::new(&env); + assert!(check_inductive(ind, &mut tc).is_err()); + } +} diff --git a/src/ix/kernel/level.rs b/src/ix/kernel/level.rs new file mode 100644 index 00000000..90931ca6 --- /dev/null +++ b/src/ix/kernel/level.rs @@ -0,0 +1,393 @@ +use crate::ix::env::{Expr, ExprData, Level, LevelData, Name}; + +/// Simplify a universe level expression. +pub fn simplify(l: &Level) -> Level { + match l.as_data() { + LevelData::Zero(_) | LevelData::Param(..) | LevelData::Mvar(..) => { + l.clone() + }, + LevelData::Succ(inner, _) => { + let inner_s = simplify(inner); + Level::succ(inner_s) + }, + LevelData::Max(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + combining(&a_s, &b_s) + }, + LevelData::Imax(a, b, _) => { + let a_s = simplify(a); + let b_s = simplify(b); + if is_zero(&a_s) || is_one(&a_s) { + b_s + } else { + match b_s.as_data() { + LevelData::Zero(_) => b_s, + LevelData::Succ(..) => combining(&a_s, &b_s), + _ => Level::imax(a_s, b_s), + } + } + }, + } +} + +/// Combine two levels, simplifying Max(Zero, x) = x and +/// Max(Succ a, Succ b) = Succ(Max(a, b)). +fn combining(l: &Level, r: &Level) -> Level { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) => r.clone(), + (_, LevelData::Zero(_)) => l.clone(), + (LevelData::Succ(a, _), LevelData::Succ(b, _)) => { + let inner = combining(a, b); + Level::succ(inner) + }, + _ => Level::max(l.clone(), r.clone()), + } +} + +fn is_one(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Succ(inner, _) if is_zero(inner)) +} + +/// Check if a level is definitionally zero: l <= 0. +pub fn is_zero(l: &Level) -> bool { + leq(l, &Level::zero()) +} + +/// Check if `l <= r`. +pub fn leq(l: &Level, r: &Level) -> bool { + let l_s = simplify(l); + let r_s = simplify(r); + leq_core(&l_s, &r_s, 0) +} + +/// Check `l <= r + diff`. +fn leq_core(l: &Level, r: &Level, diff: isize) -> bool { + match (l.as_data(), r.as_data()) { + (LevelData::Zero(_), _) if diff >= 0 => true, + (_, LevelData::Zero(_)) if diff < 0 => false, + (LevelData::Param(a, _), LevelData::Param(b, _)) => a == b && diff >= 0, + (LevelData::Param(..), LevelData::Zero(_)) => false, + (LevelData::Zero(_), LevelData::Param(..)) => diff >= 0, + (LevelData::Succ(s, _), _) => leq_core(s, r, diff - 1), + (_, LevelData::Succ(s, _)) => leq_core(l, s, diff + 1), + (LevelData::Max(a, b, _), _) => { + leq_core(a, r, diff) && leq_core(b, r, diff) + }, + (LevelData::Param(..) | LevelData::Zero(_), LevelData::Max(x, y, _)) => { + leq_core(l, x, diff) || leq_core(l, y, diff) + }, + (LevelData::Imax(a, b, _), LevelData::Imax(x, y, _)) + if a == x && b == y => + { + true + }, + (LevelData::Imax(_, b, _), _) if is_param(b) => { + leq_imax_by_cases(b, l, r, diff) + }, + (_, LevelData::Imax(_, y, _)) if is_param(y) => { + leq_imax_by_cases(y, l, r, diff) + }, + (LevelData::Imax(a, b, _), _) if is_any_max(b) => { + match b.as_data() { + LevelData::Imax(x, y, _) => { + let new_lhs = Level::imax(a.clone(), y.clone()); + let new_rhs = Level::imax(x.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(&new_max, r, diff) + }, + LevelData::Max(x, y, _) => { + let new_lhs = Level::imax(a.clone(), x.clone()); + let new_rhs = Level::imax(a.clone(), y.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(&simplified, r, diff) + }, + _ => unreachable!(), + } + }, + (_, LevelData::Imax(x, y, _)) if is_any_max(y) => { + match y.as_data() { + LevelData::Imax(j, k, _) => { + let new_lhs = Level::imax(x.clone(), k.clone()); + let new_rhs = Level::imax(j.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + leq_core(l, &new_max, diff) + }, + LevelData::Max(j, k, _) => { + let new_lhs = Level::imax(x.clone(), j.clone()); + let new_rhs = Level::imax(x.clone(), k.clone()); + let new_max = Level::max(new_lhs, new_rhs); + let simplified = simplify(&new_max); + leq_core(l, &simplified, diff) + }, + _ => unreachable!(), + } + }, + _ => false, + } +} + +/// Test l <= r by substituting param with 0 and Succ(param) and checking both. +fn leq_imax_by_cases( + param: &Level, + lhs: &Level, + rhs: &Level, + diff: isize, +) -> bool { + let zero = Level::zero(); + let succ_param = Level::succ(param.clone()); + + let lhs_0 = subst_and_simplify(lhs, param, &zero); + let rhs_0 = subst_and_simplify(rhs, param, &zero); + let lhs_s = subst_and_simplify(lhs, param, &succ_param); + let rhs_s = subst_and_simplify(rhs, param, &succ_param); + + leq_core(&lhs_0, &rhs_0, diff) && leq_core(&lhs_s, &rhs_s, diff) +} + +fn subst_and_simplify(level: &Level, from: &Level, to: &Level) -> Level { + let substituted = subst_single_level(level, from, to); + simplify(&substituted) +} + +/// Substitute a single level parameter. +fn subst_single_level(level: &Level, from: &Level, to: &Level) -> Level { + if level == from { + return to.clone(); + } + match level.as_data() { + LevelData::Zero(_) | LevelData::Mvar(..) => level.clone(), + LevelData::Param(..) => { + if level == from { + to.clone() + } else { + level.clone() + } + }, + LevelData::Succ(inner, _) => { + Level::succ(subst_single_level(inner, from, to)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_single_level(a, from, to), + subst_single_level(b, from, to), + ), + } +} + +fn is_param(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Param(..)) +} + +fn is_any_max(l: &Level) -> bool { + matches!(l.as_data(), LevelData::Max(..) | LevelData::Imax(..)) +} + +/// Check universe level equality via antisymmetry: l == r iff l <= r && r <= l. +pub fn eq_antisymm(l: &Level, r: &Level) -> bool { + leq(l, r) && leq(r, l) +} + +/// Check that two lists of levels are pointwise equal. +pub fn eq_antisymm_many(ls: &[Level], rs: &[Level]) -> bool { + ls.len() == rs.len() + && ls.iter().zip(rs.iter()).all(|(l, r)| eq_antisymm(l, r)) +} + +/// Substitute universe parameters: `level[params[i] := values[i]]`. +pub fn subst_level( + level: &Level, + params: &[Name], + values: &[Level], +) -> Level { + match level.as_data() { + LevelData::Zero(_) => level.clone(), + LevelData::Succ(inner, _) => { + Level::succ(subst_level(inner, params, values)) + }, + LevelData::Max(a, b, _) => Level::max( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Imax(a, b, _) => Level::imax( + subst_level(a, params, values), + subst_level(b, params, values), + ), + LevelData::Param(name, _) => { + for (i, p) in params.iter().enumerate() { + if name == p { + return values[i].clone(); + } + } + level.clone() + }, + LevelData::Mvar(..) => level.clone(), + } +} + +/// Check that all universe parameters in `level` are contained in `params`. +pub fn all_uparams_defined(level: &Level, params: &[Name]) -> bool { + match level.as_data() { + LevelData::Zero(_) => true, + LevelData::Succ(inner, _) => all_uparams_defined(inner, params), + LevelData::Max(a, b, _) | LevelData::Imax(a, b, _) => { + all_uparams_defined(a, params) && all_uparams_defined(b, params) + }, + LevelData::Param(name, _) => params.iter().any(|p| p == name), + LevelData::Mvar(..) => true, + } +} + +/// Check that all universe parameters in an expression are contained in `params`. +/// Recursively walks the Expr, checking all Levels in Sort and Const nodes. +pub fn all_expr_uparams_defined(e: &Expr, params: &[Name]) -> bool { + match e.as_data() { + ExprData::Sort(level, _) => all_uparams_defined(level, params), + ExprData::Const(_, levels, _) => { + levels.iter().all(|l| all_uparams_defined(l, params)) + }, + ExprData::App(f, a, _) => { + all_expr_uparams_defined(f, params) + && all_expr_uparams_defined(a, params) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::LetE(_, t, v, b, _, _) => { + all_expr_uparams_defined(t, params) + && all_expr_uparams_defined(v, params) + && all_expr_uparams_defined(b, params) + }, + ExprData::Proj(_, _, s, _) => all_expr_uparams_defined(s, params), + ExprData::Mdata(_, inner, _) => all_expr_uparams_defined(inner, params), + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => true, + } +} + +/// Check that a list of levels are all Params with no duplicates. +pub fn no_dupes_all_params(levels: &[Name]) -> bool { + for (i, a) in levels.iter().enumerate() { + for b in &levels[i + 1..] { + if a == b { + return false; + } + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simplify_zero() { + let z = Level::zero(); + assert_eq!(simplify(&z), z); + } + + #[test] + fn test_simplify_max_zero() { + let z = Level::zero(); + let p = Level::param(Name::str(Name::anon(), "u".into())); + let m = Level::max(z, p.clone()); + assert_eq!(simplify(&m), p); + } + + #[test] + fn test_simplify_imax_zero_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let z = Level::zero(); + let im = Level::imax(p, z.clone()); + assert_eq!(simplify(&im), z); + } + + #[test] + fn test_simplify_imax_succ_right() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let one = Level::succ(Level::zero()); + let im = Level::imax(p.clone(), one.clone()); + let simplified = simplify(&im); + // imax(p, 1) where p is nonzero → combining(p, 1) + // Actually: imax(u, 1) simplifies since a_s = u, b_s = 1 = Succ(0) + // → combining(u, 1) = max(u, 1) since u is Param, 1 is Succ + let expected = Level::max(p, one); + assert_eq!(simplified, expected); + } + + #[test] + fn test_simplify_idempotent() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let l = Level::max( + Level::imax(p.clone(), q.clone()), + Level::succ(Level::zero()), + ); + let s1 = simplify(&l); + let s2 = simplify(&s1); + assert_eq!(s1, s2); + } + + #[test] + fn test_leq_reflexive() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&p, &p)); + assert!(leq(&Level::zero(), &Level::zero())); + } + + #[test] + fn test_leq_zero_anything() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(leq(&Level::zero(), &p)); + assert!(leq(&Level::zero(), &Level::succ(Level::zero()))); + } + + #[test] + fn test_leq_succ_not_zero() { + let one = Level::succ(Level::zero()); + assert!(!leq(&one, &Level::zero())); + } + + #[test] + fn test_eq_antisymm_identity() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + assert!(eq_antisymm(&p, &p)); + } + + #[test] + fn test_eq_antisymm_max_comm() { + let p = Level::param(Name::str(Name::anon(), "u".into())); + let q = Level::param(Name::str(Name::anon(), "v".into())); + let m1 = Level::max(p.clone(), q.clone()); + let m2 = Level::max(q, p); + assert!(eq_antisymm(&m1, &m2)); + } + + #[test] + fn test_subst_level() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let one = Level::succ(Level::zero()); + let result = subst_level(&p, &[u_name], &[one.clone()]); + assert_eq!(result, one); + } + + #[test] + fn test_subst_level_nested() { + let u_name = Name::str(Name::anon(), "u".into()); + let p = Level::param(u_name.clone()); + let l = Level::succ(p); + let zero = Level::zero(); + let result = subst_level(&l, &[u_name], &[zero]); + let expected = Level::succ(Level::zero()); + assert_eq!(result, expected); + } +} diff --git a/src/ix/kernel/mod.rs b/src/ix/kernel/mod.rs new file mode 100644 index 00000000..d6a5750e --- /dev/null +++ b/src/ix/kernel/mod.rs @@ -0,0 +1,11 @@ +pub mod convert; +pub mod dag; +pub mod def_eq; +pub mod dll; +pub mod error; +pub mod inductive; +pub mod level; +pub mod quot; +pub mod tc; +pub mod upcopy; +pub mod whnf; diff --git a/src/ix/kernel/quot.rs b/src/ix/kernel/quot.rs new file mode 100644 index 00000000..51a1e070 --- /dev/null +++ b/src/ix/kernel/quot.rs @@ -0,0 +1,291 @@ +use crate::ix::env::*; + +use super::error::TcError; + +type TcResult = Result; + +/// Verify that the quotient declarations are consistent with the environment. +/// Checks that Quot is an inductive, Quot.mk is its constructor, and +/// Quot.lift and Quot.ind exist. +pub fn check_quot(env: &Env) -> TcResult<()> { + let quot_name = Name::str(Name::anon(), "Quot".into()); + let quot_mk_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "mk".into()); + let quot_lift_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "lift".into()); + let quot_ind_name = + Name::str(Name::str(Name::anon(), "Quot".into()), "ind".into()); + + // Check Quot exists and is an inductive + let quot = + env.get("_name).ok_or(TcError::UnknownConst { name: quot_name })?; + match quot { + ConstantInfo::InductInfo(_) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot is not an inductive type".into(), + }) + }, + } + + // Check Quot.mk exists and is a constructor of Quot + let mk = env + .get("_mk_name) + .ok_or(TcError::UnknownConst { name: quot_mk_name })?; + match mk { + ConstantInfo::CtorInfo(c) + if c.induct + == Name::str(Name::anon(), "Quot".into()) => {}, + _ => { + return Err(TcError::KernelException { + msg: "Quot.mk is not a constructor of Quot".into(), + }) + }, + } + + // Check Eq exists as an inductive with exactly 1 universe param and 1 ctor + let eq_name = Name::str(Name::anon(), "Eq".into()); + if let Some(eq_ci) = env.get(&eq_name) { + match eq_ci { + ConstantInfo::InductInfo(iv) => { + if iv.cnst.level_params.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }); + } + if iv.ctors.len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Eq should have 1 constructor, found {}", + iv.ctors.len() + ), + }); + } + }, + _ => { + return Err(TcError::KernelException { + msg: "Eq is not an inductive type".into(), + }) + }, + } + } else { + return Err(TcError::KernelException { + msg: "Eq not found in environment (required for quotient types)".into(), + }); + } + + // Check Quot has exactly 1 level param + match quot { + ConstantInfo::InductInfo(iv) if iv.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot should have 1 universe parameter, found {}", + iv.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.mk has 1 level param + match mk { + ConstantInfo::CtorInfo(c) if c.cnst.level_params.len() != 1 => { + return Err(TcError::KernelException { + msg: format!( + "Quot.mk should have 1 universe parameter, found {}", + c.cnst.level_params.len() + ), + }) + }, + _ => {}, + } + + // Check Quot.lift exists and has 2 level params + let lift = env + .get("_lift_name) + .ok_or(TcError::UnknownConst { name: quot_lift_name })?; + if lift.get_level_params().len() != 2 { + return Err(TcError::KernelException { + msg: format!( + "Quot.lift should have 2 universe parameters, found {}", + lift.get_level_params().len() + ), + }); + } + + // Check Quot.ind exists and has 1 level param + let ind = env + .get("_ind_name) + .ok_or(TcError::UnknownConst { name: quot_ind_name })?; + if ind.get_level_params().len() != 1 { + return Err(TcError::KernelException { + msg: format!( + "Quot.ind should have 1 universe parameter, found {}", + ind.get_level_params().len() + ), + }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + /// Build a well-formed quotient environment. + fn mk_quot_env() -> Env { + let mut env = Env::default(); + let u = mk_name("u"); + let v = mk_name("v"); + let dummy_ty = Expr::sort(Level::param(u.clone())); + + // Eq.{u} — 1 uparam, 1 ctor + let eq_name = mk_name("Eq"); + let eq_refl = mk_name2("Eq", "refl"); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Eq"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + // Quot.{u} — 1 uparam + let quot_name = mk_name("Quot"); + let quot_mk = mk_name2("Quot", "mk"); + env.insert( + quot_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: quot_name.clone(), + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![quot_name], + ctors: vec![quot_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + env.insert( + quot_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: quot_mk, + level_params: vec![u.clone()], + typ: dummy_ty.clone(), + }, + induct: mk_name("Quot"), + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Quot.lift.{u,v} — 2 uparams + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![u.clone(), v.clone()], + typ: dummy_ty.clone(), + }, + is_unsafe: false, + }), + ); + + // Quot.ind.{u} — 1 uparam + let quot_ind = mk_name2("Quot", "ind"); + env.insert( + quot_ind.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_ind, + level_params: vec![u], + typ: dummy_ty, + }, + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_quot_well_formed() { + let env = mk_quot_env(); + assert!(check_quot(&env).is_ok()); + } + + #[test] + fn check_quot_missing_eq() { + let mut env = mk_quot_env(); + env.remove(&mk_name("Eq")); + assert!(check_quot(&env).is_err()); + } + + #[test] + fn check_quot_wrong_lift_levels() { + let mut env = mk_quot_env(); + // Replace Quot.lift with 1 level param instead of 2 + let quot_lift = mk_name2("Quot", "lift"); + env.insert( + quot_lift.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: quot_lift, + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }), + ); + assert!(check_quot(&env).is_err()); + } +} diff --git a/src/ix/kernel/tc.rs b/src/ix/kernel/tc.rs new file mode 100644 index 00000000..e80416fd --- /dev/null +++ b/src/ix/kernel/tc.rs @@ -0,0 +1,1694 @@ +use crate::ix::env::*; +use crate::lean::nat::Nat; +use rustc_hash::FxHashMap; + +use super::def_eq::def_eq; +use super::error::TcError; +use super::level::{all_expr_uparams_defined, no_dupes_all_params}; +use super::whnf::*; + +type TcResult = Result; + +/// The kernel type checker. +pub struct TypeChecker<'env> { + pub env: &'env Env, + pub whnf_cache: FxHashMap, + pub infer_cache: FxHashMap, + pub local_counter: u64, + pub local_types: FxHashMap, +} + +impl<'env> TypeChecker<'env> { + pub fn new(env: &'env Env) -> Self { + TypeChecker { + env, + whnf_cache: FxHashMap::default(), + infer_cache: FxHashMap::default(), + local_counter: 0, + local_types: FxHashMap::default(), + } + } + + // ========================================================================== + // WHNF with caching + // ========================================================================== + + pub fn whnf(&mut self, e: &Expr) -> Expr { + if let Some(cached) = self.whnf_cache.get(e) { + return cached.clone(); + } + let result = whnf(e, self.env); + self.whnf_cache.insert(e.clone(), result.clone()); + result + } + + // ========================================================================== + // Local context management + // ========================================================================== + + /// Create a fresh free variable for entering a binder. + pub fn mk_local(&mut self, name: &Name, ty: &Expr) -> Expr { + let id = self.local_counter; + self.local_counter += 1; + let local_name = Name::num(name.clone(), Nat::from(id)); + self.local_types.insert(local_name.clone(), ty.clone()); + Expr::fvar(local_name) + } + + // ========================================================================== + // Ensure helpers + // ========================================================================== + + pub fn ensure_sort(&mut self, e: &Expr) -> TcResult { + if let ExprData::Sort(level, _) = e.as_data() { + return Ok(level.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::Sort(level, _) => Ok(level.clone()), + _ => Err(TcError::TypeExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + pub fn ensure_pi(&mut self, e: &Expr) -> TcResult { + if let ExprData::ForallE(..) = e.as_data() { + return Ok(e.clone()); + } + let whnfd = self.whnf(e); + match whnfd.as_data() { + ExprData::ForallE(..) => Ok(whnfd), + _ => Err(TcError::FunctionExpected { + expr: e.clone(), + inferred: whnfd, + }), + } + } + + /// Infer the type of `e` and ensure it's a sort; return the universe level. + pub fn infer_sort_of(&mut self, e: &Expr) -> TcResult { + let ty = self.infer(e)?; + let whnfd = self.whnf(&ty); + self.ensure_sort(&whnfd) + } + + // ========================================================================== + // Type inference + // ========================================================================== + + pub fn infer(&mut self, e: &Expr) -> TcResult { + if let Some(cached) = self.infer_cache.get(e) { + return Ok(cached.clone()); + } + let result = self.infer_core(e)?; + self.infer_cache.insert(e.clone(), result.clone()); + Ok(result) + } + + fn infer_core(&mut self, e: &Expr) -> TcResult { + match e.as_data() { + ExprData::Sort(level, _) => self.infer_sort(level), + ExprData::Const(name, levels, _) => self.infer_const(name, levels), + ExprData::App(..) => self.infer_app(e), + ExprData::Lam(..) => self.infer_lambda(e), + ExprData::ForallE(..) => self.infer_pi(e), + ExprData::LetE(_, typ, val, body, _, _) => { + self.infer_let(typ, val, body) + }, + ExprData::Lit(lit, _) => self.infer_lit(lit), + ExprData::Proj(type_name, idx, structure, _) => { + self.infer_proj(type_name, idx, structure) + }, + ExprData::Mdata(_, inner, _) => self.infer(inner), + ExprData::Fvar(name, _) => { + match self.local_types.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(TcError::KernelException { + msg: "cannot infer type of free variable without context".into(), + }), + } + }, + ExprData::Bvar(idx, _) => Err(TcError::FreeBoundVariable { + idx: idx.to_u64().unwrap_or(u64::MAX), + }), + ExprData::Mvar(..) => Err(TcError::KernelException { + msg: "cannot infer type of metavariable".into(), + }), + } + } + + fn infer_sort(&mut self, level: &Level) -> TcResult { + Ok(Expr::sort(Level::succ(level.clone()))) + } + + fn infer_const( + &mut self, + name: &Name, + levels: &[Level], + ) -> TcResult { + let ci = self + .env + .get(name) + .ok_or_else(|| TcError::UnknownConst { name: name.clone() })?; + + let decl_params = ci.get_level_params(); + if levels.len() != decl_params.len() { + return Err(TcError::KernelException { + msg: format!( + "universe parameter count mismatch for {}", + name.pretty() + ), + }); + } + + let ty = ci.get_type(); + Ok(subst_expr_levels(ty, decl_params, levels)) + } + + fn infer_app(&mut self, e: &Expr) -> TcResult { + let (fun, args) = unfold_apps(e); + let mut fun_ty = self.infer(&fun)?; + + for arg in &args { + let pi = self.ensure_pi(&fun_ty)?; + match pi.as_data() { + ExprData::ForallE(_, binder_type, body, _, _) => { + // Check argument type matches binder + let arg_ty = self.infer(arg)?; + self.assert_def_eq(&arg_ty, binder_type)?; + fun_ty = inst(body, &[arg.clone()]); + }, + _ => unreachable!(), + } + } + + Ok(fun_ty) + } + + fn infer_lambda(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut binder_types = Vec::new(); + let mut binder_infos = Vec::new(); + let mut binder_names = Vec::new(); + + while let ExprData::Lam(name, binder_type, body, bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + self.infer_sort_of(&binder_type_inst)?; + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + binder_types.push(binder_type_inst); + binder_infos.push(bi.clone()); + binder_names.push(name.clone()); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let body_ty = self.infer(&body_inst)?; + + // Abstract back: build Pi telescope + let mut result = abstr(&body_ty, &locals); + for i in (0..locals.len()).rev() { + let binder_type_abstrd = abstr(&binder_types[i], &locals[..i]); + result = Expr::all( + binder_names[i].clone(), + binder_type_abstrd, + result, + binder_infos[i].clone(), + ); + } + + Ok(result) + } + + fn infer_pi(&mut self, e: &Expr) -> TcResult { + let mut cursor = e.clone(); + let mut locals = Vec::new(); + let mut universes = Vec::new(); + + while let ExprData::ForallE(name, binder_type, body, _bi, _) = + cursor.as_data() + { + let binder_type_inst = inst(binder_type, &locals); + let dom_univ = self.infer_sort_of(&binder_type_inst)?; + universes.push(dom_univ); + + let local = self.mk_local(name, &binder_type_inst); + locals.push(local); + cursor = body.clone(); + } + + let body_inst = inst(&cursor, &locals); + let mut result_level = self.infer_sort_of(&body_inst)?; + + for univ in universes.into_iter().rev() { + result_level = Level::imax(univ, result_level); + } + + Ok(Expr::sort(result_level)) + } + + fn infer_let( + &mut self, + typ: &Expr, + val: &Expr, + body: &Expr, + ) -> TcResult { + // Verify value matches declared type + let val_ty = self.infer(val)?; + self.assert_def_eq(&val_ty, typ)?; + let body_inst = inst(body, &[val.clone()]); + self.infer(&body_inst) + } + + fn infer_lit(&mut self, lit: &Literal) -> TcResult { + match lit { + Literal::NatVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "Nat".into()), vec![])) + }, + Literal::StrVal(_) => { + Ok(Expr::cnst(Name::str(Name::anon(), "String".into()), vec![])) + }, + } + } + + fn infer_proj( + &mut self, + type_name: &Name, + idx: &Nat, + structure: &Expr, + ) -> TcResult { + let structure_ty = self.infer(structure)?; + let structure_ty_whnf = self.whnf(&structure_ty); + + let (_, struct_ty_args) = unfold_apps(&structure_ty_whnf); + let struct_ty_head = match unfold_apps(&structure_ty_whnf).0.as_data() { + ExprData::Const(name, levels, _) => (name.clone(), levels.clone()), + _ => { + return Err(TcError::KernelException { + msg: "projection structure type is not a constant".into(), + }) + }, + }; + + let ind = self.env.get(&struct_ty_head.0).ok_or_else(|| { + TcError::UnknownConst { name: struct_ty_head.0.clone() } + })?; + + let (num_params, ctor_name) = match ind { + ConstantInfo::InductInfo(iv) => { + let ctor = iv.ctors.first().ok_or_else(|| { + TcError::KernelException { + msg: "inductive has no constructors".into(), + } + })?; + (iv.num_params.to_u64().unwrap(), ctor.clone()) + }, + _ => { + return Err(TcError::KernelException { + msg: "projection type is not an inductive".into(), + }) + }, + }; + + let ctor_ci = self.env.get(&ctor_name).ok_or_else(|| { + TcError::UnknownConst { name: ctor_name.clone() } + })?; + + let mut ctor_ty = subst_expr_levels( + ctor_ci.get_type(), + ctor_ci.get_level_params(), + &struct_ty_head.1, + ); + + // Skip params + for i in 0..num_params as usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + ctor_ty = inst(body, &[struct_ty_args[i].clone()]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (params)".into(), + }) + }, + } + } + + // Walk to the idx-th field + let idx_usize = idx.to_u64().unwrap() as usize; + for i in 0..idx_usize { + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, _, body, _, _) => { + let proj = + Expr::proj(type_name.clone(), Nat::from(i as u64), structure.clone()); + ctor_ty = inst(body, &[proj]); + }, + _ => { + return Err(TcError::KernelException { + msg: "ran out of constructor telescope (fields)".into(), + }) + }, + } + } + + let whnf_ty = self.whnf(&ctor_ty); + match whnf_ty.as_data() { + ExprData::ForallE(_, binder_type, _, _, _) => { + Ok(binder_type.clone()) + }, + _ => Err(TcError::KernelException { + msg: "ran out of constructor telescope (target field)".into(), + }), + } + } + + // ========================================================================== + // Definitional equality (delegated to def_eq module) + // ========================================================================== + + pub fn def_eq(&mut self, x: &Expr, y: &Expr) -> bool { + def_eq(x, y, self) + } + + pub fn assert_def_eq(&mut self, x: &Expr, y: &Expr) -> TcResult<()> { + if self.def_eq(x, y) { + Ok(()) + } else { + Err(TcError::DefEqFailure { lhs: x.clone(), rhs: y.clone() }) + } + } + + // ========================================================================== + // Declaration checking + // ========================================================================== + + /// Check that a declaration's type is well-formed. + pub fn check_declar_info( + &mut self, + info: &ConstantVal, + ) -> TcResult<()> { + // Check for duplicate universe params + if !no_dupes_all_params(&info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "duplicate universe parameters in {}", + info.name.pretty() + ), + }); + } + + // Check that the type has no loose bound variables + if has_loose_bvars(&info.typ) { + return Err(TcError::KernelException { + msg: format!( + "free bound variables in type of {}", + info.name.pretty() + ), + }); + } + + // Check that all universe parameters in the type are declared + if !all_expr_uparams_defined(&info.typ, &info.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in type of {}", + info.name.pretty() + ), + }); + } + + // Check that the type is a type (infers to a Sort) + let inferred = self.infer(&info.typ)?; + self.ensure_sort(&inferred)?; + + Ok(()) + } + + /// Check a single declaration. + pub fn check_declar( + &mut self, + ci: &ConstantInfo, + ) -> TcResult<()> { + match ci { + ConstantInfo::AxiomInfo(v) => { + self.check_declar_info(&v.cnst)?; + }, + ConstantInfo::DefnInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::ThmInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::OpaqueInfo(v) => { + self.check_declar_info(&v.cnst)?; + if !all_expr_uparams_defined(&v.value, &v.cnst.level_params) { + return Err(TcError::KernelException { + msg: format!( + "undeclared universe parameters in value of {}", + v.cnst.name.pretty() + ), + }); + } + let inferred_type = self.infer(&v.value)?; + self.assert_def_eq(&inferred_type, &v.cnst.typ)?; + }, + ConstantInfo::QuotInfo(v) => { + self.check_declar_info(&v.cnst)?; + super::quot::check_quot(self.env)?; + }, + ConstantInfo::InductInfo(v) => { + super::inductive::check_inductive(v, self)?; + }, + ConstantInfo::CtorInfo(v) => { + self.check_declar_info(&v.cnst)?; + // Verify the parent inductive exists + if self.env.get(&v.induct).is_none() { + return Err(TcError::UnknownConst { + name: v.induct.clone(), + }); + } + }, + ConstantInfo::RecInfo(v) => { + self.check_declar_info(&v.cnst)?; + for ind_name in &v.all { + if self.env.get(ind_name).is_none() { + return Err(TcError::UnknownConst { + name: ind_name.clone(), + }); + } + } + super::inductive::validate_k_flag(v, self.env)?; + }, + } + Ok(()) + } +} + +/// Check all declarations in an environment. +pub fn check_env(env: &Env) -> Vec<(Name, TcError)> { + let mut errors = Vec::new(); + for (name, ci) in env.iter() { + let mut tc = TypeChecker::new(env); + if let Err(e) = tc.check_declar(ci) { + errors.push((name.clone(), e)); + } + } + errors +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lean::nat::Nat; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + fn prop() -> Expr { + Expr::sort(Level::zero()) + } + + fn type_u() -> Expr { + Expr::sort(Level::param(mk_name("u"))) + } + + /// Build a minimal environment with Nat, Nat.zero, and Nat.succ. + fn mk_nat_env() -> Env { + let mut env = Env::default(); + + let nat_name = mk_name("Nat"); + // Nat : Sort 1 + let nat = ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: nat_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![nat_name.clone()], + ctors: vec![mk_name2("Nat", "zero"), mk_name2("Nat", "succ")], + num_nested: Nat::from(0u64), + is_rec: true, + is_unsafe: false, + is_reflexive: false, + }); + env.insert(nat_name, nat); + + // Nat.zero : Nat + let zero_name = mk_name2("Nat", "zero"); + let zero = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: zero_name.clone(), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + env.insert(zero_name, zero); + + // Nat.succ : Nat → Nat + let succ_name = mk_name2("Nat", "succ"); + let succ_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let succ = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: succ_name.clone(), + level_params: vec![], + typ: succ_ty, + }, + induct: mk_name("Nat"), + cidx: Nat::from(1u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }); + env.insert(succ_name, succ); + + env + } + + // ========================================================================== + // Infer: Sort + // ========================================================================== + + #[test] + fn infer_sort_zero() { + // Sort(0) : Sort(1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = prop(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_sort_succ() { + // Sort(1) : Sort(2) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::succ(Level::zero())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::succ(Level::zero())))); + } + + #[test] + fn infer_sort_param() { + // Sort(u) : Sort(u+1) + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let u = Level::param(mk_name("u")); + let e = Expr::sort(u.clone()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(u))); + } + + // ========================================================================== + // Infer: Const + // ========================================================================== + + #[test] + fn infer_const_nat() { + // Nat : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::succ(Level::zero()))); + } + + #[test] + fn infer_const_nat_zero() { + // Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_const_nat_succ() { + // Nat.succ : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let ty = tc.infer(&e).unwrap(); + let expected = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_unknown() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("NonExistent"), vec![]); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_const_universe_mismatch() { + // Nat has 0 universe params; passing 1 should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::cnst(mk_name("Nat"), vec![Level::zero()]); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // Infer: Lit + // ========================================================================== + + #[test] + fn infer_nat_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::NatVal(Nat::from(42u64))); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_string_lit() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::lit(Literal::StrVal("hello".into())); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::cnst(mk_name("String"), vec![])); + } + + // ========================================================================== + // Infer: Lambda + // ========================================================================== + + #[test] + fn infer_identity_lambda() { + // fun (x : Nat) => x : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let ty = tc.infer(&id_fn).unwrap(); + let expected = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + #[test] + fn infer_const_lambda() { + // fun (x : Nat) (y : Nat) => x : Nat → Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let body = Expr::lam( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(1u64)), // x + BinderInfo::Default, + ); + let k_fn = Expr::lam( + mk_name("x"), + nat_type(), + body, + BinderInfo::Default, + ); + let ty = tc.infer(&k_fn).unwrap(); + // Nat → Nat → Nat + let expected = Expr::all( + mk_name("x"), + nat_type(), + Expr::all( + mk_name("y"), + nat_type(), + nat_type(), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + assert_eq!(ty, expected); + } + + // ========================================================================== + // Infer: App + // ========================================================================== + + #[test] + fn infer_app_succ_zero() { + // Nat.succ Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + #[test] + fn infer_app_identity() { + // (fun x : Nat => x) Nat.zero : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let id_fn = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let e = Expr::app(id_fn, nat_zero()); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Pi + // ========================================================================== + + #[test] + fn infer_pi_nat_to_nat() { + // (Nat → Nat) : Sort 1 + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let pi = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(1, 1)) which simplifies to Sort(1) + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::eq_antisymm( + level, + &Level::succ(Level::zero()) + ), + "Nat → Nat should live in Sort 1, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + #[test] + fn infer_pi_prop_to_prop() { + // (Prop → Prop) : Sort 1 + // An axiom P : Prop, then P → P : Sort 1 + let mut env = Env::default(); + let p_name = mk_name("P"); + env.insert( + p_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: p_name.clone(), + level_params: vec![], + typ: prop(), + }, + is_unsafe: false, + }), + ); + + let mut tc = TypeChecker::new(&env); + let p = Expr::cnst(p_name, vec![]); + let pi = Expr::all(mk_name("x"), p.clone(), p.clone(), BinderInfo::Default); + let ty = tc.infer(&pi).unwrap(); + // Sort(imax(0, 0)) = Sort(0) = Prop + if let ExprData::Sort(level, _) = ty.as_data() { + assert!( + super::super::level::is_zero(level), + "Prop → Prop should live in Prop, got {:?}", + level + ); + } else { + panic!("Expected Sort, got {:?}", ty); + } + } + + // ========================================================================== + // Infer: Let + // ========================================================================== + + #[test] + fn infer_let_simple() { + // let x : Nat := Nat.zero in x : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: errors + // ========================================================================== + + #[test] + fn infer_free_bvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::bvar(Nat::from(0u64)); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_fvar_fails() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::fvar(mk_name("x")); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_app_wrong_arg_type() { + // Nat.succ expects Nat, but we pass Sort(0) — should fail with DefEqFailure + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name2("Nat", "succ"), vec![]), + prop(), // Sort(0), not Nat + ); + assert!(tc.infer(&e).is_err()); + } + + #[test] + fn infer_let_type_mismatch() { + // let x : Nat → Nat := Nat.zero in x + // Nat.zero : Nat, but annotation says Nat → Nat — should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let nat_to_nat = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let e = Expr::letE( + mk_name("x"), + nat_to_nat, + nat_zero(), + Expr::bvar(Nat::from(0u64)), + false, + ); + assert!(tc.infer(&e).is_err()); + } + + // ========================================================================== + // check_declar + // ========================================================================== + + #[test] + fn check_axiom_declar() { + // axiom myAxiom : Nat → Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax_ty = Expr::all( + mk_name("n"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("myAxiom"), + level_params: vec![], + typ: ax_ty, + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_declar() { + // def myId : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }); + assert!(tc.check_declar(&defn).is_ok()); + } + + #[test] + fn check_defn_type_mismatch() { + // def bad : Nat := Nat.succ (wrong: Nat.succ : Nat → Nat, not Nat) + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + #[test] + fn check_declar_loose_bvar() { + // Type with a dangling bound variable should fail + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![], + typ: Expr::bvar(Nat::from(0u64)), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_declar_duplicate_uparams() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let u = mk_name("u"); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![u.clone(), u], + typ: type_u(), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + // ========================================================================== + // check_env + // ========================================================================== + + #[test] + fn check_nat_env() { + let env = mk_nat_env(); + let errors = check_env(&env); + assert!(errors.is_empty(), "Expected no errors, got: {:?}", errors); + } + + // ========================================================================== + // Polymorphic constants + // ========================================================================== + + #[test] + fn infer_polymorphic_const() { + // axiom A.{u} : Sort u + // A.{0} should give Sort(0) + let mut env = Env::default(); + let a_name = mk_name("A"); + let u_name = mk_name("u"); + env.insert( + a_name.clone(), + ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: a_name.clone(), + level_params: vec![u_name.clone()], + typ: Expr::sort(Level::param(u_name)), + }, + is_unsafe: false, + }), + ); + let mut tc = TypeChecker::new(&env); + // A.{0} : Sort(0) + let e = Expr::cnst(a_name, vec![Level::zero()]); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, Expr::sort(Level::zero())); + } + + // ========================================================================== + // Infer: whnf caching + // ========================================================================== + + #[test] + fn whnf_cache_works() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let e = Expr::sort(Level::zero()); + let r1 = tc.whnf(&e); + let r2 = tc.whnf(&e); + assert_eq!(r1, r2); + } + + // ========================================================================== + // check_declar: Theorem + // ========================================================================== + + #[test] + fn check_theorem_declar() { + // theorem myThm : Nat → Nat := fun x => x + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("myThm"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + all: vec![mk_name("myThm")], + }); + assert!(tc.check_declar(&thm).is_ok()); + } + + #[test] + fn check_theorem_type_mismatch() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let thm = ConstantInfo::ThmInfo(TheoremVal { + cnst: ConstantVal { + name: mk_name("badThm"), + level_params: vec![], + typ: nat_type(), // claims : Nat + }, + value: Expr::cnst(mk_name2("Nat", "succ"), vec![]), // but is : Nat → Nat + all: vec![mk_name("badThm")], + }); + assert!(tc.check_declar(&thm).is_err()); + } + + // ========================================================================== + // check_declar: Opaque + // ========================================================================== + + #[test] + fn check_opaque_declar() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let opaque = ConstantInfo::OpaqueInfo(OpaqueVal { + cnst: ConstantVal { + name: mk_name("myOpaque"), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + is_unsafe: false, + all: vec![mk_name("myOpaque")], + }); + assert!(tc.check_declar(&opaque).is_ok()); + } + + // ========================================================================== + // check_declar: Ctor (parent existence check) + // ========================================================================== + + #[test] + fn check_ctor_missing_parent() { + // A constructor whose parent inductive doesn't exist + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "mk"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + induct: mk_name("Fake"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_err()); + } + + #[test] + fn check_ctor_with_parent() { + // Nat.zero : Nat, with Nat in env + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ctor = ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "zero"), + level_params: vec![], + typ: nat_type(), + }, + induct: mk_name("Nat"), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }); + assert!(tc.check_declar(&ctor).is_ok()); + } + + // ========================================================================== + // check_declar: Rec (mutual reference check) + // ========================================================================== + + #[test] + fn check_rec_missing_inductive() { + let env = Env::default(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Fake", "rec"), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + all: vec![mk_name("Fake")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(0u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } + + #[test] + fn check_rec_with_inductive() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: false, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + // ========================================================================== + // Infer: App with delta (definition in head) + // ========================================================================== + + #[test] + fn infer_app_through_delta() { + // def myId : Nat → Nat := fun x => x + // myId Nat.zero : Nat + let mut env = mk_nat_env(); + let fun_ty = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let body = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0u64)), + BinderInfo::Default, + ); + env.insert( + mk_name("myId"), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("myId"), + level_params: vec![], + typ: fun_ty, + }, + value: body, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("myId")], + }), + ); + let mut tc = TypeChecker::new(&env); + let e = Expr::app( + Expr::cnst(mk_name("myId"), vec![]), + nat_zero(), + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: Proj + // ========================================================================== + + /// Build an env with a simple Prod.{u,v} structure type. + fn mk_prod_env() -> Env { + let mut env = mk_nat_env(); + let u_name = mk_name("u"); + let v_name = mk_name("v"); + let prod_name = mk_name("Prod"); + let mk_name_prod = mk_name2("Prod", "mk"); + + // Prod.{u,v} : Sort u → Sort v → Sort (max u v) + // Simplified: Prod (α : Sort u) (β : Sort v) : Sort (max u v) + let prod_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::sort(Level::max( + Level::param(u_name.clone()), + Level::param(v_name.clone()), + )), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + prod_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: prod_name.clone(), + level_params: vec![u_name.clone(), v_name.clone()], + typ: prod_type, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(0u64), + all: vec![prod_name.clone()], + ctors: vec![mk_name_prod.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + + // Prod.mk.{u,v} (α : Sort u) (β : Sort v) (fst : α) (snd : β) : Prod α β + // Type: (α : Sort u) → (β : Sort v) → α → β → Prod α β + let ctor_type = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u_name.clone())), + Expr::all( + mk_name("β"), + Expr::sort(Level::param(v_name.clone())), + Expr::all( + mk_name("fst"), + Expr::bvar(Nat::from(1u64)), // α + Expr::all( + mk_name("snd"), + Expr::bvar(Nat::from(1u64)), // β + Expr::app( + Expr::app( + Expr::cnst( + prod_name.clone(), + vec![ + Level::param(u_name.clone()), + Level::param(v_name.clone()), + ], + ), + Expr::bvar(Nat::from(3u64)), // α + ), + Expr::bvar(Nat::from(2u64)), // β + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + + env.insert( + mk_name_prod.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: mk_name_prod, + level_params: vec![u_name, v_name], + typ: ctor_type, + }, + induct: prod_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(2u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn infer_proj_fst() { + // Given p : Prod Nat Nat, (Prod.1 p) : Nat + // Build: Prod.mk Nat Nat Nat.zero Nat.zero, then project field 0 + let env = mk_prod_env(); + let mut tc = TypeChecker::new(&env); + + let one = Level::succ(Level::zero()); + let pair = Expr::app( + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + mk_name2("Prod", "mk"), + vec![one.clone(), one.clone()], + ), + nat_type(), + ), + nat_type(), + ), + nat_zero(), + ), + nat_zero(), + ); + + let proj = Expr::proj(mk_name("Prod"), Nat::from(0u64), pair); + let ty = tc.infer(&proj).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer: nested let + // ========================================================================== + + #[test] + fn infer_nested_let() { + // let x := Nat.zero in let y := x in y : Nat + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let inner = Expr::letE( + mk_name("y"), + nat_type(), + Expr::bvar(Nat::from(0u64)), // x + Expr::bvar(Nat::from(0u64)), // y + false, + ); + let e = Expr::letE( + mk_name("x"), + nat_type(), + nat_zero(), + inner, + false, + ); + let ty = tc.infer(&e).unwrap(); + assert_eq!(ty, nat_type()); + } + + // ========================================================================== + // Infer caching + // ========================================================================== + + #[test] + fn infer_cache_hit() { + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let e = nat_zero(); + let ty1 = tc.infer(&e).unwrap(); + let ty2 = tc.infer(&e).unwrap(); + assert_eq!(ty1, ty2); + assert_eq!(tc.infer_cache.len(), 1); + } + + // ========================================================================== + // Universe parameter validation + // ========================================================================== + + #[test] + fn check_axiom_undeclared_uparam_in_type() { + // axiom bad.{u} : Sort v — v is not declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("v"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_err()); + } + + #[test] + fn check_axiom_declared_uparam_in_type() { + // axiom good.{u} : Sort u — u is declared + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let ax = ConstantInfo::AxiomInfo(AxiomVal { + cnst: ConstantVal { + name: mk_name("good"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + is_unsafe: false, + }); + assert!(tc.check_declar(&ax).is_ok()); + } + + #[test] + fn check_defn_undeclared_uparam_in_value() { + // def bad.{u} : Sort 1 := Sort v — v not declared, in value + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let defn = ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: mk_name("bad"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::succ(Level::zero())), + }, + value: Expr::sort(Level::param(mk_name("v"))), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![mk_name("bad")], + }); + assert!(tc.check_declar(&defn).is_err()); + } + + // ========================================================================== + // K-flag validation + // ========================================================================== + + /// Build an env with a Prop inductive + single zero-field ctor (Eq-like). + fn mk_eq_like_env() -> Env { + let mut env = mk_nat_env(); + let u = mk_name("u"); + let eq_name = mk_name("MyEq"); + let eq_refl = mk_name2("MyEq", "refl"); + + // MyEq.{u} (α : Sort u) (a : α) : α → Prop + // Simplified: type lives in Prop (Sort 0) + let eq_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::all( + mk_name("b"), + Expr::bvar(Nat::from(1u64)), + Expr::sort(Level::zero()), + BinderInfo::Default, + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: eq_name.clone(), + level_params: vec![u.clone()], + typ: eq_ty, + }, + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + all: vec![eq_name.clone()], + ctors: vec![eq_refl.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: true, + }), + ); + // MyEq.refl.{u} (α : Sort u) (a : α) : MyEq α a a + // zero fields + let refl_ty = Expr::all( + mk_name("α"), + Expr::sort(Level::param(u.clone())), + Expr::all( + mk_name("a"), + Expr::bvar(Nat::from(0u64)), + Expr::app( + Expr::app( + Expr::app( + Expr::cnst(eq_name.clone(), vec![Level::param(u.clone())]), + Expr::bvar(Nat::from(1u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + Expr::bvar(Nat::from(0u64)), + ), + BinderInfo::Default, + ), + BinderInfo::Default, + ); + env.insert( + eq_refl.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: eq_refl, + level_params: vec![u], + typ: refl_ty, + }, + induct: eq_name, + cidx: Nat::from(0u64), + num_params: Nat::from(2u64), + num_fields: Nat::from(0u64), + is_unsafe: false, + }), + ); + + env + } + + #[test] + fn check_rec_k_flag_valid() { + let env = mk_eq_like_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("MyEq", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("MyEq")], + num_params: Nat::from(2u64), + num_indices: Nat::from(1u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(1u64), + rules: vec![], + k: true, + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_ok()); + } + + #[test] + fn check_rec_k_flag_invalid_2_ctors() { + // Nat has 2 constructors — K should fail + let env = mk_nat_env(); + let mut tc = TypeChecker::new(&env); + let rec = ConstantInfo::RecInfo(RecursorVal { + cnst: ConstantVal { + name: mk_name2("Nat", "rec"), + level_params: vec![mk_name("u")], + typ: Expr::sort(Level::param(mk_name("u"))), + }, + all: vec![mk_name("Nat")], + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + num_motives: Nat::from(1u64), + num_minors: Nat::from(2u64), + rules: vec![], + k: true, // invalid: Nat is not in Prop and has 2 ctors + is_unsafe: false, + }); + assert!(tc.check_declar(&rec).is_err()); + } +} diff --git a/src/ix/kernel/upcopy.rs b/src/ix/kernel/upcopy.rs new file mode 100644 index 00000000..89dae8a0 --- /dev/null +++ b/src/ix/kernel/upcopy.rs @@ -0,0 +1,659 @@ +use core::ptr::NonNull; + +use crate::ix::env::{BinderInfo, Name}; + +use super::dag::*; +use super::dll::DLL; + +// ============================================================================ +// Upcopy +// ============================================================================ + +pub fn upcopy(new_child: DAGPtr, cc: ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + let var = &lam.var; + let new_lam = alloc_lam(var.depth, new_child, None); + let new_lam_ref = &mut *new_lam.as_ptr(); + let bod_ref_ptr = + NonNull::new(&mut new_lam_ref.bod_ref as *mut Parents).unwrap(); + add_to_parents(new_child, bod_ref_ptr); + let new_var_ptr = + NonNull::new(&mut new_lam_ref.var as *mut Var).unwrap(); + for parent in DLL::iter_option(var.parents) { + upcopy(DAGPtr::Var(new_var_ptr), *parent); + } + for parent in DLL::iter_option(lam.parents) { + upcopy(DAGPtr::Lam(new_lam), *parent); + } + }, + ParentPtr::AppFun(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).fun = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(new_child, app.arg); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + match app.copy { + Some(cache) => { + (*cache.as_ptr()).arg = new_child; + }, + None => { + let new_app = alloc_app_no_uplinks(app.fun, new_child); + app.copy = Some(new_app); + for parent in DLL::iter_option(app.parents) { + upcopy(DAGPtr::App(new_app), *parent); + } + }, + } + }, + ParentPtr::FunDom(link) => { + let fun = &mut *link.as_ptr(); + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + new_child, + fun.img, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + // new_child must be a Lam + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("FunImg parent expects Lam child"), + }; + match fun.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_fun = alloc_fun_no_uplinks( + fun.binder_name.clone(), + fun.binder_info.clone(), + fun.dom, + new_lam, + ); + fun.copy = Some(new_fun); + for parent in DLL::iter_option(fun.parents) { + upcopy(DAGPtr::Fun(new_fun), *parent); + } + }, + } + }, + ParentPtr::PiDom(link) => { + let pi = &mut *link.as_ptr(); + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).dom = new_child; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + new_child, + pi.img, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("PiImg parent expects Lam child"), + }; + match pi.copy { + Some(cache) => { + (*cache.as_ptr()).img = new_lam; + }, + None => { + let new_pi = alloc_pi_no_uplinks( + pi.binder_name.clone(), + pi.binder_info.clone(), + pi.dom, + new_lam, + ); + pi.copy = Some(new_pi); + for parent in DLL::iter_option(pi.parents) { + upcopy(DAGPtr::Pi(new_pi), *parent); + } + }, + } + }, + ParentPtr::LetTyp(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).typ = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + new_child, + let_node.val, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetVal(link) => { + let let_node = &mut *link.as_ptr(); + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).val = new_child; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + new_child, + let_node.bod, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + let new_lam = match new_child { + DAGPtr::Lam(p) => p, + _ => panic!("LetBod parent expects Lam child"), + }; + match let_node.copy { + Some(cache) => { + (*cache.as_ptr()).bod = new_lam; + }, + None => { + let new_let = alloc_let_no_uplinks( + let_node.binder_name.clone(), + let_node.non_dep, + let_node.typ, + let_node.val, + new_lam, + ); + let_node.copy = Some(new_let); + for parent in DLL::iter_option(let_node.parents) { + upcopy(DAGPtr::Let(new_let), *parent); + } + }, + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + let new_proj = alloc_proj_no_uplinks( + proj.type_name.clone(), + proj.idx.clone(), + new_child, + ); + for parent in DLL::iter_option(proj.parents) { + upcopy(DAGPtr::Proj(new_proj), *parent); + } + }, + } + } +} + +// ============================================================================ +// No-uplink allocators for upcopy +// ============================================================================ + +fn alloc_app_no_uplinks(fun: DAGPtr, arg: DAGPtr) -> NonNull { + let app_ptr = alloc_val(App { + fun, + arg, + fun_ref: DLL::singleton(ParentPtr::Root), + arg_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let app = &mut *app_ptr.as_ptr(); + app.fun_ref = DLL::singleton(ParentPtr::AppFun(app_ptr)); + app.arg_ref = DLL::singleton(ParentPtr::AppArg(app_ptr)); + } + app_ptr +} + +fn alloc_fun_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let fun_ptr = alloc_val(Fun { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let fun = &mut *fun_ptr.as_ptr(); + fun.dom_ref = DLL::singleton(ParentPtr::FunDom(fun_ptr)); + fun.img_ref = DLL::singleton(ParentPtr::FunImg(fun_ptr)); + } + fun_ptr +} + +fn alloc_pi_no_uplinks( + binder_name: Name, + binder_info: BinderInfo, + dom: DAGPtr, + img: NonNull, +) -> NonNull { + let pi_ptr = alloc_val(Pi { + binder_name, + binder_info, + dom, + img, + dom_ref: DLL::singleton(ParentPtr::Root), + img_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let pi = &mut *pi_ptr.as_ptr(); + pi.dom_ref = DLL::singleton(ParentPtr::PiDom(pi_ptr)); + pi.img_ref = DLL::singleton(ParentPtr::PiImg(pi_ptr)); + } + pi_ptr +} + +fn alloc_let_no_uplinks( + binder_name: Name, + non_dep: bool, + typ: DAGPtr, + val: DAGPtr, + bod: NonNull, +) -> NonNull { + let let_ptr = alloc_val(LetNode { + binder_name, + non_dep, + typ, + val, + bod, + typ_ref: DLL::singleton(ParentPtr::Root), + val_ref: DLL::singleton(ParentPtr::Root), + bod_ref: DLL::singleton(ParentPtr::Root), + copy: None, + parents: None, + }); + unsafe { + let let_node = &mut *let_ptr.as_ptr(); + let_node.typ_ref = DLL::singleton(ParentPtr::LetTyp(let_ptr)); + let_node.val_ref = DLL::singleton(ParentPtr::LetVal(let_ptr)); + let_node.bod_ref = DLL::singleton(ParentPtr::LetBod(let_ptr)); + } + let_ptr +} + +fn alloc_proj_no_uplinks( + type_name: Name, + idx: crate::lean::nat::Nat, + expr: DAGPtr, +) -> NonNull { + let proj_ptr = alloc_val(ProjNode { + type_name, + idx, + expr, + expr_ref: DLL::singleton(ParentPtr::Root), + parents: None, + }); + unsafe { + let proj = &mut *proj_ptr.as_ptr(); + proj.expr_ref = DLL::singleton(ParentPtr::ProjExpr(proj_ptr)); + } + proj_ptr +} + +// ============================================================================ +// Clean up: Clear copy caches after reduction +// ============================================================================ + +pub fn clean_up(cc: &ParentPtr) { + unsafe { + match cc { + ParentPtr::Root => {}, + ParentPtr::LamBod(link) => { + let lam = &*link.as_ptr(); + for parent in DLL::iter_option(lam.var.parents) { + clean_up(parent); + } + for parent in DLL::iter_option(lam.parents) { + clean_up(parent); + } + }, + ParentPtr::AppFun(link) | ParentPtr::AppArg(link) => { + let app = &mut *link.as_ptr(); + if let Some(app_copy) = app.copy { + let App { fun, arg, fun_ref, arg_ref, .. } = + &mut *app_copy.as_ptr(); + app.copy = None; + add_to_parents(*fun, NonNull::new(fun_ref).unwrap()); + add_to_parents(*arg, NonNull::new(arg_ref).unwrap()); + for parent in DLL::iter_option(app.parents) { + clean_up(parent); + } + } + }, + ParentPtr::FunDom(link) | ParentPtr::FunImg(link) => { + let fun = &mut *link.as_ptr(); + if let Some(fun_copy) = fun.copy { + let Fun { dom, img, dom_ref, img_ref, .. } = + &mut *fun_copy.as_ptr(); + fun.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(fun.parents) { + clean_up(parent); + } + } + }, + ParentPtr::PiDom(link) | ParentPtr::PiImg(link) => { + let pi = &mut *link.as_ptr(); + if let Some(pi_copy) = pi.copy { + let Pi { dom, img, dom_ref, img_ref, .. } = + &mut *pi_copy.as_ptr(); + pi.copy = None; + add_to_parents(*dom, NonNull::new(dom_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*img), NonNull::new(img_ref).unwrap()); + for parent in DLL::iter_option(pi.parents) { + clean_up(parent); + } + } + }, + ParentPtr::LetTyp(link) + | ParentPtr::LetVal(link) + | ParentPtr::LetBod(link) => { + let let_node = &mut *link.as_ptr(); + if let Some(let_copy) = let_node.copy { + let LetNode { typ, val, bod, typ_ref, val_ref, bod_ref, .. } = + &mut *let_copy.as_ptr(); + let_node.copy = None; + add_to_parents(*typ, NonNull::new(typ_ref).unwrap()); + add_to_parents(*val, NonNull::new(val_ref).unwrap()); + add_to_parents(DAGPtr::Lam(*bod), NonNull::new(bod_ref).unwrap()); + for parent in DLL::iter_option(let_node.parents) { + clean_up(parent); + } + } + }, + ParentPtr::ProjExpr(link) => { + let proj = &*link.as_ptr(); + for parent in DLL::iter_option(proj.parents) { + clean_up(parent); + } + }, + } + } +} + +// ============================================================================ +// Replace child +// ============================================================================ + +pub fn replace_child(old: DAGPtr, new: DAGPtr) { + unsafe { + if let Some(parents) = get_parents(old) { + for parent in DLL::iter_option(Some(parents)) { + match parent { + ParentPtr::Root => {}, + ParentPtr::LamBod(p) => (*p.as_ptr()).bod = new, + ParentPtr::FunDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::FunImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("FunImg expects Lam"), + }, + ParentPtr::PiDom(p) => (*p.as_ptr()).dom = new, + ParentPtr::PiImg(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).img = lam, + _ => panic!("PiImg expects Lam"), + }, + ParentPtr::AppFun(p) => (*p.as_ptr()).fun = new, + ParentPtr::AppArg(p) => (*p.as_ptr()).arg = new, + ParentPtr::LetTyp(p) => (*p.as_ptr()).typ = new, + ParentPtr::LetVal(p) => (*p.as_ptr()).val = new, + ParentPtr::LetBod(p) => match new { + DAGPtr::Lam(lam) => (*p.as_ptr()).bod = lam, + _ => panic!("LetBod expects Lam"), + }, + ParentPtr::ProjExpr(p) => (*p.as_ptr()).expr = new, + } + } + set_parents(old, None); + match get_parents(new) { + None => set_parents(new, Some(parents)), + Some(new_parents) => { + DLL::concat(new_parents, Some(parents)); + }, + } + } + } +} + +// ============================================================================ +// Free dead nodes +// ============================================================================ + +pub fn free_dead_node(node: DAGPtr) { + unsafe { + match node { + DAGPtr::Lam(link) => { + let lam = &*link.as_ptr(); + let bod_ref_ptr = &lam.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(lam.bod, Some(remaining)); + } else { + set_parents(lam.bod, None); + free_dead_node(lam.bod); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::App(link) => { + let app = &*link.as_ptr(); + let fun_ref_ptr = &app.fun_ref as *const Parents; + if let Some(remaining) = (*fun_ref_ptr).unlink_node() { + set_parents(app.fun, Some(remaining)); + } else { + set_parents(app.fun, None); + free_dead_node(app.fun); + } + let arg_ref_ptr = &app.arg_ref as *const Parents; + if let Some(remaining) = (*arg_ref_ptr).unlink_node() { + set_parents(app.arg, Some(remaining)); + } else { + set_parents(app.arg, None); + free_dead_node(app.arg); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Fun(link) => { + let fun = &*link.as_ptr(); + let dom_ref_ptr = &fun.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(fun.dom, Some(remaining)); + } else { + set_parents(fun.dom, None); + free_dead_node(fun.dom); + } + let img_ref_ptr = &fun.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(fun.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(fun.img), None); + free_dead_node(DAGPtr::Lam(fun.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Pi(link) => { + let pi = &*link.as_ptr(); + let dom_ref_ptr = &pi.dom_ref as *const Parents; + if let Some(remaining) = (*dom_ref_ptr).unlink_node() { + set_parents(pi.dom, Some(remaining)); + } else { + set_parents(pi.dom, None); + free_dead_node(pi.dom); + } + let img_ref_ptr = &pi.img_ref as *const Parents; + if let Some(remaining) = (*img_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(pi.img), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(pi.img), None); + free_dead_node(DAGPtr::Lam(pi.img)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Let(link) => { + let let_node = &*link.as_ptr(); + let typ_ref_ptr = &let_node.typ_ref as *const Parents; + if let Some(remaining) = (*typ_ref_ptr).unlink_node() { + set_parents(let_node.typ, Some(remaining)); + } else { + set_parents(let_node.typ, None); + free_dead_node(let_node.typ); + } + let val_ref_ptr = &let_node.val_ref as *const Parents; + if let Some(remaining) = (*val_ref_ptr).unlink_node() { + set_parents(let_node.val, Some(remaining)); + } else { + set_parents(let_node.val, None); + free_dead_node(let_node.val); + } + let bod_ref_ptr = &let_node.bod_ref as *const Parents; + if let Some(remaining) = (*bod_ref_ptr).unlink_node() { + set_parents(DAGPtr::Lam(let_node.bod), Some(remaining)); + } else { + set_parents(DAGPtr::Lam(let_node.bod), None); + free_dead_node(DAGPtr::Lam(let_node.bod)); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Proj(link) => { + let proj = &*link.as_ptr(); + let expr_ref_ptr = &proj.expr_ref as *const Parents; + if let Some(remaining) = (*expr_ref_ptr).unlink_node() { + set_parents(proj.expr, Some(remaining)); + } else { + set_parents(proj.expr, None); + free_dead_node(proj.expr); + } + drop(Box::from_raw(link.as_ptr())); + }, + DAGPtr::Var(link) => { + let var = &*link.as_ptr(); + if let BinderPtr::Free = var.binder { + drop(Box::from_raw(link.as_ptr())); + } + }, + DAGPtr::Sort(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Cnst(link) => drop(Box::from_raw(link.as_ptr())), + DAGPtr::Lit(link) => drop(Box::from_raw(link.as_ptr())), + } + } +} + +// ============================================================================ +// Lambda reduction +// ============================================================================ + +/// Contract a lambda redex: (Fun dom (Lam bod var)) arg → [arg/var]bod. +pub fn reduce_lam(redex: NonNull, lam: NonNull) -> DAGPtr { + unsafe { + let app = &*redex.as_ptr(); + let lambda = &*lam.as_ptr(); + let var = &lambda.var; + let arg = app.arg; + + if DLL::is_singleton(lambda.parents) { + if DLL::is_empty(var.parents) { + return lambda.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), arg); + return lambda.bod; + } + + if DLL::is_empty(var.parents) { + return lambda.bod; + } + + // General case: upcopy arg through var's parents + for parent in DLL::iter_option(var.parents) { + upcopy(arg, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lambda.bod + } +} + +/// Contract a let redex: Let(typ, val, Lam(bod, var)) → [val/var]bod. +pub fn reduce_let(let_node: NonNull) -> DAGPtr { + unsafe { + let ln = &*let_node.as_ptr(); + let lam = &*ln.bod.as_ptr(); + let var = &lam.var; + let val = ln.val; + + if DLL::is_singleton(lam.parents) { + if DLL::is_empty(var.parents) { + return lam.bod; + } + replace_child(DAGPtr::Var(NonNull::from(var)), val); + return lam.bod; + } + + if DLL::is_empty(var.parents) { + return lam.bod; + } + + for parent in DLL::iter_option(var.parents) { + upcopy(val, *parent); + } + for parent in DLL::iter_option(var.parents) { + clean_up(parent); + } + lam.bod + } +} diff --git a/src/ix/kernel/whnf.rs b/src/ix/kernel/whnf.rs new file mode 100644 index 00000000..4fdde07a --- /dev/null +++ b/src/ix/kernel/whnf.rs @@ -0,0 +1,1420 @@ +use core::ptr::NonNull; + +use crate::ix::env::*; +use crate::lean::nat::Nat; +use num_bigint::BigUint; + +use super::convert::{from_expr, to_expr}; +use super::dag::*; +use super::level::{simplify, subst_level}; +use super::upcopy::{reduce_lam, reduce_let}; + + +// ============================================================================ +// Expression helpers (inst, unfold_apps, foldl_apps, subst_expr_levels) +// ============================================================================ + +/// Instantiate bound variables: `body[0 := substs[0], 1 := substs[1], ...]`. +/// `substs[0]` replaces `Bvar(0)` (innermost). +pub fn inst(body: &Expr, substs: &[Expr]) -> Expr { + if substs.is_empty() { + return body.clone(); + } + inst_aux(body, substs, 0) +} + +fn inst_aux(e: &Expr, substs: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Bvar(idx, _) => { + let idx_u64 = idx.to_u64().unwrap_or(u64::MAX); + if idx_u64 >= offset { + let adjusted = (idx_u64 - offset) as usize; + if adjusted < substs.len() { + return substs[adjusted].clone(); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = inst_aux(f, substs, offset); + let a2 = inst_aux(a, substs, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = inst_aux(t, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = inst_aux(t, substs, offset); + let v2 = inst_aux(v, substs, offset); + let b2 = inst_aux(b, substs, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = inst_aux(s, substs, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = inst_aux(inner, substs, offset); + Expr::mdata(kvs.clone(), inner2) + }, + // Terminals with no bound vars + ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Abstract: replace free variable `fvar` with `Bvar(offset)` in `e`. +pub fn abstr(e: &Expr, fvars: &[Expr]) -> Expr { + if fvars.is_empty() { + return e.clone(); + } + abstr_aux(e, fvars, 0) +} + +fn abstr_aux(e: &Expr, fvars: &[Expr], offset: u64) -> Expr { + match e.as_data() { + ExprData::Fvar(..) => { + for (i, fv) in fvars.iter().enumerate().rev() { + if e == fv { + return Expr::bvar(Nat::from(i as u64 + offset)); + } + } + e.clone() + }, + ExprData::App(f, a, _) => { + let f2 = abstr_aux(f, fvars, offset); + let a2 = abstr_aux(a, fvars, offset); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = abstr_aux(t, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = abstr_aux(t, fvars, offset); + let v2 = abstr_aux(v, fvars, offset); + let b2 = abstr_aux(b, fvars, offset + 1); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = abstr_aux(s, fvars, offset); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = abstr_aux(inner, fvars, offset); + Expr::mdata(kvs.clone(), inner2) + }, + ExprData::Bvar(..) + | ExprData::Sort(..) + | ExprData::Const(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Decompose `f a1 a2 ... an` into `(f, [a1, a2, ..., an])`. +pub fn unfold_apps(e: &Expr) -> (Expr, Vec) { + let mut args = Vec::new(); + let mut cursor = e.clone(); + loop { + match cursor.as_data() { + ExprData::App(f, a, _) => { + args.push(a.clone()); + cursor = f.clone(); + }, + _ => break, + } + } + args.reverse(); + (cursor, args) +} + +/// Reconstruct `f a1 a2 ... an`. +pub fn foldl_apps(mut fun: Expr, args: impl Iterator) -> Expr { + for arg in args { + fun = Expr::app(fun, arg); + } + fun +} + +/// Substitute universe level parameters in an expression. +pub fn subst_expr_levels( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + if params.is_empty() { + return e.clone(); + } + subst_expr_levels_aux(e, params, values) +} + +fn subst_expr_levels_aux( + e: &Expr, + params: &[Name], + values: &[Level], +) -> Expr { + match e.as_data() { + ExprData::Sort(level, _) => { + Expr::sort(subst_level(level, params, values)) + }, + ExprData::Const(name, levels, _) => { + let new_levels: Vec = + levels.iter().map(|l| subst_level(l, params, values)).collect(); + Expr::cnst(name.clone(), new_levels) + }, + ExprData::App(f, a, _) => { + let f2 = subst_expr_levels_aux(f, params, values); + let a2 = subst_expr_levels_aux(a, params, values); + Expr::app(f2, a2) + }, + ExprData::Lam(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::lam(n.clone(), t2, b2, bi.clone()) + }, + ExprData::ForallE(n, t, b, bi, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::all(n.clone(), t2, b2, bi.clone()) + }, + ExprData::LetE(n, t, v, b, nd, _) => { + let t2 = subst_expr_levels_aux(t, params, values); + let v2 = subst_expr_levels_aux(v, params, values); + let b2 = subst_expr_levels_aux(b, params, values); + Expr::letE(n.clone(), t2, v2, b2, *nd) + }, + ExprData::Proj(n, i, s, _) => { + let s2 = subst_expr_levels_aux(s, params, values); + Expr::proj(n.clone(), i.clone(), s2) + }, + ExprData::Mdata(kvs, inner, _) => { + let inner2 = subst_expr_levels_aux(inner, params, values); + Expr::mdata(kvs.clone(), inner2) + }, + // No levels to substitute + ExprData::Bvar(..) + | ExprData::Fvar(..) + | ExprData::Mvar(..) + | ExprData::Lit(..) => e.clone(), + } +} + +/// Check if an expression has any loose bound variables above `offset`. +pub fn has_loose_bvars(e: &Expr) -> bool { + has_loose_bvars_aux(e, 0) +} + +fn has_loose_bvars_aux(e: &Expr, depth: u64) -> bool { + match e.as_data() { + ExprData::Bvar(idx, _) => idx.to_u64().unwrap_or(u64::MAX) >= depth, + ExprData::App(f, a, _) => { + has_loose_bvars_aux(f, depth) || has_loose_bvars_aux(a, depth) + }, + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_loose_bvars_aux(t, depth) || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_loose_bvars_aux(t, depth) + || has_loose_bvars_aux(v, depth) + || has_loose_bvars_aux(b, depth + 1) + }, + ExprData::Proj(_, _, s, _) => has_loose_bvars_aux(s, depth), + ExprData::Mdata(_, inner, _) => has_loose_bvars_aux(inner, depth), + _ => false, + } +} + +/// Check if expression contains any free variables (Fvar). +pub fn has_fvars(e: &Expr) -> bool { + match e.as_data() { + ExprData::Fvar(..) => true, + ExprData::App(f, a, _) => has_fvars(f) || has_fvars(a), + ExprData::Lam(_, t, b, _, _) | ExprData::ForallE(_, t, b, _, _) => { + has_fvars(t) || has_fvars(b) + }, + ExprData::LetE(_, t, v, b, _, _) => { + has_fvars(t) || has_fvars(v) || has_fvars(b) + }, + ExprData::Proj(_, _, s, _) => has_fvars(s), + ExprData::Mdata(_, inner, _) => has_fvars(inner), + _ => false, + } +} + +// ============================================================================ +// Name helpers +// ============================================================================ + +pub(crate) fn mk_name2(a: &str, b: &str) -> Name { + Name::str(Name::str(Name::anon(), a.into()), b.into()) +} + +// ============================================================================ +// WHNF +// ============================================================================ + +/// Weak head normal form reduction. +/// +/// Uses DAG-based reduction internally: converts Expr to DAG, reduces using +/// BUBS (reduce_lam/reduce_let) for beta/zeta, falls back to Expr level for +/// iota/quot/nat/projection, and uses DAG-level splicing for delta. +pub fn whnf(e: &Expr, env: &Env) -> Expr { + let mut dag = from_expr(e); + whnf_dag(&mut dag, env); + let result = to_expr(&dag); + free_dag(dag); + result +} + +/// Trail-based WHNF on DAG. Walks down the App spine collecting a trail, +/// then dispatches on the head node. +fn whnf_dag(dag: &mut DAG, env: &Env) { + loop { + // Build trail of App nodes by walking down the fun chain + let mut trail: Vec> = Vec::new(); + let mut cursor = dag.head; + + loop { + match cursor { + DAGPtr::App(app) => { + trail.push(app); + cursor = unsafe { (*app.as_ptr()).fun }; + }, + _ => break, + } + } + + match cursor { + // Beta: Fun at head with args on trail + DAGPtr::Fun(fun_ptr) if !trail.is_empty() => { + let app = trail.pop().unwrap(); + let lam = unsafe { (*fun_ptr.as_ptr()).img }; + let result = reduce_lam(app, lam); + set_dag_head(dag, result, &trail); + continue; + }, + + // Zeta: Let at head + DAGPtr::Let(let_ptr) => { + let result = reduce_let(let_ptr); + set_dag_head(dag, result, &trail); + continue; + }, + + // Const: try iota, quot, nat, then delta + DAGPtr::Cnst(_) => { + // Try iota, quot, nat at Expr level + if try_expr_reductions(dag, env) { + continue; + } + // Try delta (definition unfolding) on DAG + if try_dag_delta(dag, &trail, env) { + continue; + } + return; // stuck + }, + + // Proj: try projection reduction (Expr-level fallback) + DAGPtr::Proj(_) => { + if try_expr_reductions(dag, env) { + continue; + } + return; // stuck + }, + + // Sort: simplify level in place + DAGPtr::Sort(sort_ptr) => { + unsafe { + let sort = &mut *sort_ptr.as_ptr(); + sort.level = simplify(&sort.level); + } + return; + }, + + // Mdata: strip metadata (Expr-level fallback) + DAGPtr::Lit(_) => { + // Check if this is a Nat literal that could be a Nat.succ application + // by trying Expr-level reductions (which handles nat ops) + if !trail.is_empty() { + if try_expr_reductions(dag, env) { + continue; + } + } + return; + }, + + // Everything else (Var, Pi, Lam without args, etc.): already WHNF + _ => return, + } + } +} + +/// Set the DAG head after a reduction step. +/// If trail is empty, the result becomes the new head. +/// If trail is non-empty, splice result into the innermost remaining App. +fn set_dag_head( + dag: &mut DAG, + result: DAGPtr, + trail: &[NonNull], +) { + if trail.is_empty() { + dag.head = result; + } else { + unsafe { + (*trail.last().unwrap().as_ptr()).fun = result; + } + dag.head = DAGPtr::App(trail[0]); + } +} + +/// Try iota/quot/nat/projection reductions at Expr level. +/// Converts current DAG to Expr, attempts reduction, converts back if +/// successful. +fn try_expr_reductions(dag: &mut DAG, env: &Env) -> bool { + let current_expr = to_expr(&DAG { head: dag.head }); + + let (head, args) = unfold_apps(¤t_expr); + + let reduced = match head.as_data() { + ExprData::Const(name, levels, _) => { + // Try iota (recursor) reduction + if let Some(result) = try_reduce_rec(name, levels, &args, env) { + Some(result) + } + // Try quotient reduction + else if let Some(result) = try_reduce_quot(name, &args, env) { + Some(result) + } + // Try nat reduction + else if let Some(result) = + try_reduce_nat(¤t_expr, env) + { + Some(result) + } else { + None + } + }, + ExprData::Proj(type_name, idx, structure, _) => { + reduce_proj(type_name, idx, structure, env) + .map(|result| foldl_apps(result, args.into_iter())) + }, + ExprData::Mdata(_, inner, _) => { + Some(foldl_apps(inner.clone(), args.into_iter())) + }, + _ => None, + }; + + if let Some(result_expr) = reduced { + let result_dag = from_expr(&result_expr); + dag.head = result_dag.head; + true + } else { + false + } +} + +/// Try delta (definition) unfolding on DAG. +/// Looks up the constant, substitutes universe levels in the definition body, +/// converts it to a DAG, and splices it into the current DAG. +fn try_dag_delta( + dag: &mut DAG, + trail: &[NonNull], + env: &Env, +) -> bool { + // Extract constant info from head + let cnst_ref = match dag_head_past_trail(dag, trail) { + DAGPtr::Cnst(cnst) => unsafe { &*cnst.as_ptr() }, + _ => return false, + }; + + let ci = match env.get(&cnst_ref.name) { + Some(c) => c, + None => return false, + }; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) + if d.hints != ReducibilityHints::Opaque => + { + (&d.cnst.level_params, &d.value) + }, + _ => return false, + }; + + if cnst_ref.levels.len() != def_params.len() { + return false; + } + + // Substitute levels at Expr level, then convert to DAG + let val = subst_expr_levels(def_value, def_params, &cnst_ref.levels); + let body_dag = from_expr(&val); + + // Splice body into the working DAG + set_dag_head(dag, body_dag.head, trail); + true +} + +/// Get the head node past the trail (the non-App node at the bottom). +fn dag_head_past_trail( + dag: &DAG, + trail: &[NonNull], +) -> DAGPtr { + if trail.is_empty() { + dag.head + } else { + unsafe { (*trail.last().unwrap().as_ptr()).fun } + } +} + +/// Try to unfold a definition at the head. +pub fn try_unfold_def(e: &Expr, env: &Env) -> Option { + let (head, args) = unfold_apps(e); + let (name, levels) = match head.as_data() { + ExprData::Const(name, levels, _) => (name, levels), + _ => return None, + }; + + let ci = env.get(name)?; + let (def_params, def_value) = match ci { + ConstantInfo::DefnInfo(d) => { + if d.hints == ReducibilityHints::Opaque { + return None; + } + (&d.cnst.level_params, &d.value) + }, + _ => return None, + }; + + if levels.len() != def_params.len() { + return None; + } + + let val = subst_expr_levels(def_value, def_params, levels); + Some(foldl_apps(val, args.into_iter())) +} + +/// Try to reduce a recursor application (iota reduction). +fn try_reduce_rec( + name: &Name, + levels: &[Level], + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let rec = match ci { + ConstantInfo::RecInfo(r) => r, + _ => return None, + }; + + let major_idx = rec.num_params.to_u64().unwrap() as usize + + rec.num_motives.to_u64().unwrap() as usize + + rec.num_minors.to_u64().unwrap() as usize + + rec.num_indices.to_u64().unwrap() as usize; + + let major = args.get(major_idx)?; + + // WHNF the major premise + let major_whnf = whnf(major, env); + + // Handle nat literal → constructor + let major_ctor = match major_whnf.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => nat_lit_to_constructor(n), + _ => major_whnf.clone(), + }; + + let (ctor_head, ctor_args) = unfold_apps(&major_ctor); + + // Find the matching rec rule + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + let rule = rec.rules.iter().find(|r| &r.ctor == ctor_name)?; + + let n_fields = rule.n_fields.to_u64().unwrap() as usize; + let num_params = rec.num_params.to_u64().unwrap() as usize; + let num_motives = rec.num_motives.to_u64().unwrap() as usize; + let num_minors = rec.num_minors.to_u64().unwrap() as usize; + + // The constructor args may have extra params for nested inductives + let ctor_args_wo_params = + if ctor_args.len() >= n_fields { + &ctor_args[ctor_args.len() - n_fields..] + } else { + return None; + }; + + // Substitute universe levels in the rule's RHS + let rhs = subst_expr_levels( + &rule.rhs, + &rec.cnst.level_params, + levels, + ); + + // Apply: params, motives, minors + let prefix_count = num_params + num_motives + num_minors; + let mut result = rhs; + for arg in args.iter().take(prefix_count) { + result = Expr::app(result, arg.clone()); + } + + // Apply constructor fields + for arg in ctor_args_wo_params { + result = Expr::app(result, arg.clone()); + } + + // Apply remaining args after major + for arg in args.iter().skip(major_idx + 1) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Convert a Nat literal to its constructor form. +fn nat_lit_to_constructor(n: &Nat) -> Expr { + if n.0 == BigUint::ZERO { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } else { + let pred = Nat(n.0.clone() - BigUint::from(1u64)); + let pred_expr = Expr::lit(Literal::NatVal(pred)); + Expr::app(Expr::cnst(mk_name2("Nat", "succ"), vec![]), pred_expr) + } +} + +/// Convert a string literal to its constructor form: +/// `"hello"` → `String.mk (List.cons 'h' (List.cons 'e' ... List.nil))` +/// where chars are represented as `Char.ofNat n`. +fn string_lit_to_constructor(s: &str) -> Expr { + let list_name = Name::str(Name::anon(), "List".into()); + let char_name = Name::str(Name::anon(), "Char".into()); + let char_type = Expr::cnst(char_name.clone(), vec![]); + + // Build the list from right to left + // List.nil.{0} : List Char + let nil = Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "nil".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ); + + let result = s.chars().rev().fold(nil, |acc, c| { + let char_val = Expr::app( + Expr::cnst(Name::str(char_name.clone(), "ofNat".into()), vec![]), + Expr::lit(Literal::NatVal(Nat::from(c as u64))), + ); + // List.cons.{0} Char char_val acc + Expr::app( + Expr::app( + Expr::app( + Expr::cnst( + Name::str(list_name.clone(), "cons".into()), + vec![Level::succ(Level::zero())], + ), + char_type.clone(), + ), + char_val, + ), + acc, + ) + }); + + // String.mk list + Expr::app( + Expr::cnst( + Name::str(Name::str(Name::anon(), "String".into()), "mk".into()), + vec![], + ), + result, + ) +} + +/// Try to reduce a projection. +fn reduce_proj( + _type_name: &Name, + idx: &Nat, + structure: &Expr, + env: &Env, +) -> Option { + let structure_whnf = whnf(structure, env); + + // Handle string literal → constructor + let structure_ctor = match structure_whnf.as_data() { + ExprData::Lit(Literal::StrVal(s), _) => { + string_lit_to_constructor(s) + }, + _ => structure_whnf, + }; + + let (ctor_head, ctor_args) = unfold_apps(&structure_ctor); + + let ctor_name = match ctor_head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + // Look up constructor to get num_params + let ci = env.get(ctor_name)?; + let num_params = match ci { + ConstantInfo::CtorInfo(c) => c.num_params.to_u64().unwrap() as usize, + _ => return None, + }; + + let field_idx = num_params + idx.to_u64().unwrap() as usize; + ctor_args.get(field_idx).cloned() +} + +/// Try to reduce a quotient operation. +fn try_reduce_quot( + name: &Name, + args: &[Expr], + env: &Env, +) -> Option { + let ci = env.get(name)?; + let kind = match ci { + ConstantInfo::QuotInfo(q) => &q.kind, + _ => return None, + }; + + let (qmk_idx, rest_idx) = match kind { + QuotKind::Lift => (5, 6), + QuotKind::Ind => (4, 5), + _ => return None, + }; + + let qmk = args.get(qmk_idx)?; + let qmk_whnf = whnf(qmk, env); + + // Check that the head is Quot.mk + let (qmk_head, _) = unfold_apps(&qmk_whnf); + match qmk_head.as_data() { + ExprData::Const(n, _, _) if *n == mk_name2("Quot", "mk") => {}, + _ => return None, + } + + let f = args.get(3)?; + + // Extract the argument of Quot.mk + let qmk_arg = match qmk_whnf.as_data() { + ExprData::App(_, arg, _) => arg, + _ => return None, + }; + + let mut result = Expr::app(f.clone(), qmk_arg.clone()); + for arg in args.iter().skip(rest_idx) { + result = Expr::app(result, arg.clone()); + } + + Some(result) +} + +/// Try to reduce nat operations. +fn try_reduce_nat(e: &Expr, env: &Env) -> Option { + if has_fvars(e) { + return None; + } + + let (head, args) = unfold_apps(e); + let name = match head.as_data() { + ExprData::Const(name, _, _) => name, + _ => return None, + }; + + match args.len() { + 1 => { + if *name == mk_name2("Nat", "succ") { + let arg_whnf = whnf(&args[0], env); + let n = get_nat_value(&arg_whnf)?; + Some(Expr::lit(Literal::NatVal(Nat(n + BigUint::from(1u64))))) + } else { + None + } + }, + 2 => { + let a_whnf = whnf(&args[0], env); + let b_whnf = whnf(&args[1], env); + let a = get_nat_value(&a_whnf)?; + let b = get_nat_value(&b_whnf)?; + + let result = if *name == mk_name2("Nat", "add") { + Some(Expr::lit(Literal::NatVal(Nat(a + b)))) + } else if *name == mk_name2("Nat", "sub") { + Some(Expr::lit(Literal::NatVal(Nat(if a >= b { + a - b + } else { + BigUint::ZERO + })))) + } else if *name == mk_name2("Nat", "mul") { + Some(Expr::lit(Literal::NatVal(Nat(a * b)))) + } else if *name == mk_name2("Nat", "div") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + BigUint::ZERO + } else { + a / b + })))) + } else if *name == mk_name2("Nat", "mod") { + Some(Expr::lit(Literal::NatVal(Nat(if b == BigUint::ZERO { + a + } else { + a % b + })))) + } else if *name == mk_name2("Nat", "beq") { + bool_to_expr(a == b) + } else if *name == mk_name2("Nat", "ble") { + bool_to_expr(a <= b) + } else if *name == mk_name2("Nat", "pow") { + let exp = u32::try_from(&b).unwrap_or(u32::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a.pow(exp))))) + } else if *name == mk_name2("Nat", "land") { + Some(Expr::lit(Literal::NatVal(Nat(a & b)))) + } else if *name == mk_name2("Nat", "lor") { + Some(Expr::lit(Literal::NatVal(Nat(a | b)))) + } else if *name == mk_name2("Nat", "xor") { + Some(Expr::lit(Literal::NatVal(Nat(a ^ b)))) + } else if *name == mk_name2("Nat", "shiftLeft") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a << shift)))) + } else if *name == mk_name2("Nat", "shiftRight") { + let shift = u64::try_from(&b).unwrap_or(u64::MAX); + Some(Expr::lit(Literal::NatVal(Nat(a >> shift)))) + } else if *name == mk_name2("Nat", "blt") { + bool_to_expr(a < b) + } else { + None + }; + result + }, + _ => None, + } +} + +fn get_nat_value(e: &Expr) -> Option { + match e.as_data() { + ExprData::Lit(Literal::NatVal(n), _) => Some(n.0.clone()), + ExprData::Const(name, _, _) if *name == mk_name2("Nat", "zero") => { + Some(BigUint::ZERO) + }, + _ => None, + } +} + +fn bool_to_expr(b: bool) -> Option { + let name = if b { + mk_name2("Bool", "true") + } else { + mk_name2("Bool", "false") + }; + Some(Expr::cnst(name, vec![])) +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + fn mk_name(s: &str) -> Name { + Name::str(Name::anon(), s.into()) + } + + fn nat_type() -> Expr { + Expr::cnst(mk_name("Nat"), vec![]) + } + + fn nat_zero() -> Expr { + Expr::cnst(mk_name2("Nat", "zero"), vec![]) + } + + #[test] + fn test_inst_bvar() { + let body = Expr::bvar(Nat::from(0)); + let arg = nat_zero(); + let result = inst(&body, &[arg.clone()]); + assert_eq!(result, arg); + } + + #[test] + fn test_inst_nested() { + // body = Lam(_, Nat, Bvar(1)) — references outer binder + // After inst with [zero], should become Lam(_, Nat, zero) + let body = Expr::lam( + Name::anon(), + nat_type(), + Expr::bvar(Nat::from(1)), + BinderInfo::Default, + ); + let result = inst(&body, &[nat_zero()]); + let expected = Expr::lam( + Name::anon(), + nat_type(), + nat_zero(), + BinderInfo::Default, + ); + assert_eq!(result, expected); + } + + #[test] + fn test_unfold_apps() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = Expr::cnst(mk_name2("Nat", "succ"), vec![]); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + assert_eq!(head, f); + assert_eq!(args.len(), 2); + assert_eq!(args[0], a); + assert_eq!(args[1], b); + } + + #[test] + fn test_beta_reduce_identity() { + // (fun x : Nat => x) Nat.zero + let id = Expr::lam( + Name::str(Name::anon(), "x".into()), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let e = Expr::app(id, nat_zero()); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_zeta_reduce() { + // let x : Nat := Nat.zero in x + let e = Expr::letE( + Name::str(Name::anon(), "x".into()), + nat_type(), + nat_zero(), + Expr::bvar(Nat::from(0)), + false, + ); + let env = Env::default(); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Delta reduction + // ========================================================================== + + fn mk_defn_env(name: &str, value: Expr, typ: Expr) -> Env { + let mut env = Env::default(); + let n = mk_name(name); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ, + }, + value, + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![n], + }), + ); + env + } + + #[test] + fn test_delta_unfold() { + // def myZero := Nat.zero + // whnf(myZero) = Nat.zero + let env = mk_defn_env("myZero", nat_zero(), nat_type()); + let e = Expr::cnst(mk_name("myZero"), vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_delta_opaque_no_unfold() { + // An opaque definition should NOT unfold + let mut env = Env::default(); + let n = mk_name("opaqueVal"); + env.insert( + n.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: n.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Opaque, + safety: DefinitionSafety::Safe, + all: vec![n.clone()], + }), + ); + let e = Expr::cnst(n.clone(), vec![]); + let result = whnf(&e, &env); + // Should still be the constant, not unfolded + assert_eq!(result, e); + } + + #[test] + fn test_delta_chained() { + // def a := Nat.zero, def b := a => whnf(b) = Nat.zero + let mut env = Env::default(); + let a = mk_name("a"); + let b = mk_name("b"); + env.insert( + a.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: a.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: nat_zero(), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![a.clone()], + }), + ); + env.insert( + b.clone(), + ConstantInfo::DefnInfo(DefinitionVal { + cnst: ConstantVal { + name: b.clone(), + level_params: vec![], + typ: nat_type(), + }, + value: Expr::cnst(a, vec![]), + hints: ReducibilityHints::Abbrev, + safety: DefinitionSafety::Safe, + all: vec![b.clone()], + }), + ); + let e = Expr::cnst(b, vec![]); + let result = whnf(&e, &env); + assert_eq!(result, nat_zero()); + } + + // ========================================================================== + // Nat arithmetic reduction + // ========================================================================== + + fn nat_lit(n: u64) -> Expr { + Expr::lit(Literal::NatVal(Nat::from(n))) + } + + #[test] + fn test_nat_add() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "add"), vec![]), nat_lit(3)), + nat_lit(4), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(7)); + } + + #[test] + fn test_nat_sub_underflow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "sub"), vec![]), nat_lit(3)), + nat_lit(10), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mul() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mul"), vec![]), nat_lit(6)), + nat_lit(7), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(42)); + } + + #[test] + fn test_nat_div() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(3)); + } + + #[test] + fn test_nat_div_by_zero() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "div"), vec![]), nat_lit(10)), + nat_lit(0), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(0)); + } + + #[test] + fn test_nat_mod() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "mod"), vec![]), nat_lit(10)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, nat_lit(1)); + } + + #[test] + fn test_nat_beq_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_beq_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "beq"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_ble_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(3)), + nat_lit(5), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_ble_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "ble"), vec![]), nat_lit(5)), + nat_lit(3), + ); + let result = whnf(&e, &env); + assert_eq!(result, Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + #[test] + fn test_nat_pow() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "pow"), vec![]), nat_lit(2)), + nat_lit(10), + ); + assert_eq!(whnf(&e, &env), nat_lit(1024)); + } + + #[test] + fn test_nat_land() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "land"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1000)); + } + + #[test] + fn test_nat_lor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "lor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b1110)); + } + + #[test] + fn test_nat_xor() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "xor"), vec![]), nat_lit(0b1100)), + nat_lit(0b1010), + ); + assert_eq!(whnf(&e, &env), nat_lit(0b0110)); + } + + #[test] + fn test_nat_shift_left() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftLeft"), vec![]), nat_lit(1)), + nat_lit(8), + ); + assert_eq!(whnf(&e, &env), nat_lit(256)); + } + + #[test] + fn test_nat_shift_right() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "shiftRight"), vec![]), nat_lit(256)), + nat_lit(4), + ); + assert_eq!(whnf(&e, &env), nat_lit(16)); + } + + #[test] + fn test_nat_blt_true() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(3)), + nat_lit(5), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "true"), vec![])); + } + + #[test] + fn test_nat_blt_false() { + let env = Env::default(); + let e = Expr::app( + Expr::app(Expr::cnst(mk_name2("Nat", "blt"), vec![]), nat_lit(5)), + nat_lit(3), + ); + assert_eq!(whnf(&e, &env), Expr::cnst(mk_name2("Bool", "false"), vec![])); + } + + // ========================================================================== + // Sort simplification in WHNF + // ========================================================================== + + #[test] + fn test_string_lit_proj_reduces() { + // Build an env with String, String.mk ctor, List, List.cons, List.nil, Char + let mut env = Env::default(); + let string_name = mk_name("String"); + let string_mk = mk_name2("String", "mk"); + let list_name = mk_name("List"); + let char_name = mk_name("Char"); + + // String : Sort 1 + env.insert( + string_name.clone(), + ConstantInfo::InductInfo(InductiveVal { + cnst: ConstantVal { + name: string_name.clone(), + level_params: vec![], + typ: Expr::sort(Level::succ(Level::zero())), + }, + num_params: Nat::from(0u64), + num_indices: Nat::from(0u64), + all: vec![string_name.clone()], + ctors: vec![string_mk.clone()], + num_nested: Nat::from(0u64), + is_rec: false, + is_unsafe: false, + is_reflexive: false, + }), + ); + // String.mk : List Char → String (1 field, 0 params) + let list_char = Expr::app( + Expr::cnst(list_name, vec![Level::succ(Level::zero())]), + Expr::cnst(char_name, vec![]), + ); + env.insert( + string_mk.clone(), + ConstantInfo::CtorInfo(ConstructorVal { + cnst: ConstantVal { + name: string_mk, + level_params: vec![], + typ: Expr::all( + mk_name("data"), + list_char, + Expr::cnst(string_name.clone(), vec![]), + BinderInfo::Default, + ), + }, + induct: string_name.clone(), + cidx: Nat::from(0u64), + num_params: Nat::from(0u64), + num_fields: Nat::from(1u64), + is_unsafe: false, + }), + ); + + // Proj String 0 "hi" should reduce (not return None) + let proj = Expr::proj( + string_name, + Nat::from(0u64), + Expr::lit(Literal::StrVal("hi".into())), + ); + let result = whnf(&proj, &env); + // The result should NOT be a Proj anymore (it should have reduced) + assert!( + !matches!(result.as_data(), ExprData::Proj(..)), + "String projection should reduce, got: {:?}", + result + ); + } + + #[test] + fn test_whnf_sort_simplifies() { + // Sort(max 0 u) should simplify to Sort(u) + let env = Env::default(); + let u = Level::param(mk_name("u")); + let e = Expr::sort(Level::max(Level::zero(), u.clone())); + let result = whnf(&e, &env); + assert_eq!(result, Expr::sort(u)); + } + + // ========================================================================== + // Already-WHNF terms + // ========================================================================== + + #[test] + fn test_whnf_sort_unchanged() { + let env = Env::default(); + let e = Expr::sort(Level::zero()); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_lambda_unchanged() { + // A lambda without applied arguments is already WHNF + let env = Env::default(); + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + #[test] + fn test_whnf_pi_unchanged() { + let env = Env::default(); + let e = Expr::all( + mk_name("x"), + nat_type(), + nat_type(), + BinderInfo::Default, + ); + let result = whnf(&e, &env); + assert_eq!(result, e); + } + + // ========================================================================== + // Helper function tests + // ========================================================================== + + #[test] + fn test_has_loose_bvars_true() { + assert!(has_loose_bvars(&Expr::bvar(Nat::from(0)))); + } + + #[test] + fn test_has_loose_bvars_false_under_binder() { + // fun x : Nat => x — bvar(0) is bound, not loose + let e = Expr::lam( + mk_name("x"), + nat_type(), + Expr::bvar(Nat::from(0)), + BinderInfo::Default, + ); + assert!(!has_loose_bvars(&e)); + } + + #[test] + fn test_has_loose_bvars_const() { + assert!(!has_loose_bvars(&nat_zero())); + } + + #[test] + fn test_has_fvars_true() { + assert!(has_fvars(&Expr::fvar(mk_name("x")))); + } + + #[test] + fn test_has_fvars_false() { + assert!(!has_fvars(&nat_zero())); + } + + #[test] + fn test_foldl_apps_roundtrip() { + let f = Expr::cnst(mk_name("f"), vec![]); + let a = nat_zero(); + let b = nat_type(); + let e = Expr::app(Expr::app(f.clone(), a.clone()), b.clone()); + let (head, args) = unfold_apps(&e); + let rebuilt = foldl_apps(head, args.into_iter()); + assert_eq!(rebuilt, e); + } + + #[test] + fn test_abstr_simple() { + // abstr(fvar("x"), [fvar("x")]) = bvar(0) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&x, &[x.clone()]); + assert_eq!(result, Expr::bvar(Nat::from(0))); + } + + #[test] + fn test_abstr_not_found() { + // abstr(Nat.zero, [fvar("x")]) = Nat.zero (unchanged) + let x = Expr::fvar(mk_name("x")); + let result = abstr(&nat_zero(), &[x]); + assert_eq!(result, nat_zero()); + } + + #[test] + fn test_subst_expr_levels_simple() { + // Sort(u) with u := 0 => Sort(0) + let u_name = mk_name("u"); + let e = Expr::sort(Level::param(u_name.clone())); + let result = subst_expr_levels(&e, &[u_name], &[Level::zero()]); + assert_eq!(result, Expr::sort(Level::zero())); + } +}