Skip to content

Commit 1526a11

Browse files
committed
sketched neural network spec language
1 parent 3074194 commit 1526a11

5 files changed

Lines changed: 180 additions & 4 deletions

File tree

Pullback/NN/Basic.lean

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import Mathlib
2+
3+
/-
4+
`NDMatrix α shape` is a numpy ndarray with shape `shape` and elements from `α`
5+
-/
6+
@[reducible]
7+
def NDMatrix (α : Type u) : List Nat → Type u
8+
| [] => α
9+
| (a::l) => Fin a → NDMatrix α l
10+
11+
def NDMatrix.map {α : Type u} (f : α → α) : {shape : List Nat} → NDMatrix α shape → NDMatrix α shape
12+
| [] => f
13+
| _::l => fun x => fun i => NDMatrix.map f (shape := l) (x i)
14+
15+
def NDMatrix.sum {α : Type u} [Zero α] [Add α] : {shape : List Nat} → NDMatrix α shape → α
16+
| [] => id
17+
| _::l => fun x => Fin.sum (fun i => NDMatrix.sum (shape := l) (x i))
18+
19+
def NDMatrix.entrywise {α : Type u} (f : α → α → α) : {shape : List Nat} → NDMatrix α shape → NDMatrix α shape → NDMatrix α shape
20+
| [] => f
21+
| _::l =>
22+
fun x y =>
23+
fun i => NDMatrix.entrywise f (shape := l) (x i) (y i)
24+
25+
instance instNonempty {α : Type u} [Nonempty α] : {shape : List Nat} → Nonempty (NDMatrix α shape)
26+
| [] => inferInstance
27+
| _ :: shape => letI := instNonempty (α := α) (shape := shape); inferInstance
28+
29+
instance {α : Type u} [Sub α] {shape : (List Nat)} : Sub (NDMatrix α shape) := ⟨NDMatrix.entrywise (· - · : α → α → α)⟩
30+
31+
32+
instance {α : Type u} [Mul α] {shape : (List Nat)} : Mul (NDMatrix α shape) := ⟨NDMatrix.entrywise (· * · : α → α → α)⟩
33+
34+
instance{α : Type u} [Sub α] : Sub (NDMatrix α [0]) := by infer_instance
35+
36+
37+
def List.shapesize (shape : List Nat) : Nat := List.foldr (· * ·) 1 shape
38+
39+
def NN (α : Type u) (shape₁ shape₂ : List Nat) := NDMatrix α shape₁ → NDMatrix α shape₂
40+
41+
theorem Function.comp_invFun {α : Sort u} {β} [Nonempty α] (f : α → β) (hf : Function.Surjective f) : f ∘ Function.invFun f = id := by sorry
42+
43+
def NN.implBy {α : Type u} [Nonempty α] {shape₁ shape₂ : List Nat} (nn : NN α shape₁ shape₂) (view₁ : NDMatrix α shape₁ → Vector α shape₁.shapesize) (view₂ : NDMatrix α shape₂ → Vector α shape₂.shapesize) (impl : Vector α shape₁.shapesize → Vector α shape₂.shapesize) : Prop := nn = (Function.invFun view₂) ∘ impl ∘ view₁
44+
45+
notation:50 nn " ⊧[" view₁ "," view₂ "] " impl:max =>
46+
NN.implBy nn view₁ view₂ impl
47+
48+
theorem NN.comp_implBy {α : Type u} [Nonempty α] {shape₁ shape₂ shape₃ : List Nat} (nn₁ : NN α shape₁ shape₂) (nn₂ : NN α shape₂ shape₃) (view₁ : NDMatrix α shape₁ → Vector α shape₁.shapesize) (view₂ : NDMatrix α shape₂ → Vector α shape₂.shapesize) (hview₂ : Function.Surjective view₂) (view₃ : _) (impl₁ : Vector α shape₁.shapesize → Vector α shape₂.shapesize) (impl₂ : Vector α shape₂.shapesize → Vector α shape₃.shapesize) : nn₁ ⊧[view₁, view₂] impl₁ → nn₂ ⊧[view₂, view₃] impl₂ → (nn₂ ∘ nn₁) ⊧[view₁, view₃] (impl₂ ∘ impl₁) := by
49+
unfold implBy
50+
intro h1 h2
51+
rw [h1, h2]
52+
calc
53+
_ = Function.invFun view₃ ∘ impl₂ ∘ (view₂ ∘ Function.invFun view₂) ∘ impl₁ ∘ view₁ := by simp [Function.comp_assoc]
54+
_ = Function.invFun view₃ ∘ impl₂ ∘ id ∘ impl₁ ∘ view₁ := by
55+
grind [Function.comp_invFun]

Pullback/NN/Transformer.lean

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import Mathlib
2+
import Pullback.NN.Basic
3+
4+
noncomputable section
5+
6+
variable {l d : Nat} (ε : ℝ)
7+
8+
def mean (x : NDMatrix ℝ [d]) : ℝ :=
9+
(1 / d•1) * Fin.sum (fun i => x i)
10+
11+
def var (x : NDMatrix ℝ [d]) : ℝ :=
12+
(1 / d) * Fin.sum (fun i : Fin d => (x i - mean x) ^ 2)
13+
14+
def rms (x : NDMatrix ℝ [d]) : ℝ :=
15+
Real.sqrt ((1 / d) * Fin.sum (fun i : Fin d => (x i) ^ 2))
16+
17+
def layerNorm
18+
(γ β : NDMatrix ℝ [d]) : NN ℝ [d] [d] := fun x =>
19+
fun i =>
20+
((x i - mean x) / Real.sqrt (var x + ε)) * γ i + β i
21+
22+
def biasNorm
23+
(γ β : NDMatrix ℝ [d]) : NN ℝ [d] [d] := fun x =>
24+
fun i =>
25+
x i / (rms (x - β)) * Real.exp (γ i)
26+
27+
/-
28+
attention for a single query
29+
TODO :: add softmax and multiplication by values
30+
-/
31+
def attention (keys : Array (NDMatrix ℝ [d])) (query : NDMatrix ℝ [d]) : Array ℝ := keys.map (fun key => (key * query).sum)
32+
33+
-- /-
34+
-- batches attention computation across all keys from the given token sequence
35+
-- -/
36+
-- def attentionSeqBatch (keys : NDMatrix ℝ [l, d]) (queries : NDMatrix ℝ [l, d]) : NDMatrix ℝ [l, l] := queries.transpose.matmul keys

lake-manifest.json

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,95 @@
11
{"version": "1.1.0",
22
"packagesDir": ".lake/packages",
33
"packages":
4-
[{"url": "https://github.com/leanprover-community/quote4.git",
4+
[{"url": "https://github.com/leanprover-community/mathlib4",
5+
"type": "git",
6+
"subDir": null,
7+
"scope": "leanprover-community",
8+
"rev": "4b0276049bfa34119c2049b334d59545ab46cb98",
9+
"name": "mathlib",
10+
"manifestFile": "lake-manifest.json",
11+
"inputRev": "master",
12+
"inherited": false,
13+
"configFile": "lakefile.lean"},
14+
{"url": "https://github.com/leanprover-community/quote4.git",
515
"type": "git",
616
"subDir": null,
717
"scope": "",
8-
"rev": "95c2f8afe09d9e49d3cacca667261da04f7f93f7",
18+
"rev": "23324752757bf28124a518ec284044c8db79fee5",
919
"name": "Qq",
1020
"manifestFile": "lake-manifest.json",
1121
"inputRev": null,
1222
"inherited": false,
23+
"configFile": "lakefile.toml"},
24+
{"url": "https://github.com/leanprover-community/plausible",
25+
"type": "git",
26+
"subDir": null,
27+
"scope": "leanprover-community",
28+
"rev": "7311586e1a56af887b1081d05e80c11b6c41d212",
29+
"name": "plausible",
30+
"manifestFile": "lake-manifest.json",
31+
"inputRev": "main",
32+
"inherited": true,
33+
"configFile": "lakefile.toml"},
34+
{"url": "https://github.com/leanprover-community/LeanSearchClient",
35+
"type": "git",
36+
"subDir": null,
37+
"scope": "leanprover-community",
38+
"rev": "5ce7f0a355f522a952a3d678d696bd563bb4fd28",
39+
"name": "LeanSearchClient",
40+
"manifestFile": "lake-manifest.json",
41+
"inputRev": "main",
42+
"inherited": true,
43+
"configFile": "lakefile.toml"},
44+
{"url": "https://github.com/leanprover-community/import-graph",
45+
"type": "git",
46+
"subDir": null,
47+
"scope": "leanprover-community",
48+
"rev": "875ad9d88ed684e39c16bdea260e6ecfa15afd60",
49+
"name": "importGraph",
50+
"manifestFile": "lake-manifest.json",
51+
"inputRev": "main",
52+
"inherited": true,
53+
"configFile": "lakefile.toml"},
54+
{"url": "https://github.com/leanprover-community/ProofWidgets4",
55+
"type": "git",
56+
"subDir": null,
57+
"scope": "leanprover-community",
58+
"rev": "6d65c6e0a25b8a52c13c3adeb63ecde3bfbb6294",
59+
"name": "proofwidgets",
60+
"manifestFile": "lake-manifest.json",
61+
"inputRev": "v0.0.86",
62+
"inherited": true,
63+
"configFile": "lakefile.lean"},
64+
{"url": "https://github.com/leanprover-community/aesop",
65+
"type": "git",
66+
"subDir": null,
67+
"scope": "leanprover-community",
68+
"rev": "f08e838d4f9aea519f3cde06260cfb686fd4bab0",
69+
"name": "aesop",
70+
"manifestFile": "lake-manifest.json",
71+
"inputRev": "master",
72+
"inherited": true,
73+
"configFile": "lakefile.toml"},
74+
{"url": "https://github.com/leanprover-community/batteries",
75+
"type": "git",
76+
"subDir": null,
77+
"scope": "leanprover-community",
78+
"rev": "cabbb5a025bfbbc50af9184ed2dfdde6ea4f53a7",
79+
"name": "batteries",
80+
"manifestFile": "lake-manifest.json",
81+
"inputRev": "main",
82+
"inherited": true,
83+
"configFile": "lakefile.toml"},
84+
{"url": "https://github.com/leanprover/lean4-cli",
85+
"type": "git",
86+
"subDir": null,
87+
"scope": "leanprover",
88+
"rev": "28e0856d4424863a85b18f38868c5420c55f9bae",
89+
"name": "Cli",
90+
"manifestFile": "lake-manifest.json",
91+
"inputRev": "v4.28.0-rc1",
92+
"inherited": true,
1393
"configFile": "lakefile.toml"}],
1494
"name": "pullback",
1595
"lakeDir": ".lake"}

lakefile.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,9 @@ root = "Main"
1111

1212
[[require]]
1313
name = "Qq"
14-
git = "https://github.com/leanprover-community/quote4.git"
14+
git = "https://github.com/leanprover-community/quote4.git"
15+
16+
17+
[[require]]
18+
name = "mathlib"
19+
scope = "leanprover-community"

lean-toolchain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
leanprover/lean4:v4.25.0-rc2
1+
leanprover/lean4:v4.28.0-rc1

0 commit comments

Comments
 (0)