Add VJP for cumulative max and min scans#3599
Open
devin-lai wants to merge 2 commits into
Open
Conversation
`mx.grad` through `cummax` and `cummin` previously raised because reverse-mode differentiation was not implemented for cumulative min/max scans. Route each output cotangent to the input element that owns the running extreme at that position, using the latest occurrence in scan order for ties. The owner index is reconstructed from the inclusive running extreme and then accumulated with `scatter_add_axis`; exclusive scans shift cotangents by one step in the scan direction. Add Python and C++ tests for max/min scans across forward, reverse, inclusive, and exclusive modes, including tie cases and non-uniform cotangents. Forward-mode JVP remains unimplemented, matching the existing `cumprod` behavior.
zcbenz
reviewed
May 29, 2026
Collaborator
zcbenz
left a comment
There was a problem hiding this comment.
- Local PyTorch cross-check: 24 VJP comparisons across
cummax/cummin, axes, reverse, inclusive/exclusive, ties, and weighted cotangents.
If cross-check had been done locally I think it would be very useful to run them in the python tests.
Move the local PyTorch comparison for cumulative max/min gradients into the Python autograd tests. The new test covers cummax and cummin over each axis, forward and reverse scans, inclusive and exclusive modes, tie cases, and weighted cotangents. PyTorch does not expose MLX's reverse or exclusive scan options directly, so the test models reverse scans with a flip and exclusive scans by shifting the inclusive result by one step. The test is skipped when PyTorch is unavailable, matching the other PyTorch reference tests in the suite.
Author
Thanks!!! I moved the local cross-check into the python autograd tests. It now covers cummax/cummin over both axes, reverse/inclusive/exclusive combinations, tie cases, and weighted cotangents, the test skips if PyTorch is not available, consistent with the other PyTorch reference tests in the suite. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
cummaxandcummininstead of raising at gradient time.Details
cummaxandcumminare public forward ops, butmx.gradthrough them previously raised because cumulative min/max VJP was not implemented. This follows the existingScan::vjpTODO: mark entries equal to the inclusive running extreme, scan over those indices to recover the owning input position, then accumulate cotangents withscatter_add_axis.For ties, the VJP uses the latest occurrence in scan order. That convention matches the owner selected by PyTorch
cummax/cumminbackward in the cases cross-checked locally. Exclusive scans first shift cotangents by one step in the scan direction because the current element is excluded from its own output.The implementation adds a small number of scan ops plus one
scatter_add_axisonly on the VJP path, which previously threw, so it does not regress existing differentiable workloads.Tests
/Users/ldy/Library/Python/3.11/bin/pre-commit run --files mlx/primitives.cpp python/tests/test_autograd.py tests/autograd_tests.cppPYTHONPATH=python:python/tests python3.11 -m unittest python.tests.test_autogradbuild_tests/tests/tests --test-case="test scan grads"cummax/cummin, axes, reverse, inclusive/exclusive, ties, and weighted cotangents.