-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathscaletransform.jl
More file actions
35 lines (24 loc) · 884 Bytes
/
scaletransform.jl
File metadata and controls
35 lines (24 loc) · 884 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
ScaleTransform(l::Real)
Transformation that multiplies the input elementwise with `l`.
# Examples
```jldoctest
julia> l = rand(); t = ScaleTransform(l); X = rand(100, 10);
julia> map(t, ColVecs(X)) == ColVecs(l .* X)
true
```
"""
struct ScaleTransform{T<:Real} <: Transform
s::Vector{T}
end
function ScaleTransform(s::T=1.0) where {T<:Real}
return ScaleTransform{T}([s])
end
@functor ScaleTransform
set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]
(t::ScaleTransform)(x) = only(t.s) * x
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")")