Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions Strata/Languages/Laurel/LiftImperativeExpressions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public import Strata.Languages.Laurel.Grammar.AbstractToConcreteTreeTranslator
public import Strata.Languages.Laurel.LaurelTypes
public import Strata.Languages.Core.Verifier
public import Strata.DL.Util.Map
public import Strata.Languages.Laurel.MapStmtExpr
import Strata.Util.Tactics

namespace Strata
Expand Down Expand Up @@ -216,20 +217,22 @@ private def liftAssignExpr (targets : List VariableMd) (seqValue : StmtExprMd)
setSubst varName snapshotName
| _ => pure ()

mutual
/--
Process an expression in expression context, traversing arguments right to left.
Assignments are lifted to prependedStmts and replaced with snapshot variable references.

Only constructors that require custom lifting logic are handled explicitly.
All other constructors (like `PrimitiveOp`) are traversed generically via
`mapStmtExprChildrenM`, so the pass doesn't need to enumerate language features
it doesn't relate to.
-/
def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
partial def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
match expr with
| AstNode.mk val source =>
match val with
| .Var (.Local name) =>
return ⟨.Var (.Local (← getSubst name)), source⟩

| .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _ => return expr

| .Hole false (some holeType) =>
-- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc)
let holeVar ← freshCondVar
Expand Down Expand Up @@ -262,11 +265,6 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do

return resultExpr

| .PrimitiveOp op args =>
-- Process arguments right to left
let seqArgs ← args.reverse.mapM transformExpr
return ⟨.PrimitiveOp op seqArgs.reverse, source⟩

| .StaticCall callee args =>
let model := (← get).model
let seqArgs ← args.reverse.mapM transformExpr
Expand Down Expand Up @@ -322,13 +320,8 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
prepend (bare (.Var (.Declare ⟨condVar, condType⟩)))
return bare (.Var (.Local condVar))
else
-- No assignments in branches — recurse normally
let seqCond ← transformExpr cond
let seqThen ← transformExpr thenBranch
let seqElse ← match elseBranch with
| some e => pure (some (← transformExpr e))
| none => pure none
return ⟨.IfThenElse seqCond seqThen seqElse, source⟩
-- No assignments in branches — use generic traversal
mapStmtExprChildrenM transformExpr (reverseChildren := true) expr

| .Block stmts labelOption =>
let newStmts := (← stmts.reverse.mapM transformExpr).reverse
Expand All @@ -345,10 +338,17 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
else
return expr

| _ => return expr
termination_by (sizeOf expr, 0)
decreasing_by
all_goals (simp_all; try term_by_mem)
-- Assert and Assume in expression position (e.g. inside blocks) are not
-- recursed into — they are lifted out by onlyKeepSideEffectStmtsAndLast
-- and should reference original variable names, not substituted ones.
| .Assert _ | .Assume _ => return expr

-- All other constructors: delegate to generic right-to-left child traversal
-- via `mapStmtExprChildrenM`. This handles PrimitiveOp, ReferenceEquals,
-- AsType, IsType, InstanceCall, Quantifier, Assigned, Old, Fresh,
-- ProveBy, ContractOf, PureFieldUpdate, Var (.Field ..), and all
-- leaves automatically — the pass doesn't need to know about them.
| _ => mapStmtExprChildrenM transformExpr (reverseChildren := true) expr

/--
Process a statement, handling any assignments in its sub-expressions.
Expand Down Expand Up @@ -451,11 +451,6 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do

| _ =>
return [stmt]
termination_by (sizeOf stmt, 0)
decreasing_by
all_goals (try term_by_mem)
all_goals (apply Prod.Lex.left; try term_by_mem)
end

def transformProcedureBody (body : StmtExprMd) : LiftM StmtExprMd := do
let stmts ← transformStmt body
Expand Down
208 changes: 208 additions & 0 deletions Strata/Languages/Laurel/MapStmtExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ can pattern-match in `f` and fall through for the rest.

Also provides `mapProcedureBodiesM` and `mapProgramM` to eliminate the
`Body`/`Procedure`/`Program` boilerplate shared by nearly every pass.

## Pre/Post Hook Traversal

`mapStmtExprPrePostM` extends the basic traversal with:
- A `pre` hook that can override recursion for specific constructors
- A `reverseChildren` flag for right-to-left sibling traversal

When `pre` returns `some result`, the traversal skips recursion and uses that
result directly. When `pre` returns `none`, the generic recursion + `post`
proceeds as normal.
-/

namespace Strata.Laurel
Expand Down Expand Up @@ -106,6 +116,204 @@ decreasing_by
def mapStmtExpr (f : StmtExprMd → StmtExprMd) (expr : StmtExprMd) : StmtExprMd :=
(mapStmtExprM (m := Id) f expr)

/--
Map a monadic function over the immediate `StmtExprMd` children of a node
(one level only, no recursion). The node is rebuilt with the transformed children.

When `reverseChildren` is `true`, list-valued children (e.g. arguments) are
traversed right-to-left and the results are reversed back to original order.

This is useful for passes that handle specific constructors with custom logic
but want generic child traversal for all other constructors.
-/
def mapStmtExprChildrenM [Monad m] (f : StmtExprMd → m StmtExprMd)
(reverseChildren : Bool := false)
(expr : StmtExprMd) : m StmtExprMd := do
let source := expr.source
match expr.val with
| .IfThenElse cond th el =>
let seqEl ← el.mapM fun e => f e
pure ⟨.IfThenElse (← f cond) (← f th) seqEl, source⟩
| .Block stmts label =>
let mapped ← if reverseChildren then do
let r ← stmts.reverse.mapM f
pure r.reverse
else
stmts.mapM f
pure ⟨.Block mapped label, source⟩
| .While cond invs dec body =>
pure ⟨.While (← f cond) (← invs.mapM f) (← dec.mapM f) (← f body), source⟩
| .Return v =>
pure ⟨.Return (← v.mapM f), source⟩
| .Assign targets value =>
let targets' ← targets.mapM fun v => do
let ⟨vv, vs⟩ := v
match vv with
| .Field target fieldName =>
pure ⟨Variable.Field (← f target) fieldName, vs⟩
| .Local _ | .Declare _ => pure v
pure ⟨.Assign targets' (← f value), source⟩
| .Var (.Field target fieldName) =>
pure ⟨.Var (.Field (← f target) fieldName), source⟩
| .PureFieldUpdate target fieldName newValue =>
pure ⟨.PureFieldUpdate (← f target) fieldName (← f newValue), source⟩
| .StaticCall callee args =>
let mapped ← if reverseChildren then do
let r ← args.reverse.mapM f
pure r.reverse
else
args.mapM f
pure ⟨.StaticCall callee mapped, source⟩
| .PrimitiveOp op args =>
let mapped ← if reverseChildren then do
let r ← args.reverse.mapM f
pure r.reverse
else
args.mapM f
pure ⟨.PrimitiveOp op mapped, source⟩
| .ReferenceEquals lhs rhs =>
pure ⟨.ReferenceEquals (← f lhs) (← f rhs), source⟩
| .AsType target ty =>
pure ⟨.AsType (← f target) ty, source⟩
| .IsType target ty =>
pure ⟨.IsType (← f target) ty, source⟩
| .InstanceCall target callee args =>
let mapped ← if reverseChildren then do
let r ← args.reverse.mapM f
pure r.reverse
else
args.mapM f
pure ⟨.InstanceCall (← f target) callee mapped, source⟩
| .Quantifier mode param trigger body =>
pure ⟨.Quantifier mode param (← trigger.mapM f) (← f body), source⟩
| .Assigned name =>
pure ⟨.Assigned (← f name), source⟩
| .Old value =>
pure ⟨.Old (← f value), source⟩
| .Fresh value =>
pure ⟨.Fresh (← f value), source⟩
| .Assert cond =>
pure ⟨.Assert { cond with condition := ← f cond.condition }, source⟩
| .Assume cond =>
pure ⟨.Assume (← f cond), source⟩
| .ProveBy value proof =>
pure ⟨.ProveBy (← f value) (← f proof), source⟩
| .ContractOf ty func =>
pure ⟨.ContractOf ty (← f func), source⟩
| .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _
| .Var (.Local _) | .Var (.Declare _) | .New _ | .This | .Abstract | .All | .Hole .. => pure expr

/--
Monadic traversal of `StmtExprMd` with pre/post hooks and optional right-to-left
child traversal.

- `pre`: called before recursion. If it returns `some e`, that value is used
directly (no recursion into children). If it returns `none`, the generic
recursion proceeds followed by `post`.
- `post`: applied to the rebuilt node after children have been recursively
traversed. Only called when `pre` returned `none`.
- `reverseChildren`: when `true`, list-valued children (e.g. arguments) are
traversed right-to-left and the results are reversed back to original order.
-/
def mapStmtExprPrePostM [Monad m]
(pre : StmtExprMd → m (Option StmtExprMd))
(post : StmtExprMd → m StmtExprMd)
(reverseChildren : Bool := false)
(expr : StmtExprMd) : m StmtExprMd := do
match ← pre expr with
| some result => return result
| none =>
let source := expr.source
let rebuilt ← match _h : expr.val with
| .IfThenElse cond th el =>
let seqEl ← el.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure ⟨.IfThenElse (← mapStmtExprPrePostM pre post reverseChildren cond)
(← mapStmtExprPrePostM pre post reverseChildren th) seqEl, source⟩
| .Block stmts label =>
let mapped ← if reverseChildren then do
let r ← stmts.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure r.reverse
else
stmts.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure ⟨.Block mapped label, source⟩
| .While cond invs dec body =>
pure ⟨.While (← mapStmtExprPrePostM pre post reverseChildren cond)
(← invs.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e)
(← dec.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e)
(← mapStmtExprPrePostM pre post reverseChildren body), source⟩
| .Return v =>
pure ⟨.Return (← v.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e), source⟩
| .Assign targets value =>
let targets' ← targets.attach.mapM fun ⟨v, _⟩ => do
let ⟨vv, vs⟩ := v
match vv with
| .Field target fieldName =>
pure ⟨Variable.Field (← mapStmtExprPrePostM pre post reverseChildren target) fieldName, vs⟩
| .Local _ | .Declare _ => pure v
pure ⟨.Assign targets' (← mapStmtExprPrePostM pre post reverseChildren value), source⟩
| .Var (.Field target fieldName) =>
pure ⟨.Var (.Field (← mapStmtExprPrePostM pre post reverseChildren target) fieldName), source⟩
| .PureFieldUpdate target fieldName newValue =>
pure ⟨.PureFieldUpdate (← mapStmtExprPrePostM pre post reverseChildren target) fieldName
(← mapStmtExprPrePostM pre post reverseChildren newValue), source⟩
| .StaticCall callee args =>
let mapped ← if reverseChildren then do
let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure r.reverse
else
args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure ⟨.StaticCall callee mapped, source⟩
| .PrimitiveOp op args =>
let mapped ← if reverseChildren then do
let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure r.reverse
else
args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure ⟨.PrimitiveOp op mapped, source⟩
| .ReferenceEquals lhs rhs =>
pure ⟨.ReferenceEquals (← mapStmtExprPrePostM pre post reverseChildren lhs)
(← mapStmtExprPrePostM pre post reverseChildren rhs), source⟩
| .AsType target ty =>
pure ⟨.AsType (← mapStmtExprPrePostM pre post reverseChildren target) ty, source⟩
| .IsType target ty =>
pure ⟨.IsType (← mapStmtExprPrePostM pre post reverseChildren target) ty, source⟩
| .InstanceCall target callee args =>
let mapped ← if reverseChildren then do
let r ← args.attach.reverse.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure r.reverse
else
args.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e
pure ⟨.InstanceCall (← mapStmtExprPrePostM pre post reverseChildren target) callee mapped, source⟩
| .Quantifier mode param trigger body =>
pure ⟨.Quantifier mode param
(← trigger.attach.mapM fun ⟨e, _⟩ => mapStmtExprPrePostM pre post reverseChildren e)
(← mapStmtExprPrePostM pre post reverseChildren body), source⟩
| .Assigned name =>
pure ⟨.Assigned (← mapStmtExprPrePostM pre post reverseChildren name), source⟩
| .Old value =>
pure ⟨.Old (← mapStmtExprPrePostM pre post reverseChildren value), source⟩
| .Fresh value =>
pure ⟨.Fresh (← mapStmtExprPrePostM pre post reverseChildren value), source⟩
| .Assert cond =>
pure ⟨.Assert { cond with condition := ← mapStmtExprPrePostM pre post reverseChildren cond.condition }, source⟩
| .Assume cond =>
pure ⟨.Assume (← mapStmtExprPrePostM pre post reverseChildren cond), source⟩
| .ProveBy value proof =>
pure ⟨.ProveBy (← mapStmtExprPrePostM pre post reverseChildren value)
(← mapStmtExprPrePostM pre post reverseChildren proof), source⟩
| .ContractOf ty func =>
pure ⟨.ContractOf ty (← mapStmtExprPrePostM pre post reverseChildren func), source⟩
| .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ | .LiteralDecimal _
| .Var (.Local _) | .Var (.Declare _) | .New _ | .This | .Abstract | .All | .Hole .. => pure expr
post rebuilt
termination_by sizeOf expr
decreasing_by
all_goals simp_wf
all_goals (try have := AstNode.sizeOf_val_lt expr)
all_goals (try have := Condition.sizeOf_condition_lt ‹_›)
all_goals (try term_by_mem)
all_goals (cases expr; simp_all; omega)

/-- Apply a monadic transformation to all procedure bodies. -/
def mapProcedureBodiesM [Monad m] (f : StmtExprMd → m StmtExprMd) (proc : Procedure) : m Procedure := do
match proc.body with
Expand Down
Loading