RFC: don't automatically unthunk grad outputs#134
RFC: don't automatically unthunk grad outputs#134
Conversation
|
FWIW, JuliaDiff/Diffractor.jl#79 went the other way. There's something to be said for the high-level functions never giving you a thunk. You can of course write |
|
Either way, don't think any of this interferes with JuliaDiff/ChainRulesCore.jl#568? The mention of Diffractor using escape analysis instead initially excited me, but given how far that's (not) developed and where Diffractor sits in the AD performance ranking, seeing thunks bring tangible performance benefits is way more appealing right now. |
|
Is it correct to say that thunks only save compute and memory during differentiating, but eventually they need to be materialized and at this point they take the same amount of resources as eager execution would take? |
|
I think |
|
Yes. In practice using My Diffractor link above discusses many things, sorry. But also shows Yota (at the time) doing this right, here with
Is this some quirk of the |
|
I remember seeing that, which is why I was surprised when it didn't work on the ADTests example! The extra amount allocated (8MB on top of 65MB) is exactly the size of using Yota
using BenchmarkTools
using Random: seed!
seed!(123)
const bs = 4096
const f = 256
const h1 = 512
struct Linear{A,B}
w::A
b::B
end
(m::Linear)(x) = exp.(m.w * x .+ m.b)
let
w1 = randn(h1, f) .* 0.01;
b1 = randn(h1);
x = randn(f, bs) .* 0.01;
m = Linear(w1, b1)
loss(m) = sum(m(x))
@btime grad($loss, $m; seed=1.0); # 41.837 ms (64 allocations: 73.01 MiB)
tape_func = Yota.GRAD_CACHE |> values |> first
@bprofile $tape_func($loss, $m);
# VSCodeServer.view_profile()
endRunning Cthulhu on Which is clearly |
|
Here's my take on thunks. I believe, in short term we may use thunks to make micro-optimisations like this, but in longer term it's more advantageous to get rid of them. I expect Julia AD to eventually converge to 2-3 large engines with different tradeoffs. Perhaps, one similar to Diffractor with heavy compiler-level optimizations and targeting SciML, another more traditional based on computational graphs targeting conventional deep learning and maybe one more experimental. But all these engines will be pretty sophisticated, bringing a lot of optimisations by themselves. These optimisations may include, for example, gradient checkpointing or training with mixed precision. These are huge optimisations that can reduce training time and memory by several times - much more than thunks. And these optimisations are easier to implement without thunks involved. So in this bright future I imagine ChainRules to be a minimalistic repository of rules without any added machinery, and AD engines to take care of all the optimisations. At the same time, we are not there yet, and 11% memory reduction (8Mb / 65+8Mb) is quite a lot. If it's reproducible on other tests too (maybe not 11%, but 2-3% is still good), let's incorporate the change. |
|
In fact, we can indeed add an option to ignore gradients of some inputs and don't unthunk or even don't calculate only their derivatives. This should a non-breaking change. |
|
Thanks for the digging above. It sounds like the problem is I think that never returning thunks from user-facing functions is probably a good policy. You ask for a gradient, you get a gradient, not some weird internal delayed object. Re benchmarks, for typical neural network things, materialising the gradient with respect to the data is likely to always be a tiny effect. It only matters for the first layer, all subsequent layers cannot avoid this work. The places where something like thunks can have a large effect are (1) calculations with many static inputs along the way (e.g. suppose you wanted only the gradient with respect to |
|
Agreed with all the above. There are certainly cases where in-place accumulation could save a decent amount of memory ( |
This is the other change I made while optimizing https://github.com/jeremiedb/ADTests.jl/blob/main/experiments/yota/dense.jl. By not unthunking all outputs by default, we avoid materializing the gradient of the input
x. This saves both a significant amount of compute and memory. It's also highly breaking, hence the draft PR. Would there be any interest in having this as an option ingrador lower level API?