Skip to content

Commit 42f37c7

Browse files
committed
replace transform_to_real kwarg, and add add_obs_noise_cov
1 parent c50362d commit 42f37c7

5 files changed

Lines changed: 39 additions & 26 deletions

File tree

src/Emulator.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ function predict(
319319
emulator::Emulator{FT},
320320
new_inputs::AM;
321321
encode=nothing, # maps decoded inputs to decoded outputs
322-
add_obs_noise_cov=false,
323322
mlt_kwargs...,
324323
) where {FT <: AbstractFloat, AM <: AbstractMatrix}
325324

src/MachineLearningTools/GaussianProcess.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct GaussianProcess{GPPackage, FT, VV <: AbstractVector} <: MachineLearningTo
7676
noise_learn::Bool
7777
"Additional observational or regularization noise in used in GP algorithms"
7878
alg_reg_noise::FT
79-
"Prediction type (`y` to predict the data, `f` to predict the latent function)."
79+
"[Deprecated - use `add_obs_noise_cov` kwarg when calling `predict(`] Prediction type (`y` to predict the data, `f` to predict the latent function)."
8080
prediction_type::PredictionType
8181
"Regularization vector for each output dimension (based on alg_reg_noise"
8282
regularization::VV
@@ -258,7 +258,7 @@ end
258258
function _predict(
259259
gp::GaussianProcess,
260260
new_inputs::AbstractMatrix{FT},
261-
predict_method::Function,
261+
predict_method::Function;
262262
) where {FT <: AbstractFloat}
263263
M = length(gp.models)
264264
N_samples = size(new_inputs, 2)
@@ -282,11 +282,12 @@ predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}, ::FType) wher
282282
"""
283283
$(DocStringExtensions.TYPEDSIGNATURES)
284284
285-
Predict means and covariances in decorrelated output space using Gaussian process models.
285+
Predict means and covariances in decorrelated output space using Gaussian process models. The use of stored `FType` and `YType` to control this method is deprecated, the return covariance is now determined by the `predict(` kwarg `add_obs_noise_cov`
286286
"""
287-
predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat} =
288-
predict(gp, new_inputs, gp.prediction_type)
289-
287+
function predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}; add_obs_noise_cov=false, mlt_kwargs...) where {FT <: AbstractFloat}
288+
pred_type= add_obs_noise_cov ? YType() : FType()
289+
return predict(gp, new_inputs, pred_type)
290+
end
290291

291292
#now we build the SKLJL implementation
292293
function build_models!(
@@ -369,13 +370,15 @@ function _SKJL_predict_function(gp_model::PyObject, new_inputs::AbstractMatrix{F
369370
μ, σ = gp_model.predict(new_inputs', return_std = true)
370371
return μ, (σ .* σ)
371372
end
372-
function predict(gp::GaussianProcess{SKLJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat}
373+
function predict(gp::GaussianProcess{SKLJL}, new_inputs::AbstractMatrix{FT}; add_obs_noise_cov=false, mlt_kwargs...) where {FT <: AbstractFloat}
373374
μ, σ2 = _predict(gp, new_inputs, _SKJL_predict_function)
374375

375376
# for SKLJL does not return the observational noise (even if return_std = true)
376377
# we must add contribution depending on whether we learnt the noise or not.
377-
for i in 1:size(σ2, 2)
378-
σ2[:, i] = σ2[:, i] + gp.regularization
378+
if add_obs_noise_cov
379+
for i in 1:size(σ2, 2)
380+
σ2[:, i] = σ2[:, i] + gp.regularization
381+
end
379382
end
380383

381384
return μ, σ2
@@ -482,7 +485,7 @@ function optimize_hyperparameters!(gp::GaussianProcess{AGPJL}, args...; kwargs..
482485
@info "AbstractGP already built. Continuing..."
483486
end
484487

485-
function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: AbstractMatrix}
488+
function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM; add_obs_noise_cov=false, mlt_kwargs...) where {AM <: AbstractMatrix}
486489

487490
N_models = length(gp.models)
488491
N_samples = size(new_inputs, 2)
@@ -495,8 +498,10 @@ function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: Abstra
495498
μ[i, :] = mean(pred)
496499
σ2[i, :] = var(pred)
497500
end
498-
for i in 1:size(σ2, 2)
499-
σ2[:, i] .= σ2[:, i] + gp.regularization
501+
if add_obs_noise_cov
502+
for i in 1:size(σ2, 2)
503+
σ2[:, i] .= σ2[:, i] + gp.regularization
504+
end
500505
end
501506
return μ, σ2
502507
end

src/MachineLearningTools/ScalarRandomFeature.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,8 @@ function predict(
631631
srfi::ScalarRandomFeatureInterface,
632632
new_inputs::MM;
633633
multithread = "ensemble",
634+
add_obs_noise_cov=false,
635+
mlt_kwargs...,
634636
) where {MM <: AbstractMatrix}
635637
M = length(get_rfms(srfi))
636638
N_samples = size(new_inputs, 2)
@@ -653,11 +655,13 @@ function predict(
653655
end
654656

655657
# add the noise contribution stored within the regularization
656-
reg = get_regularization(srfi)[1]
657-
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658+
if add_obs_noise_cov
659+
reg = get_regularization(srfi)[1]
660+
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658661

659-
for i in 1:M
660-
σ2[i, :] .+= reg_diag[i]
662+
for i in 1:M
663+
σ2[i, :] .+= reg_diag[i]
664+
end
661665
end
662666

663667
return μ, σ2

src/MachineLearningTools/VectorRandomFeature.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ $(DocStringExtensions.TYPEDSIGNATURES)
656656
657657
Prediction of data observation (not latent function) at new inputs (passed in as columns in a matrix). That is, we add the observational noise into predictions.
658658
"""
659-
function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <: AbstractMatrix}
659+
function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M; add_obs_noise_cov=false, mlt_kwargs...) where {M <: AbstractMatrix}
660660
input_dim = get_input_dim(vrfi)
661661
output_dim = get_output_dim(vrfi)
662662
rfm = get_rfms(vrfi)[1]
@@ -676,12 +676,14 @@ function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <:
676676
# sizes (output_dim x n_test), (output_dim x output_dim x n_test)
677677
# add the noise contribution from the regularization
678678
# note this is because we are predicting the data here, not the latent function.
679-
lambda = get_regularization(vrfi)[1]
680-
for i in 1:N_samples
681-
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
682-
683-
if !isposdef(σ2[:, :, i])
684-
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
679+
if add_obs_noise_cov
680+
lambda = get_regularization(vrfi)[1]
681+
for i in 1:N_samples
682+
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
683+
684+
if !isposdef(σ2[:, :, i])
685+
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
686+
end
685687
end
686688
end
687689
return μ, σ2

src/MarkovChainMonteCarlo.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function emulator_log_density_model(
255255

256256
# predict is written to apply to columns.
257257
# Returned g is a length-1, Vector{Real} or Vector{Vector}, and g_cov is length-1 Vector{Vector} or Vector{Matrix} respectively
258-
g, g_cov = Emulators.predict(em_or_fmw, reshape(θ, :, 1), transform_to_real = false)
258+
g, g_cov = Emulators.predict(em_or_fmw, reshape(θ, :, 1), encode="out", add_obs_noise_cov=true)
259259

260260
if isa(g_cov[1], Real)
261261
return sum([logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)
@@ -577,8 +577,10 @@ function MCMCWrapper(
577577
eachcol(observation)
578578
end
579579

580-
# encoding works on columns but mcmc wants vec-of-vec
580+
# encoding data works on columns but mcmc wants vec-of-vec
581581
encoded_obs = [vec(encode_data(em_or_fmw, reshape(obs, :, 1), "out")) for obs in obs_slice]
582+
# encoding initial condition
583+
#encoded_init_params = vec(encode_data(em_or_fmw, reshape(init_params,:,1), "in"))
582584

583585
log_posterior_map = EmulatorPosteriorModel(prior, em_or_fmw, encoded_obs)
584586
mh_proposal_sampler = MetropolisHastingsSampler(mcmc_alg, prior)
@@ -594,6 +596,7 @@ function MCMCWrapper(
594596
end
595597

596598
sample_kwargs = (; # set defaults here
599+
# :initial_params => deepcopy(encoded_init_params),
597600
:initial_params => deepcopy(init_params),
598601
:param_names => param_names,
599602
:discard_initial => burnin,

0 commit comments

Comments
 (0)