Skip to content

Commit 1e7ceef

Browse files
chore: simplify archon (#175)
*Archon flush values are now of `OracleOrConst`, following the Binius design more closely. * Unify Archon exponentials by using `OracleOrConst` as the type for the base.
1 parent 150a5fe commit 1e7ceef

15 files changed

Lines changed: 205 additions & 163 deletions

File tree

Ix/Aiur/Constraints.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ structure Constraints where
4646
recvs : Array (Channel × ArithExpr × Array ArithExpr)
4747
requires : Array (Channel × ArithExpr × OracleIdx × Array ArithExpr)
4848
topmostSelector : ArithExpr
49-
io : Array OracleIdx
49+
io : Array OracleOrConst
5050
multiplicity : OracleIdx
5151

5252
def blockSelector (block : Bytecode.Block) (columns : Columns) : ArithExpr :=
@@ -63,7 +63,7 @@ def new (function : Bytecode.Function) (layout : Layout) (columns : Columns) : C
6363
recvs := #[]
6464
requires := #[]
6565
topmostSelector := blockSelector function.body columns
66-
io := columns.inputs ++ columns.outputs
66+
io := columns.inputs ++ columns.outputs |>.map .oracle
6767
multiplicity := columns.multiplicity }
6868

6969
@[inline] def pushUnique (constraints : Constraints) (expr : ArithExpr) : Constraints :=

Ix/Aiur/Synthesis.lean

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,26 @@ abbrev SynthM := StateM SynthState
3535
(stt.circuitModule.freezeOracles, a)
3636

3737
@[inline] def flush (direction : FlushDirection) (channelId : ChannelId)
38-
(selector : OracleIdx) (args : @& Array OracleIdx) (multiplicity : UInt64) : SynthM Unit :=
38+
(selector : OracleIdx) (args : @& Array OracleOrConst) (multiplicity : UInt64) : SynthM Unit :=
3939
modify fun stt =>
4040
let circuitModule := stt.circuitModule.flush direction channelId selector
4141
args multiplicity
4242
{ stt with circuitModule }
4343

44-
@[inline] def send (channelId : ChannelId) (args : @& Array OracleIdx) : SynthM Unit :=
44+
@[inline] def send (channelId : ChannelId) (args : @& Array OracleOrConst) : SynthM Unit :=
4545
flush .push channelId CircuitModule.selector args 1
4646

47-
@[inline] def recv (channelId : ChannelId) (args : @& Array OracleIdx) : SynthM Unit :=
47+
@[inline] def recv (channelId : ChannelId) (args : @& Array OracleOrConst) : SynthM Unit :=
4848
flush .pull channelId CircuitModule.selector args 1
4949

5050
@[inline] def assertZero (name : @& String) (expr : @& ArithExpr) : SynthM Unit :=
5151
modify fun stt =>
5252
{ stt with circuitModule := stt.circuitModule.assertZero name #[] expr }
5353

54-
@[inline] def assertDynamicExp (expBits : @& Array OracleIdx) (result base : OracleIdx) :
54+
@[inline] def assertExp (expBits : @& Array OracleIdx) (result : OracleIdx) (base : @& OracleOrConst) :
5555
SynthM Unit :=
5656
modify fun stt =>
57-
{ stt with circuitModule := stt.circuitModule.assertDynamicExp expBits result base }
58-
59-
@[inline] def assertStaticExp (expBits : @& Array OracleIdx) (result : OracleIdx)
60-
(base : @& UInt128) (baseTF : TowerField): SynthM Unit :=
61-
modify fun stt =>
62-
{ stt with circuitModule := stt.circuitModule.assertStaticExp expBits result base baseTF }
57+
{ stt with circuitModule := stt.circuitModule.assertExp expBits result base }
6358

6459
def addCommitted (name : @& String) (tf : TowerField) (relativeHeight : RelativeHeight) :
6560
SynthM OracleIdx :=
@@ -97,15 +92,6 @@ def addProjected (name : @& String) (inner : OracleIdx) (selection : UInt64)
9792
let (o, circuitModule) := stt.circuitModule.addProjected name inner selection chunkSize
9893
(o, { stt with circuitModule })
9994

100-
def cacheConst (value : UInt128) : SynthM OracleIdx :=
101-
modifyGet fun stt => match stt.cachedOracles.find? (.const value) with
102-
| some o => (o, stt)
103-
| none =>
104-
let (o, circuitModule) := stt.circuitModule.addTransparent
105-
s!"cached-const-{value}" (.const value) .base
106-
let cachedOracles := stt.cachedOracles.insert (.const value) o
107-
(o, ⟨circuitModule, cachedOracles⟩)
108-
10995
def cacheLc (expr : ArithExpr) : SynthM OracleIdx :=
11096
let key := .lc expr
11197
modifyGet fun stt => match stt.cachedOracles.find? key with
@@ -127,16 +113,15 @@ where
127113
| _ => unreachable!
128114

129115
def provide (channelId : ChannelId) (multiplicity : OracleIdx)
130-
(args : Array OracleIdx) : SynthM Unit := do
131-
let ones ← cacheConst 1
132-
send channelId (args.push ones)
133-
recv channelId (args.push multiplicity)
116+
(args : Array OracleOrConst) : SynthM Unit := do
117+
send channelId (args.push (.const 1 .b1))
118+
recv channelId (args.push (.oracle multiplicity))
134119

135-
def require (channelId : ChannelId) (prevIdx : OracleIdx) (args : Array OracleIdx)
120+
def require (channelId : ChannelId) (prevIdx : OracleIdx) (args : Array OracleOrConst)
136121
(sel : OracleIdx) : SynthM Unit := do
137122
let idx ← addLinearCombination s!"index-{channelId.toUSize}" 0 #[(prevIdx, B64_MULT_GEN)] .base
138-
flush .pull channelId sel (args.push prevIdx) 1
139-
flush .push channelId sel (args.push idx) 1
123+
flush .pull channelId sel (args.push (.oracle prevIdx)) 1
124+
flush .push channelId sel (args.push (.oracle idx)) 1
140125

141126
def synthesizeFunction (funcIdx : FuncIdx) (function : Bytecode.Function)
142127
(layout : Layout) (aiurChannels : AiurChannels) : SynthM Columns := do
@@ -154,13 +139,15 @@ def synthesizeFunction (funcIdx : FuncIdx) (function : Bytecode.Function)
154139
constraints.recvs.forM fun (channel, sel, args) => do
155140
let sel ← cacheLc sel
156141
let args ← args.mapM cacheLc
142+
let args := args.map .oracle
157143
match channel with
158144
| .add => flush .pull aiurChannels.add sel args 1
159145
| .mul => flush .pull aiurChannels.mul sel args 1
160146
| _ => unreachable!
161147
constraints.requires.forM fun (channel, sel, prevIdx, values) => do
162148
let sel ← cacheLc sel
163149
let values ← values.mapM cacheLc
150+
let values := values.map .oracle
164151
match channel with
165152
| .func funcIdx =>
166153
let funcChannel := aiurChannels.func[funcIdx]!
@@ -192,7 +179,7 @@ def synthesizeAdd (channelId : ChannelId) : SynthM AddColumns := do
192179
let coutProjected ← addProjected "add-cout-projected" cout 63 64
193180
assertZero "add-sum" $ xin + yin + cin - zout
194181
assertZero "add-carry" $ (xin + cin) * (yin + cin) + cin - cout
195-
send channelId #[xinPacked, yinPacked, zoutPacked, coutProjected]
182+
send channelId $ #[xinPacked, yinPacked, zoutPacked, coutProjected].map .oracle
196183
pure { xin, yin, zout, cout }
197184

198185
structure MulColumns where
@@ -224,7 +211,7 @@ def synthesizeMul (channelId : ChannelId) : SynthM MulColumns := do
224211
← mul xinBits yinBits
225212
let zoutLow := zoutBits.extract (stop := 64)
226213
bitDecomposition "mul-bit-decomposition-zout" zoutLow zout
227-
send channelId #[xin, yin, zout]
214+
send channelId $ #[xin, yin, zout].map .oracle
228215
pure {
229216
xin,
230217
yin,
@@ -262,10 +249,10 @@ where
262249
let zoutLow := zoutBits.extract (stop := outSize)
263250
let zoutHigh := zoutBits.extract (start := outSize)
264251

265-
assertStaticExp xinBits xinExpResult B128_MULT_GEN .b128
266-
assertDynamicExp yinBits yinExpResult xinExpResult
267-
assertStaticExp zoutLow zoutLowExpResult B128_MULT_GEN .b128
268-
assertStaticExp zoutHigh zoutHighExpResult B128GenPow2To64 .b128
252+
assertExp xinBits xinExpResult (.const B128_MULT_GEN .b128)
253+
assertExp yinBits yinExpResult (.oracle xinExpResult)
254+
assertExp zoutLow zoutLowExpResult (.const B128_MULT_GEN .b128)
255+
assertExp zoutHigh zoutHighExpResult (.const B128GenPow2To64 .b128)
269256

270257
pure (xinExpResult, yinExpResult, zoutLowExpResult, zoutHighExpResult, zoutBits)
271258

@@ -277,7 +264,7 @@ def synthesizeMemory (channelId : ChannelId) (width : Nat) : SynthM MemoryColumn
277264
let address ← addTransparent s!"mem-{width}-address" .incremental .base
278265
let values ← Array.range width |>.mapM (addCommitted s!"mem-{width}-value-{·}" .b64 .base)
279266
let multiplicity ← addCommitted s!"mem-{width}-multiplicity" .b64 .base
280-
let args := #[address] ++ values
267+
let args := #[address] ++ values |>.map .oracle
281268
provide channelId multiplicity args
282269
pure { values, multiplicity }
283270

Ix/Archon/Circuit.lean

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import Blake3
22
import Ix.Archon.ArithExpr
33
import Ix.Archon.OracleIdx
44
import Ix.Archon.RelativeHeight
5+
import Ix.Archon.OracleOrConst
56
import Ix.Archon.Transparent
67
import Ix.Archon.Witness
78
import Ix.Binius.Common
@@ -32,7 +33,7 @@ opaque initWitnessModule : @& CircuitModule → WitnessModule
3233
/-- **Invalidates** the input `CircuitModule` -/
3334
@[never_extract, extern "c_rs_circuit_module_flush"]
3435
opaque flush : CircuitModule → Binius.FlushDirection → Binius.ChannelId →
35-
(selector : OracleIdx) → @& Array OracleIdx → (multiplicity : UInt64) → CircuitModule
36+
(selector : OracleIdx) → @& Array OracleOrConst → (multiplicity : UInt64) → CircuitModule
3637

3738
/-- **Invalidates** the input `CircuitModule` -/
3839
@[never_extract, extern "c_rs_circuit_module_assert_zero"]
@@ -44,14 +45,9 @@ opaque assertZero : CircuitModule → @& String → @& Array OracleIdx →
4445
opaque assertNotZero : CircuitModule → OracleIdx → CircuitModule
4546

4647
/-- **Invalidates** the input `CircuitModule` -/
47-
@[never_extract, extern "c_rs_circuit_module_assert_dynamic_exp"]
48-
opaque assertDynamicExp : CircuitModule → (expBits : @& Array OracleIdx) →
49-
(result : OracleIdx) → (base : OracleIdx) → CircuitModule
50-
51-
/-- **Invalidates** the input `CircuitModule` -/
52-
@[never_extract, extern "c_rs_circuit_module_assert_static_exp"]
53-
opaque assertStaticExp : CircuitModule → (expBits : @& Array OracleIdx) →
54-
(result : OracleIdx) → (base : @& UInt128) → (baseTF : TowerField) → CircuitModule
48+
@[never_extract, extern "c_rs_circuit_module_assert_exp"]
49+
opaque assertExp : CircuitModule → (expBits : @& Array OracleIdx) →
50+
(result : OracleIdx) → (base : @& OracleOrConst) → CircuitModule
5551

5652
/-- **Invalidates** the input `CircuitModule` -/
5753
@[never_extract, extern "c_rs_circuit_module_add_committed"]

Ix/Archon/OracleOrConst.lean

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import Ix.Archon.OracleIdx
2+
import Ix.Archon.TowerField
3+
import Ix.Unsigned
4+
5+
namespace Archon
6+
7+
inductive OracleOrConst
8+
| oracle : OracleIdx → OracleOrConst
9+
| const : UInt128 → TowerField → OracleOrConst
10+
deriving Inhabited
11+
12+
namespace OracleOrConst
13+
14+
def toString : OracleOrConst → String
15+
| oracle o => s!"Oracle({o.toUSize})"
16+
| const base tf => s!"Const({base}, {tf})"
17+
18+
instance : ToString OracleOrConst := ⟨toString⟩
19+
20+
def toBytes : @& OracleOrConst → ByteArray
21+
| oracle o => ⟨#[0]⟩ ++ o.toUSize.toLEBytes
22+
| const base tf => ⟨#[1]⟩ ++ base.toLEBytes |>.push tf.logDegree.toUInt8
23+
24+
/--
25+
Function meant for testing that tells whether the Lean→Rust mapping of OracleOrConst
26+
results on the same expression as deserializing the provided bytes.
27+
-/
28+
@[extern "rs_oracle_or_const_is_equivalent_to_bytes"]
29+
opaque isEquivalentToBytes : @& OracleOrConst → @& ByteArray → Bool
30+
31+
end OracleOrConst
32+
33+
end Archon

Ix/Archon/TowerField.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ namespace Archon
44

55
inductive TowerField
66
| b1 | b2 | b4 | b8 | b16 | b32 | b64 | b128
7+
deriving Inhabited
78

89
def TowerField.logDegree : TowerField → USize
910
| .b1 => 0

Ix/Binius/Common.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ structure ChannelId where
66
toUSize : USize
77
deriving Inhabited
88

9-
-- We can delete this later if we don't need it
109
structure ChannelAllocator where
1110
nextId : USize
1211

Tests/Aiur.lean

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,11 @@ where
145145
let (x, circuitModule) := circuitModule.addPacked "x-bits-packed" xBits 6
146146
let (y, circuitModule) := circuitModule.addPacked "y-bits-packed" yBits 6
147147
let (z, circuitModule) := circuitModule.addPacked "z-bits-packed" zBits 6
148-
let args := #[x, y, z]
149-
let (ones, circuitModule) := circuitModule.addTransparent "ones" (.const 1) .base
148+
let args := #[x, y, z].map .oracle
150149
let circuitModule := circuitModule.flush .push channelId CircuitModule.selector
151-
(args.push ones) 1
150+
(args.push (.const 1 .b1)) 1
152151
let circuitModule := circuitModule.flush .pull channelId CircuitModule.selector
153-
(args.push multiplicity) 1
152+
(args.push (.oracle multiplicity)) 1
154153
(circuitModule.popNamespace, #[xBits, yBits, multiplicity])
155154
populate entries oracles witnessModule :=
156155
if entries.isEmpty then (witnessModule, .inactive) else

Tests/FFIConsistency.lean

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import LSpec
22
import Tests.Common
33
import Ix.Binius.Boundary
44
import Ix.Archon.ArithExpr
5+
import Ix.Archon.OracleOrConst
56
import Ix.Archon.ModuleMode
67
import Ix.Archon.RelativeHeight
78
import Ix.Archon.Transparent
@@ -88,6 +89,28 @@ instance : Repr ModuleMode where
8889

8990
instance : SampleableExt ModuleMode := SampleableExt.mkSelfContained genModuleMode
9091

92+
/- OracleOrConst -/
93+
94+
def genOracleIdx : Gen OracleIdx :=
95+
OracleIdx.mk <$> genUSize
96+
97+
def genTowerField : Gen TowerField :=
98+
elements #[.b1, .b2, .b4, .b8, .b16, .b32, .b64, .b128]
99+
100+
def genOracleOrConst : Gen OracleOrConst :=
101+
frequency [
102+
(5, .oracle <$> genOracleIdx),
103+
(10, .const <$> genUInt128 <*> genTowerField),
104+
]
105+
106+
instance : Shrinkable OracleOrConst where
107+
shrink _ := []
108+
109+
instance : Repr OracleOrConst where
110+
reprPrec oc _ := oc.toString
111+
112+
instance : SampleableExt OracleOrConst := SampleableExt.mkSelfContained genOracleOrConst
113+
91114
/- RelativeHeight -/
92115

93116
def genRelativeHeight : Gen RelativeHeight :=
@@ -132,6 +155,8 @@ def Tests.FFIConsistency.suite := [
132155
(∀ boundary : Boundary, boundary.isEquivalentToBytes boundary.toBytes),
133156
check "ModuleMode Lean->Rust mapping matches the deserialized bytes"
134157
(∀ moduleMode : ModuleMode, moduleMode.isEquivalentToBytes moduleMode.toBytes),
158+
check "OracleOrConst Lean->Rust mapping matches the deserialized bytes"
159+
(∀ oc : OracleOrConst, oc.isEquivalentToBytes oc.toBytes),
135160
check "RelativeHeight Lean->Rust mapping matches the deserialized bytes"
136161
(∀ relativeHeight : RelativeHeight, relativeHeight.isEquivalentToBytes relativeHeight.toBytes),
137162
check "Transparent Lean->Rust mapping matches the deserialized bytes"

c/archon.c

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ extern lean_obj_res c_rs_circuit_module_flush(
216216
bool direction_pull,
217217
size_t channel_id,
218218
size_t selector,
219-
b_lean_obj_arg oracle_idxs,
219+
b_lean_obj_arg values,
220220
uint64_t multiplicity
221221
) {
222222
linear_object *linear = validated_linear(l_circuit);
@@ -225,7 +225,7 @@ extern lean_obj_res c_rs_circuit_module_flush(
225225
direction_pull,
226226
channel_id,
227227
selector,
228-
oracle_idxs,
228+
values,
229229
multiplicity
230230
);
231231
linear_object *new_linear = linear_bump(linear);
@@ -260,14 +260,14 @@ extern lean_obj_res c_rs_circuit_module_assert_not_zero(
260260
return alloc_lean_linear_object(new_linear);
261261
}
262262

263-
extern lean_obj_res c_rs_circuit_module_assert_dynamic_exp(
263+
extern lean_obj_res c_rs_circuit_module_assert_exp(
264264
lean_obj_arg l_circuit,
265265
b_lean_obj_arg exp_bits,
266266
size_t result,
267-
size_t base
267+
b_lean_obj_arg base
268268
) {
269269
linear_object *linear = validated_linear(l_circuit);
270-
rs_circuit_module_assert_dynamic_exp(
270+
rs_circuit_module_assert_exp(
271271
get_object_ref(linear),
272272
exp_bits,
273273
result,
@@ -277,25 +277,6 @@ extern lean_obj_res c_rs_circuit_module_assert_dynamic_exp(
277277
return alloc_lean_linear_object(new_linear);
278278
}
279279

280-
extern lean_obj_res c_rs_circuit_module_assert_static_exp(
281-
lean_obj_arg l_circuit,
282-
b_lean_obj_arg exp_bits,
283-
size_t result,
284-
b_lean_obj_arg base,
285-
uint8_t base_tower_level
286-
) {
287-
linear_object *linear = validated_linear(l_circuit);
288-
rs_circuit_module_assert_static_exp(
289-
get_object_ref(linear),
290-
exp_bits,
291-
result,
292-
lean_get_external_data(base),
293-
base_tower_level
294-
);
295-
linear_object *new_linear = linear_bump(linear);
296-
return alloc_lean_linear_object(new_linear);
297-
}
298-
299280
extern lean_obj_res c_rs_circuit_module_add_committed(
300281
lean_obj_arg l_circuit,
301282
b_lean_obj_arg name,

c/rust.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ void rs_circuit_module_assert_zero(
3737
void*, char const*, b_lean_obj_arg, b_lean_obj_arg
3838
);
3939
void rs_circuit_module_assert_not_zero(void*, size_t);
40-
void rs_circuit_module_assert_dynamic_exp(void*, b_lean_obj_arg, size_t, size_t);
41-
void rs_circuit_module_assert_static_exp(void*, b_lean_obj_arg, size_t, uint8_t*, uint8_t);
40+
void rs_circuit_module_assert_exp(void*, b_lean_obj_arg, size_t, b_lean_obj_arg);
4241
size_t rs_circuit_module_add_committed(void*, char const *, uint8_t, b_lean_obj_arg);
4342
size_t rs_circuit_module_add_transparent(void*, char const *, b_lean_obj_arg, b_lean_obj_arg);
4443
size_t rs_circuit_module_add_linear_combination(

0 commit comments

Comments
 (0)