@@ -35,31 +35,26 @@ abbrev SynthM := StateM SynthState
3535 (stt.circuitModule.freezeOracles, a)
3636
3737@[inline] def flush (direction : FlushDirection) (channelId : ChannelId)
38- (selector : OracleIdx) (args : @& Array OracleIdx ) (multiplicity : UInt64) : SynthM Unit :=
38+ (selector : OracleIdx) (args : @& Array OracleOrConst ) (multiplicity : UInt64) : SynthM Unit :=
3939 modify fun stt =>
4040 let circuitModule := stt.circuitModule.flush direction channelId selector
4141 args multiplicity
4242 { stt with circuitModule }
4343
44- @[inline] def send (channelId : ChannelId) (args : @& Array OracleIdx ) : SynthM Unit :=
44+ @[inline] def send (channelId : ChannelId) (args : @& Array OracleOrConst ) : SynthM Unit :=
4545 flush .push channelId CircuitModule.selector args 1
4646
47- @[inline] def recv (channelId : ChannelId) (args : @& Array OracleIdx ) : SynthM Unit :=
47+ @[inline] def recv (channelId : ChannelId) (args : @& Array OracleOrConst ) : SynthM Unit :=
4848 flush .pull channelId CircuitModule.selector args 1
4949
5050@[inline] def assertZero (name : @& String) (expr : @& ArithExpr) : SynthM Unit :=
5151 modify fun stt =>
5252 { stt with circuitModule := stt.circuitModule.assertZero name #[] expr }
5353
54- @[inline] def assertDynamicExp (expBits : @& Array OracleIdx) (result base : OracleIdx ) :
54+ @[inline] def assertExp (expBits : @& Array OracleIdx) (result : OracleIdx) ( base : @& OracleOrConst ) :
5555 SynthM Unit :=
5656 modify fun stt =>
57- { stt with circuitModule := stt.circuitModule.assertDynamicExp expBits result base }
58-
59- @[inline] def assertStaticExp (expBits : @& Array OracleIdx) (result : OracleIdx)
60- (base : @& UInt128) (baseTF : TowerField): SynthM Unit :=
61- modify fun stt =>
62- { stt with circuitModule := stt.circuitModule.assertStaticExp expBits result base baseTF }
57+ { stt with circuitModule := stt.circuitModule.assertExp expBits result base }
6358
6459def addCommitted (name : @& String) (tf : TowerField) (relativeHeight : RelativeHeight) :
6560 SynthM OracleIdx :=
@@ -97,15 +92,6 @@ def addProjected (name : @& String) (inner : OracleIdx) (selection : UInt64)
9792 let (o, circuitModule) := stt.circuitModule.addProjected name inner selection chunkSize
9893 (o, { stt with circuitModule })
9994
100- def cacheConst (value : UInt128) : SynthM OracleIdx :=
101- modifyGet fun stt => match stt.cachedOracles.find? (.const value) with
102- | some o => (o, stt)
103- | none =>
104- let (o, circuitModule) := stt.circuitModule.addTransparent
105- s! "cached-const-{ value} " (.const value) .base
106- let cachedOracles := stt.cachedOracles.insert (.const value) o
107- (o, ⟨circuitModule, cachedOracles⟩)
108-
10995def cacheLc (expr : ArithExpr) : SynthM OracleIdx :=
11096 let key := .lc expr
11197 modifyGet fun stt => match stt.cachedOracles.find? key with
@@ -127,16 +113,15 @@ where
127113 | _ => unreachable!
128114
129115def provide (channelId : ChannelId) (multiplicity : OracleIdx)
130- (args : Array OracleIdx) : SynthM Unit := do
131- let ones ← cacheConst 1
132- send channelId (args.push ones)
133- recv channelId (args.push multiplicity)
116+ (args : Array OracleOrConst) : SynthM Unit := do
117+ send channelId (args.push (.const 1 .b1))
118+ recv channelId (args.push (.oracle multiplicity))
134119
135- def require (channelId : ChannelId) (prevIdx : OracleIdx) (args : Array OracleIdx )
120+ def require (channelId : ChannelId) (prevIdx : OracleIdx) (args : Array OracleOrConst )
136121 (sel : OracleIdx) : SynthM Unit := do
137122 let idx ← addLinearCombination s! "index-{ channelId.toUSize} " 0 #[(prevIdx, B64_MULT_GEN)] .base
138- flush .pull channelId sel (args.push prevIdx) 1
139- flush .push channelId sel (args.push idx) 1
123+ flush .pull channelId sel (args.push (.oracle prevIdx) ) 1
124+ flush .push channelId sel (args.push (.oracle idx) ) 1
140125
141126def synthesizeFunction (funcIdx : FuncIdx) (function : Bytecode.Function)
142127 (layout : Layout) (aiurChannels : AiurChannels) : SynthM Columns := do
@@ -154,13 +139,15 @@ def synthesizeFunction (funcIdx : FuncIdx) (function : Bytecode.Function)
154139 constraints.recvs.forM fun (channel, sel, args) => do
155140 let sel ← cacheLc sel
156141 let args ← args.mapM cacheLc
142+ let args := args.map .oracle
157143 match channel with
158144 | .add => flush .pull aiurChannels.add sel args 1
159145 | .mul => flush .pull aiurChannels.mul sel args 1
160146 | _ => unreachable!
161147 constraints.requires.forM fun (channel, sel, prevIdx, values) => do
162148 let sel ← cacheLc sel
163149 let values ← values.mapM cacheLc
150+ let values := values.map .oracle
164151 match channel with
165152 | .func funcIdx =>
166153 let funcChannel := aiurChannels.func[funcIdx]!
@@ -192,7 +179,7 @@ def synthesizeAdd (channelId : ChannelId) : SynthM AddColumns := do
192179 let coutProjected ← addProjected "add-cout-projected" cout 63 64
193180 assertZero "add-sum" $ xin + yin + cin - zout
194181 assertZero "add-carry" $ (xin + cin) * (yin + cin) + cin - cout
195- send channelId #[xinPacked, yinPacked, zoutPacked, coutProjected]
182+ send channelId $ #[xinPacked, yinPacked, zoutPacked, coutProjected].map .oracle
196183 pure { xin, yin, zout, cout }
197184
198185structure MulColumns where
@@ -224,7 +211,7 @@ def synthesizeMul (channelId : ChannelId) : SynthM MulColumns := do
224211 ← mul xinBits yinBits
225212 let zoutLow := zoutBits.extract (stop := 64 )
226213 bitDecomposition "mul-bit-decomposition-zout" zoutLow zout
227- send channelId #[xin, yin, zout]
214+ send channelId $ #[xin, yin, zout].map .oracle
228215 pure {
229216 xin,
230217 yin,
@@ -262,10 +249,10 @@ where
262249 let zoutLow := zoutBits.extract (stop := outSize)
263250 let zoutHigh := zoutBits.extract (start := outSize)
264251
265- assertStaticExp xinBits xinExpResult B128_MULT_GEN .b128
266- assertDynamicExp yinBits yinExpResult xinExpResult
267- assertStaticExp zoutLow zoutLowExpResult B128_MULT_GEN .b128
268- assertStaticExp zoutHigh zoutHighExpResult B128GenPow2To64 .b128
252+ assertExp xinBits xinExpResult (.const B128_MULT_GEN .b128)
253+ assertExp yinBits yinExpResult (.oracle xinExpResult)
254+ assertExp zoutLow zoutLowExpResult (.const B128_MULT_GEN .b128)
255+ assertExp zoutHigh zoutHighExpResult (.const B128GenPow2To64 .b128)
269256
270257 pure (xinExpResult, yinExpResult, zoutLowExpResult, zoutHighExpResult, zoutBits)
271258
@@ -277,7 +264,7 @@ def synthesizeMemory (channelId : ChannelId) (width : Nat) : SynthM MemoryColumn
277264 let address ← addTransparent s! "mem-{ width} -address" .incremental .base
278265 let values ← Array.range width |>.mapM (addCommitted s! "mem-{ width} -value-{ ·} " .b64 .base)
279266 let multiplicity ← addCommitted s! "mem-{ width} -multiplicity" .b64 .base
280- let args := #[address] ++ values
267+ let args := #[address] ++ values |>.map .oracle
281268 provide channelId multiplicity args
282269 pure { values, multiplicity }
283270
0 commit comments