Skip to content

feat: Add reverse/forward counterparts#1036

Open
rsenne wants to merge 1 commit into
JuliaDiff:mainfrom
rsenne:forward-reverse-helpers
Open

feat: Add reverse/forward counterparts#1036
rsenne wants to merge 1 commit into
JuliaDiff:mainfrom
rsenne:forward-reverse-helpers

Conversation

@rsenne

@rsenne rsenne commented Jun 30, 2026

Copy link
Copy Markdown

PR to add the forward/backward counterparts discussed. I wasn't sure if for a generic backend if there was no reasonable counterpart to return e.g., if one did reverse_counterpart(AutoForwardDiff) if the code should throw an ArgumentError or if it should just return the backend. I opted for the former

Resolves #1025.

@rsenne rsenne requested a review from gdalle as a code owner June 30, 2026 02:01
@rsenne rsenne changed the title Add reverse/forward counterparts feat: Add reverse/forward counterparts Jun 30, 2026
@codecov

codecov Bot commented Jun 30, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 98.95%. Comparing base (5572e56) to head (fd44232).

Files with missing lines Patch % Lines
DifferentiationInterface/src/utils/counterparts.jl 87.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1036      +/-   ##
==========================================
+ Coverage   98.21%   98.95%   +0.73%     
==========================================
  Files         138      116      -22     
  Lines        8131     5915    -2216     
==========================================
- Hits         7986     5853    -2133     
+ Misses        145       62      -83     
Flag Coverage Δ
DI 98.95% <91.66%> (-0.02%) ⬇️
DIT ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle gdalle left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for getting this started!
Most test failures are a red herring (test CI runs by forcing the highest possible version of each package, which here leads to incompatibilities between the recent OrderedCollections v2.0 and the not-yet-updated DataFrames)

@@ -0,0 +1,11 @@
## Backend counterparts

# Pin the mode while preserving the function annotation type `A`.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually for Enzyme we have to be a bit careful and copy the mode faithfully: if you replace M with just EnzymeCore.Forward, you lose all the other attributes of that mode. See the definitions in this file to understand better. Typically, what we need here is something like

function DI.forward_counterpart(
    ::AutoEnzyme{
        <:ReverseMode{
            ReturnPrimal,RuntimeActivity,StrongZero,ABI,Holomorphic,ErrIfFuncWritten
        },
        A,
    },
) where {ReturnPrimal,RuntimeActivity,StrongZero,ABI,Holomorphic,ErrIfFuncWritten,A}
    return AutoEnzyme(;
        mode=ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,StrongZero},
        function_annotation=A,
    )
end

It relies on Enzyme internals so it's not ideal (see EnzymeAD/Enzyme.jl#1881), but the fact that I put <:ReverseMode in the signature should at least ensure that additional type parameters won't suddenly break DI.

"""
function forward_counterpart(backend::AbstractADType)
mode(backend) isa ReverseMode &&
throw(ArgumentError("No forward-mode counterpart known for `$backend`."))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather return the backend itself (possibly with a @warn?) since DI allows forward-mode backends to execute pullbacks and vice-versa

Return the forward-mode counterpart of `backend`, if it exists.
"""
function forward_counterpart(backend::AbstractADType)
mode(backend) isa ReverseMode &&

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're forgetting ForwardOrReverseMode and SymbolicMode

@test_throws MethodError pullback_performance(backend)
end

@testset "Counterparts" begin

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test this for every backend in the corresponding test, same as I do for is_available. Might be time to automate that with a small script in DIT

Comment on lines +26 to +27
@test DifferentiationInterface.forward_counterpart(AutoMooncake()) === AutoMooncakeForward()
@test DifferentiationInterface.forward_counterpart(AutoMooncake(; config)) ===

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot to write this in the Mooncake extension?
Also, we want to test that we carry over the config

rev = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const)
fwd = DifferentiationInterface.forward_counterpart(rev)
@test ADTypes.mode(fwd) isa ADTypes.ForwardMode
@test fwd isa AutoEnzyme{<:Any, Enzyme.Const} # function annotation preserved

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to test that mode properties are preserved (see long remark above)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Backend counterparts

2 participants