-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathDSTOp.jl
More file actions
74 lines (62 loc) · 1.92 KB
/
DSTOp.jl
File metadata and controls
74 lines (62 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
export DSTOpImpl
mutable struct DSTOpImpl{T, vecT, P, IP} <: DSTOp{T}
const nrow :: Int
const ncol :: Int
const symmetric :: Bool
const hermitian :: Bool
const prod! :: Function
const tprod! :: Nothing
const ctprod! :: Function
nprod :: Int
ntprod :: Int
nctprod :: Int
Mv :: vecT
Mtu :: vecT
const plan :: P
const iplan :: IP
end
LinearOperators.storage_type(::DSTOpImpl{T, vecT}) where {T, vecT} = vecT
"""
DSTOp(T::Type, shape::Tuple)
returns a `LinearOperator` which performs a DST on a given input array.
# Arguments:
* `T::Type` - type of the array to transform
* `shape::Tuple` - size of the array to transform
"""
function LinearOperatorCollection.DSTOp(T::Type; shape::Tuple, S = Array{T})
tmp=similar(S(undef, 0), shape...)
plan = FFTW.plan_r2r!(tmp,FFTW.RODFT10)
iplan = FFTW.plan_r2r!(tmp,FFTW.RODFT01)
w = weights(shape, T)
return DSTOpImpl{T, S, typeof(plan), typeof(iplan)}(prod(shape), prod(shape), true, false
, (res,x) -> dst_multiply!(res,plan,x,tmp,w)
, nothing
, (res,x) -> dst_bmultiply!(res,iplan,x,tmp,w)
, 0, 0, 0, S(undef,0), S(undef, 0)
, plan
, iplan)
end
function weights(s, T::Type)
w = ones(T,s...)./T(sqrt(8*prod(s)))
w[s[1],:,:]./= T(sqrt(2))
if length(s)>1
w[:,s[2],:]./= T(sqrt(2))
if length(s)>2
w[:,:,s[3]]./= T(sqrt(2))
end
end
return reshape(w,prod(s))
end
function dst_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{T}, tmp::AbstractArray{T,D}, weights::AbstractVector{T}) where {T,P,D}
tmp[:] .= x
plan * tmp
res .= vec(tmp).*weights
end
function dst_bmultiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{T}, tmp::AbstractArray{T,D}, weights::AbstractVector{T}) where {T,P,D}
tmp[:] .= x./weights
plan * tmp
res[:] .= vec(tmp)./(8*length(tmp))
end
function Base.copy(S::DSTOpImpl)
return DSTOpImpl(eltype(S), size(S.plan))
end