Skip to content

Improve type stability of LayerNorm and Dropout#2005

Open
ToucheSir wants to merge 3 commits intomasterfrom
bc/norm-type-stability
Open

Improve type stability of LayerNorm and Dropout#2005
ToucheSir wants to merge 3 commits intomasterfrom
bc/norm-type-stability

Conversation

@ToucheSir
Copy link
Copy Markdown
Member

@ToucheSir ToucheSir commented Jun 23, 2022

These two layers made use of explicit or implicit control flow (e.g. default keyword argument values) which Zygote does not like. This PR is essentially a set of small hacks to work around that.

Any ideas on how to avoid return_type in _dropout would be much appreciated, but for now it seems to work.

TODO benchmarks.

PR Checklist

  • Entry in NEWS.md

Comment thread src/layers/normalise.jl Outdated
@ToucheSir ToucheSir force-pushed the bc/norm-type-stability branch from 25f0a1b to 9259e4a Compare June 24, 2022 19:54
@ToucheSir
Copy link
Copy Markdown
Member Author

TTFG timings using the following snippet:

Test code
using Metalhead, Flux, Zygote
using Metalhead: ChannelLayerNorm

model = ConvNeXt(:tiny; inchannels=1, nclasses=1).layers
# ChannelLayerNorm isn't type stable yet (for the same reason as LayerNorm wasn't),
# So remove it for this demo
model = fmap(Returns(identity), model; exclude=Base.Fix2(isa, ChannelLayerNorm))

# display(model); println()

loss(m, x) = sum(m(x))

inputs = randn(Float32, 32, 32, 1, 1)
# @time loss(model, inputs)
# @time loss(model, inputs)

loss_grad(m, x) = gradient((m, x) -> loss(m, x), m, x)

@time loss_grad(model, inputs)
# @time loss_grad(model, inputs)
julia> @time loss_grad(model, inputs)
 34.835647 seconds (87.12 M allocations: 4.701 GiB, 3.14% gc time, 99.38% compilation time) # 0.13.3
 30.679322 seconds (78.88 M allocations: 4.300 GiB, 3.46% gc time, 98.96% compilation time) # this PR

Replacing the Chain{Vector} with a Chain{Tuple} creates a larger gap:

julia> @time loss_grad(model, inputs)
 79.846248 seconds (98.87 M allocations: 5.243 GiB, 1.68% gc time, 99.67% compilation time) # 0.13.3
 63.024710 seconds (79.23 M allocations: 4.245 GiB, 1.92% gc time, 99.45% compilation time) # this PR
 52.838056 seconds (70.81 M allocations: 3.745 GiB, 1.98% gc time, 99.60% compilation time) # this PR + Zygote#1248

@ToucheSir
Copy link
Copy Markdown
Member Author

ToucheSir commented Aug 1, 2022

For kicks, here is Diffractor with JuliaDiff/ChainRules.jl#644:

julia> @time loss_grad(model, inputs)
 30.442982 seconds (92.61 M allocations: 4.148 GiB, 3.18% gc time, 89.07% compilation time) # tuple chain
 23.051121 seconds (88.06 M allocations: 3.920 GiB, 3.81% gc time, 85.11% compilation time) # vector chain, requires https://github.com/JuliaDiff/Diffractor.jl/pull/82

Re-enabling ChannelLayerNorm adds but ~1s to the total. Note that even the tuple Chain here is faster than any tested Zygote configuration.

Edit: added times for vector chains using a patched Diffractor.

@theabhirath
Copy link
Copy Markdown
Member

Does Diffractor already work with most Flux models (or at least those with built-in layers)? I was under the impression that it wasn't there yet 😅

@ToucheSir
Copy link
Copy Markdown
Member Author

Not OOTB, which is why that ChainRules PR is required.

@chengchingwen
Copy link
Copy Markdown
Member

@ToucheSir Could you try running the layer norm gradient with gpu? I have try that manual broadcast fusion before but CUDA.time said it actually allocated more gpu memory

@ToucheSir
Copy link
Copy Markdown
Member Author

You're right, it allocates one more time for over 2x the memory overhead. I also found this out the hard way recently while trying to fuse the RNN cell kernels for #2023, but forgot about the change here.

@ToucheSir ToucheSir force-pushed the bc/norm-type-stability branch from 9259e4a to 29ef2ff Compare August 1, 2022 05:06
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Aug 1, 2022

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 87.37%. Comparing base (d66d2c4) to head (29ef2ff).
⚠️ Report is 555 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2005      +/-   ##
==========================================
+ Coverage   87.10%   87.37%   +0.27%     
==========================================
  Files          20       20              
  Lines        1528     1553      +25     
==========================================
+ Hits         1331     1357      +26     
+ Misses        197      196       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@darsnack
Copy link
Copy Markdown
Member

Any updates on this (like benchmarks after unfusing)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants