diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index d1610ce24..92328b515 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -28,7 +28,7 @@ end function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) - if T === Bool # TODO use nondifftype here + if T === Bool || T === Union{} # TODO use nondifftype here # 1: Trivial case: non-differentiable output, e.g. `x .> 0` @debug("split broadcasting trivial", f, T) bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index d2993774d..de7516bfb 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -177,6 +177,20 @@ BT1 = Broadcast.BroadcastStyle(Tuple) end @testset "bugs" begin + @testset "broadcast over empty tuple" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/830 + y, bk = rrule(CFG, copy∘broadcasted, BT1, isone, ()) + @test y == () + @test bk(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent()) + + y2, bk2 = rrule(CFG, copy∘broadcasted, BT1, sin, ()) + @test y2 == () + @test bk2(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent()) + + # Multi-argument case + y3, bk3 = rrule(CFG, copy∘broadcasted, BT1, atan, (), ()) + @test y3 == () + @test bk3(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()) + end @testset "unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661 @test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type @test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)