|
2 | 2 |
|
3 | 3 | using Manifolds, Manopt |
4 | 4 |
|
| 5 | +export LikelihoodInformed, likelihood_informed |
| 6 | + |
5 | 7 | mutable struct LikelihoodInformed{FT <: Real} <: PairedDataContainerProcessor |
6 | 8 | encoder_mat::Union{Nothing, AbstractMatrix} |
7 | 9 | decoder_mat::Union{Nothing, AbstractMatrix} |
@@ -33,7 +35,7 @@ function initialize_processor!( |
33 | 35 | output_structure_vectors::Dict{Symbol, <:StructureVector}, |
34 | 36 | apply_to::AbstractString, |
35 | 37 | ) where {MM <: AbstractMatrix} |
36 | | - output_dim = size(out_data, 2) |
| 38 | + output_dim = size(out_data, 1) |
37 | 39 |
|
38 | 40 | if isnothing(get_encoder_mat(li)) |
39 | 41 | α = li.α |
@@ -116,7 +118,6 @@ function initialize_processor!( |
116 | 118 | egrad = (_, Vs) -> begin |
117 | 119 | B = Vs * inv(Vs' * obs_noise_cov * Vs) * Vs' |
118 | 120 | prec = noise_cov_inv - B |
119 | | - |
120 | 121 |
|
121 | 122 | -2mean(begin |
122 | 123 | A = ((1-α)I + α^2 * (y - g)*(y - g)') |
@@ -154,19 +155,19 @@ end |
154 | 155 | """ |
155 | 156 | $(TYPEDSIGNATURES) |
156 | 157 |
|
157 | | -Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix |
| 158 | +Apply the `LikelihoodInformed` encoder, on a columns-are-data matrix or a data vector |
158 | 159 | """ |
159 | | -function encode_data(li::LikelihoodInformed, data::MM) where {MM <: AbstractMatrix} |
| 160 | +function encode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} |
160 | 161 | encoder_mat = get_encoder_mat(li) |
161 | 162 | return encoder_mat * data |
162 | 163 | end |
163 | 164 |
|
164 | 165 | """ |
165 | 166 | $(TYPEDSIGNATURES) |
166 | 167 |
|
167 | | -Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix |
| 168 | +Apply the `LikelihoodInformed` decoder, on a columns-are-data matrix or a data vector |
168 | 169 | """ |
169 | | -function decode_data(li::LikelihoodInformed, data::MM) where {MM <: AbstractMatrix} |
| 170 | +function decode_data(li::LikelihoodInformed, data::MorV) where {MorV <: Union{AbstractMatrix, AbstractVector}} |
170 | 171 | decoder_mat = get_decoder_mat(li) |
171 | 172 | return decoder_mat * data |
172 | 173 | end |
|
0 commit comments