Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[extensions]
AccessorsExt = "Accessors"
ChangesOfVariablesExt = "ChangesOfVariables"
InverseFunctionsExt = "InverseFunctions"
ReactantExt = "Reactant"

[compat]
Accessors = "0.1.42"
Expand All @@ -36,6 +38,7 @@ InverseFunctions = "0.1"
LinearAlgebra = "1.6"
LogExpFunctions = "0.3"
Random = "1.6"
Reactant = "0.2"
StaticArrays = "1"
julia = "1.10"

Expand Down
12 changes: 12 additions & 0 deletions ext/ReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module ReactantExt
using TransformVariables
using Reactant


Base.@propagate_inbounds function TransformVariables.tv_getindex(a::Reactant.AnyTracedRArray, i::Integer)
@allowscalar a[i]
end

TransformVariables._ensure_float(x::Type{T}) where {T<:Reactant.TracedRNumber} = T

end
2 changes: 1 addition & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ $(SIGNATURES)

Initial value for log Jacobian calculations.
"""
logjac_zero(::LogJac, ::Type{T}) where {T<:Real} = log(one(T))
logjac_zero(::LogJac, ::Type{T}) where {T<:Number} = log(one(T))

logjac_zero(::NoLogJac, _) = NOLOGJAC

Expand Down
38 changes: 19 additions & 19 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ abstract type ScalarTransform <: AbstractTransform end

dimension(::ScalarTransform) = 1

function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int)
transform(t, @inbounds x[index]), flag, index + 1
function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised that you need this, since transform_with is called internally and index is an integer.

Can you explain what the actual type is that you need here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes in Reactant you can get a TracedNumber{Integer} if I loop through this and the array x is a AbstractTraced array.

transform(t, @inbounds tv_getindex(x, index)), flag, index + 1
end

function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int)
transform_and_logjac(t, @inbounds x[index])..., index + 1
function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

transform_and_logjac(t, @inbounds tv_getindex(x, index))..., index + 1
end

function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y)
Expand All @@ -43,15 +43,15 @@ Identity ``x ↦ x``.
"""
struct Identity <: ScalarTransform end

transform(::Identity, x::Real) = x
transform(::Identity, x::Number) = x
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle I am OK with widening Real here and in other places, the intent is to exclude Complex. It unfortunate that Base does not have an intermediate type for this purpose, but we can use Number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I agree it is very sad.


transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::Identity, ::Type{T}) where T = T

inverse(::Identity, x::Number) = x

inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

####
#### elementary scalar transforms
Expand All @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real
"""
struct TVExp <: ScalarTransform end

transform(::TVExp, x::Real) = exp(x)
transform(::TVExp, x::Number) = exp(x)

transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x
transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x

inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct TVLogistic <: ScalarTransform end

transform(::TVLogistic, x::Real) = logistic(x)
transform(::TVLogistic, x::Number) = logistic(x)

transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)
transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x)

inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -100,13 +100,13 @@ $(TYPEDEF)

Shift transformation `x ↦ x + shift`.
"""
struct TVShift{T <: Real} <: ScalarTransform
struct TVShift{T} <: ScalarTransform
shift::T
end

transform(t::TVShift, x::Real) = x + t.shift
transform(t::TVShift, x::Number) = x + t.shift

transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S))

Expand All @@ -129,15 +129,15 @@ end

TVScale(scale::T) where {T} = TVScale{T}(scale)

transform(t::TVScale, x::Real) = t.scale * x
transform(t::TVScale, x::Number) = t.scale * x

transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)
transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale)

inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S))

inverse(t::TVScale, x::Number) = x / t.scale

inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)
inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale)

"""
$(TYPEDEF)
Expand All @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`.
struct TVNeg <: ScalarTransform
end

transform(::TVNeg, x::Real) = -x
transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform(::TVNeg, x::Number) = -x
transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T))
inverse(::TVNeg, x::Number) = -x
Expand Down
7 changes: 6 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
### logistic and logit
###

function logistic_logjac(x::Real)
# Should I be this conservative?
Base.@propagate_inbounds function tv_getindex(a::Any, i::Integer)
return a[i]
end

function logistic_logjac(x::Number)
mx = -abs(x)
mx - 2*log1pexp(mx)
end
Expand Down
Loading