-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathmulrules.jl
More file actions
51 lines (47 loc) · 1.3 KB
/
mulrules.jl
File metadata and controls
51 lines (47 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.
# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.
function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B'; mask=A)
∂B = mul(A', ΔΩ; mask=B)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end
function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
Ω, mulpullback = rrule(mul, A, B, Semirings.PLUS_TIMES)
pullback(ΔΩ) = mulpullback(ΔΩ)[1:3]
return Ω, pullback
end