Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions Strata/Languages/Laurel/LiftImperativeExpressions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ def containsAssignmentOrImperativeCall (model: SemanticModel) (expr : StmtExprMd
containsAssignmentOrImperativeCall model cond ||
containsAssignmentOrImperativeCall model th ||
match el with | some e => containsAssignmentOrImperativeCall model e | none => false
| .Assert cond => containsAssignmentOrImperativeCall model cond.condition
| .Assume cond => containsAssignmentOrImperativeCall model cond
| _ => false
termination_by expr
decreasing_by
all_goals ((try cases x); simp_all; try term_by_mem)
all_goals (have := Condition.sizeOf_condition_lt cond; omega)

/-- Check if an expression contains any nondeterministic holes (recursively). -/
private def containsNondetHole (expr : StmtExprMd) : Bool :=
Expand Down Expand Up @@ -330,6 +333,28 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
| none => pure none
return ⟨.IfThenElse seqCond seqThen seqElse, source⟩

| .Assert cond =>
let prePrepends ← takePrepends
let savedSubst := (← get).subst
modify fun s => { s with subst := [] }
let seqCond ← transformExpr cond.condition
let argPrepends ← takePrepends
modify fun s => { s with subst := savedSubst }
let liftedAssert := [⟨.Assert { cond with condition := seqCond }, source⟩]
modify fun s => { s with prependedStmts := s.prependedStmts ++ argPrepends ++ liftedAssert ++ prePrepends }
return bare (.LiteralBool true)

| .Assume cond =>
let prePrepends ← takePrepends
let savedSubst := (← get).subst
modify fun s => { s with subst := [] }
let seqCond ← transformExpr cond
let argPrepends ← takePrepends
modify fun s => { s with subst := savedSubst }
let liftedAssume := [⟨.Assume seqCond, source⟩]
modify fun s => { s with prependedStmts := s.prependedStmts ++ argPrepends ++ liftedAssume ++ prePrepends }
return bare (.LiteralBool true)

| .Block stmts labelOption =>
let newStmts := (← stmts.reverse.mapM transformExpr).reverse
return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, source⟩
Expand All @@ -349,6 +374,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
termination_by (sizeOf expr, 0)
decreasing_by
all_goals (simp_all; try term_by_mem)
all_goals (have := Condition.sizeOf_condition_lt cond; omega)

/--
Process a statement, handling any assignments in its sub-expressions.
Expand All @@ -359,24 +385,16 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do
| AstNode.mk val source =>
match val with
| .Assert cond =>
-- Do not transform assert conditions with assignments — they must be rejected.
-- But nondeterministic holes need to be lifted.
if containsNondetHole cond.condition && !containsAssignmentOrImperativeCall (← get).model cond.condition then
let seqCond ← transformExpr cond.condition
let prepends ← takePrepends
modify fun s => { s with subst := [] }
return prepends ++ [⟨.Assert { cond with condition := seqCond }, source⟩]
else
return [stmt]
let seqCond ← transformExpr cond.condition
let prepends ← takePrepends
modify fun s => { s with subst := [] }
return prepends ++ [⟨.Assert { cond with condition := seqCond }, source⟩]

| .Assume cond =>
if containsNondetHole cond && !containsAssignmentOrImperativeCall (← get).model cond then
let seqCond ← transformExpr cond
let prepends ← takePrepends
modify fun s => { s with subst := [] }
return prepends ++ [⟨.Assume seqCond, source⟩]
else
return [stmt]
let seqCond ← transformExpr cond
let prepends ← takePrepends
modify fun s => { s with subst := [] }
return prepends ++ [⟨.Assume seqCond, source⟩]

| .Block stmts metadata =>
let seqStmts ← stmts.mapM transformStmt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ procedure impureContractIsNotLegal1(x: int)
opaque
{
assert impure() == 1
// ^^^^^^^^ error: calls to procedures are not supported in functions or contracts
};

procedure impureContractIsNotLegal2(x: int)
Expand All @@ -53,7 +52,6 @@ procedure impureContractIsNotLegal2(x: int)
opaque
{
assert (x := 2) == 2
// ^^^^^^ error: destructive assignments are not supported in functions or contracts
};
"

Expand Down
40 changes: 40 additions & 0 deletions StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,44 @@ info: procedure assertInBlockExpr()
for proc in program.staticProcedures do
IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc)))

def assertWithAssignProgram : String := r"
procedure foo()
{
var x: int := 1;
var y: int := {
assert x > 0;
3
};
assert (x := 2) == 2
};
"

/--
info: procedure foo()
{ var x: int := 1; assert x > 0; var y: int := { 3 }; var $x_0: int := x; x := 2; assert x == 2 };
-/
#guard_msgs in
#eval! do
let program ← parseLaurelAndLift assertWithAssignProgram
for proc in program.staticProcedures do
IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc)))

def assumeWithAssignProgram : String := r"
procedure bar()
{
var x: int := 1;
assume (x := 2) == 2
};
"

/--
info: procedure bar()
{ var x: int := 1; var $x_0: int := x; x := 2; assume x == 2 };
-/
#guard_msgs in
#eval! do
let program ← parseLaurelAndLift assumeWithAssignProgram
for proc in program.staticProcedures do
IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc)))

end Laurel
Loading