From 1b1308b02a87377b3f0925521e160dec65ed0bb0 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 26 Jan 2026 14:24:37 -0500 Subject: [PATCH 1/3] Loosen types for Reactant --- src/scalar.jl | 34 +++++++++++++++++----------------- src/utilities.jl | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/scalar.jl b/src/scalar.jl index 093a36d..bd311da 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -17,11 +17,11 @@ abstract type ScalarTransform <: AbstractTransform end dimension(::ScalarTransform) = 1 -function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int) +function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index) transform(t, @inbounds x[index]), flag, index + 1 end -function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int) +function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index) transform_and_logjac(t, @inbounds x[index])..., index + 1 end @@ -43,15 +43,15 @@ Identity ``x ↦ x``. """ struct Identity <: ScalarTransform end -transform(::Identity, x::Real) = x +transform(::Identity, x::Number) = x -transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::Identity, ::Type{T}) where T = T inverse(::Identity, x::Number) = x -inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) #### #### elementary scalar transforms @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real """ struct TVExp <: ScalarTransform end -transform(::TVExp, x::Real) = exp(x) +transform(::TVExp, x::Number) = exp(x) -transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x +transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T) @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). """ struct TVLogistic <: ScalarTransform end -transform(::TVLogistic, x::Real) = logistic(x) +transform(::TVLogistic, x::Number) = logistic(x) -transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) +transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x) inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T) @@ -100,13 +100,13 @@ $(TYPEDEF) Shift transformation `x ↦ x + shift`. """ -struct TVShift{T <: Real} <: ScalarTransform +struct TVShift{T} <: ScalarTransform shift::T end -transform(t::TVShift, x::Real) = x + t.shift +transform(t::TVShift, x::Number) = x + t.shift -transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S)) @@ -129,15 +129,15 @@ end TVScale(scale::T) where {T} = TVScale{T}(scale) -transform(t::TVScale, x::Real) = t.scale * x +transform(t::TVScale, x::Number) = t.scale * x -transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) +transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale) inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S)) inverse(t::TVScale, x::Number) = x / t.scale -inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale) +inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`. struct TVNeg <: ScalarTransform end -transform(::TVNeg, x::Real) = -x -transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform(::TVNeg, x::Number) = -x +transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x diff --git a/src/utilities.jl b/src/utilities.jl index 6462a2e..06a24e8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -2,7 +2,7 @@ ### logistic and logit ### -function logistic_logjac(x::Real) +function logistic_logjac(x::Number) mx = -abs(x) mx - 2*log1pexp(mx) end From ef2c0d828dfddf3a2043aa7c527d8318b934c933 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 4 Feb 2026 15:35:51 -0500 Subject: [PATCH 2/3] Add tv_getindex to allow for better Reactant integration --- Project.toml | 3 +++ ext/ReactantExt.jl | 12 ++++++++++++ src/aggregation.jl | 13 +++++++++++-- src/generic.jl | 2 +- src/scalar.jl | 4 ++-- src/utilities.jl | 5 +++++ 6 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 ext/ReactantExt.jl diff --git a/Project.toml b/Project.toml index cad39f4..1701179 100644 --- a/Project.toml +++ b/Project.toml @@ -18,11 +18,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] AccessorsExt = "Accessors" ChangesOfVariablesExt = "ChangesOfVariables" InverseFunctionsExt = "InverseFunctions" +ReactantExt = "Reactant" [compat] Accessors = "0.1.42" @@ -36,6 +38,7 @@ InverseFunctions = "0.1" LinearAlgebra = "1.6" LogExpFunctions = "0.3" Random = "1.6" +Reactant = "0.2" StaticArrays = "1" julia = "1.10" diff --git a/ext/ReactantExt.jl b/ext/ReactantExt.jl new file mode 100644 index 0000000..91b0061 --- /dev/null +++ b/ext/ReactantExt.jl @@ -0,0 +1,12 @@ +module ReactantExt +using TransformVariables +using Reactant + + +Base.@propagate_inbounds function TransformVariables.tv_getindex(a::Reactant.AnyTracedRArray, i::Integer) + @allowscalar a[i] +end + +TransformVariables._ensure_float(x::Type{T}) where {T<:Reactant.TracedRNumber} = T + +end \ No newline at end of file diff --git a/src/aggregation.jl b/src/aggregation.jl index ec9acda..efc6671 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -385,8 +385,17 @@ _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) = function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts) tfirst = first(ts) - yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index) - yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts)) + ofirst = transform_with(flag, tfirst, x, index) + # Strange Reactant thing here. It looks like it raises the output to an Array + yfirst = tv_getindex(ofirst, 1) + ℓfirst = tv_getindex(ofirst, 2) + index′ = tv_getindex(ofirst, 3) + + ofrest = _transform_tuple(flag, x, index′, Base.tail(ts)) + yrest = tv_getindex(ofrest, 1) + ℓrest = tv_getindex(ofrest, 2) + index′′ = tv_getindex(ofrest, 3) + (yfirst, yrest...), ℓfirst + ℓrest, index′′ end diff --git a/src/generic.jl b/src/generic.jl index 707c6be..2b4cfcd 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -50,7 +50,7 @@ $(SIGNATURES) Initial value for log Jacobian calculations. """ -logjac_zero(::LogJac, ::Type{T}) where {T<:Real} = log(one(T)) +logjac_zero(::LogJac, ::Type{T}) where {T<:Number} = log(one(T)) logjac_zero(::NoLogJac, _) = NOLOGJAC diff --git a/src/scalar.jl b/src/scalar.jl index bd311da..a94f969 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -18,11 +18,11 @@ abstract type ScalarTransform <: AbstractTransform end dimension(::ScalarTransform) = 1 function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index) - transform(t, @inbounds x[index]), flag, index + 1 + transform(t, @inbounds tv_getindex(x, index)), flag, index + 1 end function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index) - transform_and_logjac(t, @inbounds x[index])..., index + 1 + transform_and_logjac(t, @inbounds tv_getindex(x, index))..., index + 1 end function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y) diff --git a/src/utilities.jl b/src/utilities.jl index 06a24e8..09681a1 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -2,6 +2,11 @@ ### logistic and logit ### +# Should I be this conservative? +Base.@propagate_inbounds function tv_getindex(a::Any, i::Integer) + return a[i] +end + function logistic_logjac(x::Number) mx = -abs(x) mx - 2*log1pexp(mx) From 9ffba74406ae0814ff9663250d266f87e4837565 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 4 Feb 2026 15:59:23 -0500 Subject: [PATCH 3/3] Revert transform_tuple --- src/aggregation.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index efc6671..ec9acda 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -385,17 +385,8 @@ _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) = function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts) tfirst = first(ts) - ofirst = transform_with(flag, tfirst, x, index) - # Strange Reactant thing here. It looks like it raises the output to an Array - yfirst = tv_getindex(ofirst, 1) - ℓfirst = tv_getindex(ofirst, 2) - index′ = tv_getindex(ofirst, 3) - - ofrest = _transform_tuple(flag, x, index′, Base.tail(ts)) - yrest = tv_getindex(ofrest, 1) - ℓrest = tv_getindex(ofrest, 2) - index′′ = tv_getindex(ofrest, 3) - + yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index) + yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts)) (yfirst, yrest...), ℓfirst + ℓrest, index′′ end