From f4dba72408b0f0b64ef18ec8e1e02b93a9a68686 Mon Sep 17 00:00:00 2001 From: keyboardDrummer-bot Date: Wed, 6 May 2026 17:39:35 +0000 Subject: [PATCH] Use mapStmtExprChildrenM to simplify transformExpr in LiftImperativeExpressions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add mapStmtExprChildrenM and mapStmtExprPrePostM to MapStmtExpr.lean: - mapStmtExprChildrenM: maps a function over immediate children of a node (one level, no recursion), with optional right-to-left traversal - mapStmtExprPrePostM: full traversal with pre/post hooks Refactor transformExpr in LiftImperativeExpressions.lean: - Remove explicit PrimitiveOp arm - Replace catch-all with mapStmtExprChildrenM transformExpr call - The pass no longer needs to enumerate language features it doesn't relate to (like PrimitiveOp) — new constructors are handled automatically by the generic traversal - Mark transformExpr as partial since mapStmtExprChildrenM is opaque to the termination checker Closes #40 --- .../Laurel/LiftImperativeExpressions.lean | 45 ++-- Strata/Languages/Laurel/MapStmtExpr.lean | 208 ++++++++++++++++++ 2 files changed, 228 insertions(+), 25 deletions(-) diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index 65af0996d5..4b2d91d2c2 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -10,6 +10,7 @@ public import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator public import Strata.Languages.Laurel.LaurelTypes public import Strata.Languages.Core.Verifier public import Strata.DL.Util.Map +public import Strata.Languages.Laurel.MapStmtExpr import Strata.Util.Tactics namespace Strata @@ -216,20 +217,22 @@ private def liftAssignExpr (targets : List VariableMd) (seqValue : StmtExprMd) setSubst varName snapshotName | _ => pure () -mutual /-- Process an expression in expression context, traversing arguments right to left. Assignments are lifted to prependedStmts and replaced with snapshot variable references. + +Only constructors that require custom lifting logic are handled explicitly. +All other constructors (like `PrimitiveOp`) are traversed generically via +`mapStmtExprChildrenM`, so the pass doesn't need to enumerate language features +it doesn't relate to. -/ -def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do +partial def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do match expr with | AstNode.mk val source => match val with | .Var (.Local name) => return ⟨.Var (.Local (← getSubst name)), source⟩ - | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ => return expr - | .Hole false (some holeType) => -- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc) let holeVar ← freshCondVar @@ -262,11 +265,6 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do return resultExpr - | .PrimitiveOp op args => - -- Process arguments right to left - let seqArgs ← args.reverse.mapM transformExpr - return ⟨.PrimitiveOp op seqArgs.reverse, source⟩ - | .StaticCall callee args => let model := (← get).model let seqArgs ← args.reverse.mapM transformExpr @@ -322,13 +320,8 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do prepend (bare (.Var (.Declare ⟨condVar, condType⟩))) return bare (.Var (.Local condVar)) else - -- No assignments in branches — recurse normally - let seqCond ← transformExpr cond - let seqThen ← transformExpr thenBranch - let seqElse ← match elseBranch with - | some e => pure (some (← transformExpr e)) - | none => pure none - return ⟨.IfThenElse seqCond seqThen seqElse, source⟩ + -- No assignments in branches — use generic traversal + mapStmtExprChildrenM transformExpr (reverseChildren := true) expr | .Block stmts labelOption => let newStmts := (← stmts.reverse.mapM transformExpr).reverse @@ -345,10 +338,17 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do else return expr - | _ => return expr - termination_by (sizeOf expr, 0) - decreasing_by - all_goals (simp_all; try term_by_mem) + -- Assert and Assume in expression position (e.g. inside blocks) are not + -- recursed into — they are lifted out by onlyKeepSideEffectStmtsAndLast + -- and should reference original variable names, not substituted ones. + | .Assert _ | .Assume _ => return expr + + -- All other constructors: delegate to generic right-to-left child traversal + -- via `mapStmtExprChildrenM`. This handles PrimitiveOp, ReferenceEquals, + -- AsType, IsType, InstanceCall, Quantifier, Assigned, Old, Fresh, + -- ProveBy, ContractOf, PureFieldUpdate, Var (.Field ..), and all + -- leaves automatically — the pass doesn't need to know about them. + | _ => mapStmtExprChildrenM transformExpr (reverseChildren := true) expr /-- Process a statement, handling any assignments in its sub-expressions. @@ -451,11 +451,6 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do | _ => return [stmt] - termination_by (sizeOf stmt, 0) - decreasing_by - all_goals (try term_by_mem) - all_goals (apply Prod.Lex.left; try term_by_mem) -end def transformProcedureBody (body : StmtExprMd) : LiftM StmtExprMd := do let stmts ← transformStmt body diff --git a/Strata/Languages/Laurel/MapStmtExpr.lean b/Strata/Languages/Laurel/MapStmtExpr.lean index e3892bae93..6f858ca506 100644 --- a/Strata/Languages/Laurel/MapStmtExpr.lean +++ b/Strata/Languages/Laurel/MapStmtExpr.lean @@ -18,6 +18,16 @@ can pattern-match in `f` and fall through for the rest. Also provides `mapProcedureBodiesM` and `mapProgramM` to eliminate the `Body`/`Procedure`/`Program` boilerplate shared by nearly every pass. + +## Pre/Post Hook Traversal + +`mapStmtExprPrePostM` extends the basic traversal with: +- A `pre` hook that can override recursion for specific constructors +- A `reverseChildren` flag for right-to-left sibling traversal + +When `pre` returns `some result`, the traversal skips recursion and uses that +result directly. When `pre` returns `none`, the generic recursion + `post` +proceeds as normal. -/ namespace Strata.Laurel @@ -106,6 +116,204 @@ decreasing_by def mapStmtExpr (f : StmtExprMd → StmtExprMd) (expr : StmtExprMd) : StmtExprMd := (mapStmtExprM (m := Id) f expr) +/-- +Map a monadic function over the immediate `StmtExprMd` children of a node +(one level only, no recursion). The node is rebuilt with the transformed children. + +When `reverseChildren` is `true`, list-valued children (e.g. arguments) are +traversed right-to-left and the results are reversed back to original order. + +This is useful for passes that handle specific constructors with custom logic +but want generic child traversal for all other constructors. +-/ +def mapStmtExprChildrenM [Monad m] (f : StmtExprMd → m StmtExprMd) + (reverseChildren : Bool := false) + (expr : StmtExprMd) : m StmtExprMd := do + let source := expr.source + match expr.val with + | .IfThenElse cond th el => + let seqEl ← el.mapM fun e => f e + pure ⟨.IfThenElse (← f cond) (← f th) seqEl, source⟩ + | .Block stmts label => + let mapped ← if reverseChildren then do + let r ← stmts.reverse.mapM f + pure r.reverse + else + stmts.mapM f + pure ⟨.Block mapped label, source⟩ + | .While cond invs dec body => + pure ⟨.While (← f cond) (← invs.mapM f) (← dec.mapM f) (← f body), source⟩ + | .Return v => + pure ⟨.Return (← v.mapM f), source⟩ + | .Assign targets value => + let targets' ← targets.mapM fun v => do + let ⟨vv, vs⟩ := v + match vv with + | .Field target fieldName => + pure ⟨Variable.Field (← f target) fieldName, vs⟩ + | .Local _ | .Declare _ => pure v + pure ⟨.Assign targets' (← f value), source⟩ + | .Var (.Field target fieldName) => + pure ⟨.Var (.Field (← f target) fieldName), source⟩ + | .PureFieldUpdate target fieldName newValue => + pure ⟨.PureFieldUpdate (← f target) fieldName (← f newValue), source⟩ + | .StaticCall callee args => + let mapped ← if reverseChildren then do + let r ← args.reverse.mapM f + pure r.reverse + else + args.mapM f + pure ⟨.StaticCall callee mapped, source⟩ + | .PrimitiveOp op args => + let mapped ← if reverseChildren then do + let r ← args.reverse.mapM f + pure r.reverse + else + args.mapM f + pure ⟨.PrimitiveOp op mapped, source⟩ + | .ReferenceEquals lhs rhs => + pure ⟨.ReferenceEquals (← f lhs) (← f rhs), source⟩ + | .AsType target ty => + pure ⟨.AsType (← f target) ty, source⟩ + | .IsType target ty => + pure ⟨.IsType (← f target) ty, source⟩ + | .InstanceCall target callee args => + let mapped ← if reverseChildren then do + let r ← args.reverse.mapM f + pure r.reverse + else + args.mapM f + pure ⟨.InstanceCall (← f target) callee mapped, source⟩ + | .Quantifier mode param trigger body => + pure ⟨.Quantifier mode param (← trigger.mapM f) (← f body), source⟩ + | .Assigned name => + pure ⟨.Assigned (← f name), source⟩ + | .Old value => + pure ⟨.Old (← f value), source⟩ + | .Fresh value => + pure ⟨.Fresh (← f value), source⟩ + | .Assert cond => + pure ⟨.Assert { cond with condition := ← f cond.condition }, source⟩ + | .Assume cond => + pure ⟨.Assume (← f cond), source⟩ + | .ProveBy value proof => + pure ⟨.ProveBy (← f value) (← f proof), source⟩ + | .ContractOf ty func => + pure ⟨.ContractOf ty (← f func), source⟩ + | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ + | .Var (.Local _) | .Var (.Declare _) | .New _ | .This | .Abstract | .All | .Hole .. => pure expr + +/-- +Monadic traversal of `StmtExprMd` with pre/post hooks and optional right-to-left +child traversal. + +- `pre`: called before recursion. If it returns `some e`, that value is used + directly (no recursion into children). If it returns `none`, the generic + recursion proceeds followed by `post`. +- `post`: applied to the rebuilt node after children have been recursively + traversed. Only called when `pre` returned `none`. +- `reverseChildren`: when `true`, list-valued children (e.g. arguments) are + traversed right-to-left and the results are reversed back to original order. +-/ +def mapStmtExprPrePostM [Monad m] + (pre : StmtExprMd → m (Option StmtExprMd)) + (post : StmtExprMd → m StmtExprMd) + (reverseChildren : Bool := false) + (expr : StmtExprMd) : m StmtExprMd := do + match ← pre expr with + | some result => return result + | none => + let source := expr.source + let rebuilt ← match _h : expr.val with + | .IfThenElse cond th el => + let seqEl ← el.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure ⟨.IfThenElse (← mapStmtExprPrePostM pre post reverseChildren cond) + (← mapStmtExprPrePostM pre post reverseChildren th) seqEl, source⟩ + | .Block stmts label => + let mapped ← if reverseChildren then do + let r ← stmts.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure r.reverse + else + stmts.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure ⟨.Block mapped label, source⟩ + | .While cond invs dec body => + pure ⟨.While (← mapStmtExprPrePostM pre post reverseChildren cond) + (← invs.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e) + (← dec.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e) + (← mapStmtExprPrePostM pre post reverseChildren body), source⟩ + | .Return v => + pure ⟨.Return (← v.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e), source⟩ + | .Assign targets value => + let targets' ← targets.attach.mapM fun ⟨v, _⟩ => do + let ⟨vv, vs⟩ := v + match vv with + | .Field target fieldName => + pure ⟨Variable.Field (← mapStmtExprPrePostM pre post reverseChildren target) fieldName, vs⟩ + | .Local _ | .Declare _ => pure v + pure ⟨.Assign targets' (← mapStmtExprPrePostM pre post reverseChildren value), source⟩ + | .Var (.Field target fieldName) => + pure ⟨.Var (.Field (← mapStmtExprPrePostM pre post reverseChildren target) fieldName), source⟩ + | .PureFieldUpdate target fieldName newValue => + pure ⟨.PureFieldUpdate (← mapStmtExprPrePostM pre post reverseChildren target) fieldName + (← mapStmtExprPrePostM pre post reverseChildren newValue), source⟩ + | .StaticCall callee args => + let mapped ← if reverseChildren then do + let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure r.reverse + else + args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure ⟨.StaticCall callee mapped, source⟩ + | .PrimitiveOp op args => + let mapped ← if reverseChildren then do + let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure r.reverse + else + args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure ⟨.PrimitiveOp op mapped, source⟩ + | .ReferenceEquals lhs rhs => + pure ⟨.ReferenceEquals (← mapStmtExprPrePostM pre post reverseChildren lhs) + (← mapStmtExprPrePostM pre post reverseChildren rhs), source⟩ + | .AsType target ty => + pure ⟨.AsType (← mapStmtExprPrePostM pre post reverseChildren target) ty, source⟩ + | .IsType target ty => + pure ⟨.IsType (← mapStmtExprPrePostM pre post reverseChildren target) ty, source⟩ + | .InstanceCall target callee args => + let mapped ← if reverseChildren then do + let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure r.reverse + else + args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e + pure ⟨.InstanceCall (← mapStmtExprPrePostM pre post reverseChildren target) callee mapped, source⟩ + | .Quantifier mode param trigger body => + pure ⟨.Quantifier mode param + (← trigger.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e) + (← mapStmtExprPrePostM pre post reverseChildren body), source⟩ + | .Assigned name => + pure ⟨.Assigned (← mapStmtExprPrePostM pre post reverseChildren name), source⟩ + | .Old value => + pure ⟨.Old (← mapStmtExprPrePostM pre post reverseChildren value), source⟩ + | .Fresh value => + pure ⟨.Fresh (← mapStmtExprPrePostM pre post reverseChildren value), source⟩ + | .Assert cond => + pure ⟨.Assert { cond with condition := ← mapStmtExprPrePostM pre post reverseChildren cond.condition }, source⟩ + | .Assume cond => + pure ⟨.Assume (← mapStmtExprPrePostM pre post reverseChildren cond), source⟩ + | .ProveBy value proof => + pure ⟨.ProveBy (← mapStmtExprPrePostM pre post reverseChildren value) + (← mapStmtExprPrePostM pre post reverseChildren proof), source⟩ + | .ContractOf ty func => + pure ⟨.ContractOf ty (← mapStmtExprPrePostM pre post reverseChildren func), source⟩ + | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ + | .Var (.Local _) | .Var (.Declare _) | .New _ | .This | .Abstract | .All | .Hole .. => pure expr + post rebuilt +termination_by sizeOf expr +decreasing_by + all_goals simp_wf + all_goals (try have := AstNode.sizeOf_val_lt expr) + all_goals (try have := Condition.sizeOf_condition_lt ‹_›) + all_goals (try term_by_mem) + all_goals (cases expr; simp_all; omega) + /-- Apply a monadic transformation to all procedure bodies. -/ def mapProcedureBodiesM [Monad m] (f : StmtExprMd → m StmtExprMd) (proc : Procedure) : m Procedure := do match proc.body with