Skip to content

Commit f6a2727

Browse files
new predict method API, and reduced sampling in input space (#404)
* added ForwardMapWrapper, made compatible with MCMC and encoder utilities * format * Update src/Emulator.jl Co-authored-by: ArneBouillon <45404227+ArneBouillon@users.noreply.github.com> * first update to predict deprecate transform_to_real replace transform_to_real kwarg, and add add_obs_noise_cov * restore format * initial encode/decode inputs * add orth. compl. piece * format * precompiles * corrected logic for encode/decode flags * add maxlog to warning message * add get_en/decoder_from_schedule * running version or sampler with encoded inputs (just decoding for posterior currently) * return full affine shift from get_encoder_from_schedule, and compute gaussian lifts correctly * compute proposals correctly, resolve input dim checks * format * add guardrailsfor getting empty encoders * reshape bug in test * format * due to reverse-diff slow downs with LinearMaps.jl, reduced the mcmc steps during tests * change lengths * tests for utils * test retrieval of encoder * tests for the enriching * typo * format * updated and added new docstrings * larger tol as it is borderline * Address review comments: reduced prior storage/sampling, always use decode_and_add_noise * add a few extra tests * suitable tols * typos * add scaling, rename boost_for_loss... to noise_injector_threshold * resolve review * remove all "transform_to_real" in tests * add one example to test the deprecation message... --------- Co-authored-by: ArneBouillon <45404227+ArneBouillon@users.noreply.github.com>
1 parent 5d20ea6 commit f6a2727

12 files changed

Lines changed: 791 additions & 132 deletions

File tree

src/Emulator.jl

Lines changed: 188 additions & 28 deletions
Large diffs are not rendered by default.

src/MachineLearningTools/GaussianProcess.jl

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct GaussianProcess{GPPackage, FT, VV <: AbstractVector} <: MachineLearningTo
7878
noise_learn::Bool
7979
"Additional observational or regularization noise in used in GP algorithms"
8080
alg_reg_noise::FT
81-
"Prediction type (`y` to predict the data, `f` to predict the latent function)."
81+
"[Deprecated - use `add_obs_noise_cov` kwarg when calling `predict(`] Prediction type (`y` to predict the data, `f` to predict the latent function)."
8282
prediction_type::PredictionType
8383
"Regularization vector for each output dimension (based on alg_reg_noise"
8484
regularization::VV
@@ -260,7 +260,7 @@ end
260260
function _predict(
261261
gp::GaussianProcess,
262262
new_inputs::AbstractMatrix{FT},
263-
predict_method::Function,
263+
predict_method::Function;
264264
) where {FT <: AbstractFloat}
265265
M = length(gp.models)
266266
N_samples = size(new_inputs, 2)
@@ -284,11 +284,17 @@ predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}, ::FType) wher
284284
"""
285285
$(DocStringExtensions.TYPEDSIGNATURES)
286286
287-
Predict means and covariances in decorrelated output space using Gaussian process models.
287+
Predict means and covariances in decorrelated output space using Gaussian process models. The use of stored `FType` and `YType` to control this method is deprecated, the return covariance is now determined by the `predict(` kwarg `add_obs_noise_cov`
288288
"""
289-
predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat} =
290-
predict(gp, new_inputs, gp.prediction_type)
291-
289+
function predict(
290+
gp::GaussianProcess{GPJL},
291+
new_inputs::AbstractMatrix{FT};
292+
add_obs_noise_cov = false,
293+
mlt_kwargs...,
294+
) where {FT <: AbstractFloat}
295+
pred_type = add_obs_noise_cov ? YType() : FType()
296+
return predict(gp, new_inputs, pred_type)
297+
end
292298

293299
#now we build the SKLJL implementation
294300
function build_models!(
@@ -371,13 +377,20 @@ function _SKJL_predict_function(gp_model::PyObject, new_inputs::AbstractMatrix{F
371377
μ, σ = gp_model.predict(new_inputs', return_std = true)
372378
return μ, (σ .* σ)
373379
end
374-
function predict(gp::GaussianProcess{SKLJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat}
380+
function predict(
381+
gp::GaussianProcess{SKLJL},
382+
new_inputs::AbstractMatrix{FT};
383+
add_obs_noise_cov = false,
384+
mlt_kwargs...,
385+
) where {FT <: AbstractFloat}
375386
μ, σ2 = _predict(gp, new_inputs, _SKJL_predict_function)
376387

377388
# for SKLJL does not return the observational noise (even if return_std = true)
378389
# we must add contribution depending on whether we learnt the noise or not.
379-
for i in 1:size(σ2, 2)
380-
σ2[:, i] = σ2[:, i] + gp.regularization
390+
if add_obs_noise_cov
391+
for i in 1:size(σ2, 2)
392+
σ2[:, i] = σ2[:, i] + gp.regularization
393+
end
381394
end
382395

383396
return μ, σ2
@@ -484,7 +497,12 @@ function optimize_hyperparameters!(gp::GaussianProcess{AGPJL}, args...; kwargs..
484497
@info "AbstractGP already built. Continuing..."
485498
end
486499

487-
function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: AbstractMatrix}
500+
function predict(
501+
gp::GaussianProcess{AGPJL},
502+
new_inputs::AM;
503+
add_obs_noise_cov = false,
504+
mlt_kwargs...,
505+
) where {AM <: AbstractMatrix}
488506

489507
N_models = length(gp.models)
490508
N_samples = size(new_inputs, 2)
@@ -497,8 +515,10 @@ function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: Abstra
497515
μ[i, :] = mean(pred)
498516
σ2[i, :] = var(pred)
499517
end
500-
for i in 1:size(σ2, 2)
501-
σ2[:, i] .= σ2[:, i] + gp.regularization
518+
if add_obs_noise_cov
519+
for i in 1:size(σ2, 2)
520+
σ2[:, i] .= σ2[:, i] + gp.regularization
521+
end
502522
end
503523
return μ, σ2
504524
end

src/MachineLearningTools/ScalarRandomFeature.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,8 @@ function predict(
631631
srfi::ScalarRandomFeatureInterface,
632632
new_inputs::MM;
633633
multithread = "ensemble",
634+
add_obs_noise_cov = false,
635+
mlt_kwargs...,
634636
) where {MM <: AbstractMatrix}
635637
M = length(get_rfms(srfi))
636638
N_samples = size(new_inputs, 2)
@@ -653,11 +655,13 @@ function predict(
653655
end
654656

655657
# add the noise contribution stored within the regularization
656-
reg = get_regularization(srfi)[1]
657-
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658+
if add_obs_noise_cov
659+
reg = get_regularization(srfi)[1]
660+
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658661

659-
for i in 1:M
660-
σ2[i, :] .+= reg_diag[i]
662+
for i in 1:M
663+
σ2[i, :] .+= reg_diag[i]
664+
end
661665
end
662666

663667
return μ, σ2

src/MachineLearningTools/VectorRandomFeature.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,12 @@ $(DocStringExtensions.TYPEDSIGNATURES)
656656
657657
Prediction of data observation (not latent function) at new inputs (passed in as columns in a matrix). That is, we add the observational noise into predictions.
658658
"""
659-
function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <: AbstractMatrix}
659+
function predict(
660+
vrfi::VectorRandomFeatureInterface,
661+
new_inputs::M;
662+
add_obs_noise_cov = false,
663+
mlt_kwargs...,
664+
) where {M <: AbstractMatrix}
660665
input_dim = get_input_dim(vrfi)
661666
output_dim = get_output_dim(vrfi)
662667
rfm = get_rfms(vrfi)[1]
@@ -676,12 +681,14 @@ function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <:
676681
# sizes (output_dim x n_test), (output_dim x output_dim x n_test)
677682
# add the noise contribution from the regularization
678683
# note this is because we are predicting the data here, not the latent function.
679-
lambda = get_regularization(vrfi)[1]
680-
for i in 1:N_samples
681-
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
682-
683-
if !isposdef(σ2[:, :, i])
684-
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
684+
if add_obs_noise_cov
685+
lambda = get_regularization(vrfi)[1]
686+
for i in 1:N_samples
687+
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
688+
689+
if !isposdef(σ2[:, :, i])
690+
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
691+
end
685692
end
686693
end
687694
return μ, σ2

0 commit comments

Comments
 (0)