diff --git a/Project.toml b/Project.toml index cad39f49..1701179f 100644 --- a/Project.toml +++ b/Project.toml @@ -18,11 +18,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] AccessorsExt = "Accessors" ChangesOfVariablesExt = "ChangesOfVariables" InverseFunctionsExt = "InverseFunctions" +ReactantExt = "Reactant" [compat] Accessors = "0.1.42" @@ -36,6 +38,7 @@ InverseFunctions = "0.1" LinearAlgebra = "1.6" LogExpFunctions = "0.3" Random = "1.6" +Reactant = "0.2" StaticArrays = "1" julia = "1.10" diff --git a/ext/ReactantExt.jl b/ext/ReactantExt.jl new file mode 100644 index 00000000..91b00618 --- /dev/null +++ b/ext/ReactantExt.jl @@ -0,0 +1,12 @@ +module ReactantExt +using TransformVariables +using Reactant + + +Base.@propagate_inbounds function TransformVariables.tv_getindex(a::Reactant.AnyTracedRArray, i::Integer) + @allowscalar a[i] +end + +TransformVariables._ensure_float(x::Type{T}) where {T<:Reactant.TracedRNumber} = T + +end \ No newline at end of file diff --git a/src/generic.jl b/src/generic.jl index 707c6be4..2b4cfcdc 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -50,7 +50,7 @@ $(SIGNATURES) Initial value for log Jacobian calculations. """ -logjac_zero(::LogJac, ::Type{T}) where {T<:Real} = log(one(T)) +logjac_zero(::LogJac, ::Type{T}) where {T<:Number} = log(one(T)) logjac_zero(::NoLogJac, _) = NOLOGJAC diff --git a/src/scalar.jl b/src/scalar.jl index 093a36dd..a94f9695 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -17,12 +17,12 @@ abstract type ScalarTransform <: AbstractTransform end dimension(::ScalarTransform) = 1 -function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int) - transform(t, @inbounds x[index]), flag, index + 1 +function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index) + transform(t, @inbounds tv_getindex(x, index)), flag, index + 1 end -function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int) - transform_and_logjac(t, @inbounds x[index])..., index + 1 +function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index) + transform_and_logjac(t, @inbounds tv_getindex(x, index))..., index + 1 end function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y) @@ -43,15 +43,15 @@ Identity ``x ↦ x``. """ struct Identity <: ScalarTransform end -transform(::Identity, x::Real) = x +transform(::Identity, x::Number) = x -transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::Identity, ::Type{T}) where T = T inverse(::Identity, x::Number) = x -inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x)) +inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x)) #### #### elementary scalar transforms @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real """ struct TVExp <: ScalarTransform end -transform(::TVExp, x::Real) = exp(x) +transform(::TVExp, x::Number) = exp(x) -transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x +transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T) @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). """ struct TVLogistic <: ScalarTransform end -transform(::TVLogistic, x::Real) = logistic(x) +transform(::TVLogistic, x::Number) = logistic(x) -transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x) +transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x) inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T) @@ -100,13 +100,13 @@ $(TYPEDEF) Shift transformation `x ↦ x + shift`. """ -struct TVShift{T <: Real} <: ScalarTransform +struct TVShift{T} <: ScalarTransform shift::T end -transform(t::TVShift, x::Real) = x + t.shift +transform(t::TVShift, x::Number) = x + t.shift -transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S)) @@ -129,15 +129,15 @@ end TVScale(scale::T) where {T} = TVScale{T}(scale) -transform(t::TVScale, x::Real) = t.scale * x +transform(t::TVScale, x::Number) = t.scale * x -transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale) +transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale) inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S)) inverse(t::TVScale, x::Number) = x / t.scale -inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale) +inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale) """ $(TYPEDEF) @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`. struct TVNeg <: ScalarTransform end -transform(::TVNeg, x::Real) = -x -transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x)) +transform(::TVNeg, x::Number) = -x +transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x)) inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x diff --git a/src/utilities.jl b/src/utilities.jl index 6462a2e9..09681a13 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -2,7 +2,12 @@ ### logistic and logit ### -function logistic_logjac(x::Real) +# Should I be this conservative? +Base.@propagate_inbounds function tv_getindex(a::Any, i::Integer) + return a[i] +end + +function logistic_logjac(x::Number) mx = -abs(x) mx - 2*log1pexp(mx) end