-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathardtransform.jl
More file actions
44 lines (31 loc) · 1.18 KB
/
ardtransform.jl
File metadata and controls
44 lines (31 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
ARDTransform(v::AbstractVector)
Transformation that multiplies the input elementwise by `v`.
# Examples
```jldoctest
julia> v = rand(10); t = ARDTransform(v); X = rand(10, 100);
julia> map(t, ColVecs(X)) == ColVecs(v .* X)
true
```
"""
struct ARDTransform{Tv<:AbstractVector{<:Real}} <: Transform
v::Tv
end
"""
ARDTransform(s::Real, dims::Integer)
Create an [`ARDTransform`](@ref) with vector `fill(s, dims)`.
"""
ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims))
@functor ARDTransform
function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real}
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))"
return t.v .= ρ
end
dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
Base.show(io::IO, t::ARDTransform) = print(io, "ARD Transform (dims: ", dim(t), ")")