Skip to content

Commit 11f527a

Browse files
committed
Fix bugs
1 parent 074389b commit 11f527a

5 files changed

Lines changed: 20 additions & 18 deletions

File tree

src/Utilities.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,5 +475,6 @@ end
475475
include("Utilities/canonical_correlation.jl")
476476
include("Utilities/decorrelator.jl")
477477
include("Utilities/elementwise_scaler.jl")
478+
include("Utilities/likelihood_informed.jl")
478479

479480
end # module

src/Utilities/canonical_correlation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ initialize_processor!(
181181
"""
182182
$(TYPEDSIGNATURES)
183183
184-
Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix
184+
Apply the `CanonicalCorrelation` encoder, on a columns-are-data matrix or a data vector
185185
"""
186-
function encode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix}
186+
function encode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
187187
data_mean = get_data_mean(cc)[1]
188188
encoder_mat = get_encoder_mat(cc)[1]
189189
return encoder_mat * (data .- data_mean)
@@ -192,9 +192,9 @@ end
192192
"""
193193
$(TYPEDSIGNATURES)
194194
195-
Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix
195+
Apply the `CanonicalCorrelation` decoder, on a columns-are-data matrix or a data vector
196196
"""
197-
function decode_data(cc::CanonicalCorrelation, data::MM) where {MM <: AbstractMatrix}
197+
function decode_data(cc::CanonicalCorrelation, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
198198
data_mean = get_data_mean(cc)[1]
199199
decoder_mat = get_decoder_mat(cc)[1]
200200
return decoder_mat * data .+ data_mean

src/Utilities/decorrelator.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ end
187187
"""
188188
$(TYPEDSIGNATURES)
189189
190-
Apply the `Decorrelator` encoder, on a columns-are-data matrix
190+
Apply the `Decorrelator` encoder, on a columns-are-data matrix or a data vector
191191
"""
192-
function encode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix}
192+
function encode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
193193
data_mean = get_data_mean(dd)[1]
194194
encoder_mat = get_encoder_mat(dd)[1]
195195
return encoder_mat * (data .- data_mean)
@@ -198,9 +198,9 @@ end
198198
"""
199199
$(TYPEDSIGNATURES)
200200
201-
Apply the `Decorrelator` decoder, on a columns-are-data matrix
201+
Apply the `Decorrelator` decoder, on a columns-are-data matrix or a data vector
202202
"""
203-
function decode_data(dd::Decorrelator, data::MM) where {MM <: AbstractMatrix}
203+
function decode_data(dd::Decorrelator, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
204204
data_mean = get_data_mean(dd)[1]
205205
decoder_mat = get_decoder_mat(dd)[1]
206206
return decoder_mat * data .+ data_mean

src/Utilities/elementwise_scaler.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ end
126126
"""
127127
$(TYPEDSIGNATURES)
128128
129-
Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix
129+
Apply the `ElementwiseScaler` encoder, on a columns-are-data matrix or a data vector
130130
"""
131-
function encode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix}
131+
function encode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
132132
out = deepcopy(data)
133133
for i in 1:size(out, 1)
134134
out[i, :] .-= get_shift(es)[i]
@@ -140,9 +140,9 @@ end
140140
"""
141141
$(TYPEDSIGNATURES)
142142
143-
Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix
143+
Apply the `ElementwiseScaler` decoder, on a columns-are-data matrix or a data vector
144144
"""
145-
function decode_data(es::ElementwiseScaler, data::MM) where {MM <: AbstractMatrix}
145+
function decode_data(es::ElementwiseScaler, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
146146
out = deepcopy(data)
147147
for i in 1:size(out, 1)
148148
out[i, :] *= get_scale(es)[i]

src/Utilities/likelihood_informed.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
using Manifolds, Manopt
44

5+
export LikelihoodInformed, likelihood_informed
6+
57
mutable struct LikelihoodInformed{FT <: Real} <: PairedDataContainerProcessor
68
encoder_mat::Union{Nothing, AbstractMatrix}
79
decoder_mat::Union{Nothing, AbstractMatrix}
@@ -33,7 +35,7 @@ function initialize_processor!(
3335
output_structure_vectors::Dict{Symbol, <:StructureVector},
3436
apply_to::AbstractString,
3537
) where {MM <: AbstractMatrix}
36-
output_dim = size(out_data, 2)
38+
output_dim = size(out_data, 1)
3739

3840
if isnothing(get_encoder_mat(li))
3941
α = li.α
@@ -116,7 +118,6 @@ function initialize_processor!(
116118
egrad = (_, Vs) -> begin
117119
B = Vs * inv(Vs' * obs_noise_cov * Vs) * Vs'
118120
prec = noise_cov_inv - B
119-
120121

121122
-2mean(begin
122123
A = ((1-α)I + α^2 * (y - g)*(y - g)')
@@ -154,19 +155,19 @@ end
154155
"""
155156
$(TYPEDSIGNATURES)
156157
157-
Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix
158+
Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix or a data vector
158159
"""
159-
function encode_data(li::LikelihoodInformed, data::MM) where {MM <: AbstractMatrix}
160+
function encode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
160161
encoder_mat = get_encoder_mat(li)
161162
return encoder_mat * data
162163
end
163164

164165
"""
165166
$(TYPEDSIGNATURES)
166167
167-
Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix
168+
Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix or a data vector
168169
"""
169-
function decode_data(li::LikelihoodInformed, data::MM) where {MM <: AbstractMatrix}
170+
function decode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}}
170171
decoder_mat = get_decoder_mat(li)
171172
return decoder_mat * data
172173
end

0 commit comments

Comments
 (0)