feat: Add reverse/forward counterparts#1036
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1036 +/- ##
==========================================
+ Coverage 98.21% 98.95% +0.73%
==========================================
Files 138 116 -22
Lines 8131 5915 -2216
==========================================
- Hits 7986 5853 -2133
+ Misses 145 62 -83
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
gdalle
left a comment
There was a problem hiding this comment.
Thank you for getting this started!
Most test failures are a red herring (test CI runs by forcing the highest possible version of each package, which here leads to incompatibilities between the recent OrderedCollections v2.0 and the not-yet-updated DataFrames)
| @@ -0,0 +1,11 @@ | |||
| ## Backend counterparts | |||
|
|
|||
| # Pin the mode while preserving the function annotation type `A`. | |||
There was a problem hiding this comment.
Actually for Enzyme we have to be a bit careful and copy the mode faithfully: if you replace M with just EnzymeCore.Forward, you lose all the other attributes of that mode. See the definitions in this file to understand better. Typically, what we need here is something like
function DI.forward_counterpart(
::AutoEnzyme{
<:ReverseMode{
ReturnPrimal,RuntimeActivity,StrongZero,ABI,Holomorphic,ErrIfFuncWritten
},
A,
},
) where {ReturnPrimal,RuntimeActivity,StrongZero,ABI,Holomorphic,ErrIfFuncWritten,A}
return AutoEnzyme(;
mode=ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero},
function_annotation=A,
)
endIt relies on Enzyme internals so it's not ideal (see EnzymeAD/Enzyme.jl#1881), but the fact that I put <:ReverseMode in the signature should at least ensure that additional type parameters won't suddenly break DI.
| """ | ||
| function forward_counterpart(backend::AbstractADType) | ||
| mode(backend) isa ReverseMode && | ||
| throw(ArgumentError("No forward-mode counterpart known for `$backend`.")) |
There was a problem hiding this comment.
I'd rather return the backend itself (possibly with a @warn?) since DI allows forward-mode backends to execute pullbacks and vice-versa
| Return the forward-mode counterpart of `backend`, if it exists. | ||
| """ | ||
| function forward_counterpart(backend::AbstractADType) | ||
| mode(backend) isa ReverseMode && |
There was a problem hiding this comment.
You're forgetting ForwardOrReverseMode and SymbolicMode
| @test_throws MethodError pullback_performance(backend) | ||
| end | ||
|
|
||
| @testset "Counterparts" begin |
There was a problem hiding this comment.
We should also test this for every backend in the corresponding test, same as I do for is_available. Might be time to automate that with a small script in DIT
| @test DifferentiationInterface.forward_counterpart(AutoMooncake()) === AutoMooncakeForward() | ||
| @test DifferentiationInterface.forward_counterpart(AutoMooncake(; config)) === |
There was a problem hiding this comment.
I think you forgot to write this in the Mooncake extension?
Also, we want to test that we carry over the config
| rev = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const) | ||
| fwd = DifferentiationInterface.forward_counterpart(rev) | ||
| @test ADTypes.mode(fwd) isa ADTypes.ForwardMode | ||
| @test fwd isa AutoEnzyme{<:Any, Enzyme.Const} # function annotation preserved |
There was a problem hiding this comment.
We also want to test that mode properties are preserved (see long remark above)
PR to add the forward/backward counterparts discussed. I wasn't sure if for a generic backend if there was no reasonable counterpart to return e.g., if one did
reverse_counterpart(AutoForwardDiff)if the code should throw an ArgumentError or if it should just return the backend. I opted for the formerResolves #1025.