Skip to content

Commit b02758f

Browse files
committed
first update to predict
deprecate transform_to_real replace transform_to_real kwarg, and add add_obs_noise_cov
1 parent eafcc06 commit b02758f

5 files changed

Lines changed: 99 additions & 31 deletions

File tree

src/Emulator.jl

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ Return type of N inputs: (in the output space)
318318
function predict(
319319
emulator::Emulator{FT},
320320
new_inputs::AM;
321-
transform_to_real = false,
321+
encode=nothing, # maps decoded inputs to decoded outputs
322322
mlt_kwargs...,
323323
) where {FT <: AbstractFloat, AM <: AbstractMatrix}
324+
324325
# Check if the size of new_inputs is consistent with the training data input
325326
input_dim, output_dim = size(get_io_pairs(emulator), 1)
326327
encoded_input_dim, encoded_output_dim = size(get_encoded_io_pairs(emulator), 1)
@@ -335,8 +336,17 @@ function predict(
335336
)
336337
end
337338

339+
# note the logic below
340+
in_already_encoded = encode ["in", "in_and_out"]
341+
out_to_be_decoded = encode ["out","in_and_out"]
342+
343+
338344
# encode the new input data
339-
encoded_inputs = encode_data(emulator, new_inputs, "in")
345+
if !in_already_encoded
346+
encoded_inputs = encode_data(emulator, new_inputs, "in")
347+
else
348+
encoded_inputs = new_inputs
349+
end
340350
# predict in encoding space
341351
# returns outputs: [enc_out_dim x n_samples]
342352
# Scalar-methods uncertainties=variances: [enc_out_dim x n_samples]
@@ -346,7 +356,7 @@ function predict(
346356
var_or_cov = (ndims(encoded_uncertainties) == 2) ? "var" : "cov"
347357

348358
# return decoded or encoded?
349-
if transform_to_real
359+
if out_to_be_decoded
350360
decoded_outputs = decode_data(emulator, encoded_outputs, "out")
351361

352362
decoded_covariances = zeros(eltype(encoded_outputs), output_dim, output_dim, size(encoded_uncertainties)[end])
@@ -433,7 +443,8 @@ end
433443
function predict(
434444
fmw::FMW,
435445
new_inputs::AM;
436-
transform_to_real = false,
446+
encode=nothing, # maps decoded inputs to decoded outputs
447+
add_obs_noise_cov=false,
437448
) where {FMW <: ForwardMapWrapper, AM <: AbstractMatrix}
438449
# Check if the size of new_inputs is consistent with the training input data
439450
input_dim, output_dim = size(get_io_pairs(fmw), 1)
@@ -449,6 +460,15 @@ function predict(
449460
)
450461
end
451462

463+
in_already_encoded = encode ["in", "in_and_out"]
464+
out_to_be_decoded = encode ["out","in_and_out"]
465+
#need to boost to decoded inputs
466+
if in_already_encoded
467+
# Sample from the null space
468+
decoded_inputs = ...
469+
else
470+
decoded_inputs = new_inputs
471+
end
452472
# Scalar-methods uncertainties=variances: [enc_out_dim x n_samples]
453473
# Vector-methods uncertainties=covariances: [enc_out_dim x enc_out_dim x n_samples)
454474

@@ -457,10 +477,10 @@ function predict(
457477
forward_map = get_forward_map(fmw)
458478
fm_unc = x -> forward_map(transform_unconstrained_to_constrained(prior, x))
459479

460-
decoded_outputs = reduce(hcat, map(fm_unc, eachcol(new_inputs))) # apply map and return: [out_dim x n_samples]
480+
decoded_outputs = reduce(hcat, map(fm_unc, eachcol(decoded_inputs))) # apply map and return: [out_dim x n_samples]
461481

462482
var_or_cov = (output_dim == 1) ? "var" : "cov"
463-
if transform_to_real
483+
if out_to_be_decoded
464484
# uncertainty returned is just `I` in encoded space
465485
decoded_cov = Matrix(decode_structure_matrix(fmw, I(output_dim), "out"))
466486

@@ -496,4 +516,38 @@ function predict(
496516
end
497517
end
498518

519+
520+
### Deprecated keywords
521+
522+
function predict(
523+
em_or_fmw::EorFMW,
524+
new_inputs::AM;
525+
transform_to_real = nothing,
526+
kwargs...,
527+
) where {AM <: AbstractMatrix, EorFMW <: Union{Emulator, ForwardMapWrapper}}
528+
529+
if !isnothing(transform_to_real)
530+
Base.depwarn(
531+
"""`transform_to_real` keyword is deprecated. Please use the `encode` and `add_obs_noise_cov` keywords instead.
532+
533+
Recommended usage for users is now set by default as:
534+
- `encode=nothing`, `add_obs_noise_cov=false`
535+
This behaviour takes in non-encoded inputs, and returns non-encoded outputs. It gives only the uncertainty from the Machine Learning Tool (not inflated by observational noise)
536+
537+
This simulation will continue with the old behavior:
538+
- `transform_to_real=true` replaced with `encode=nothing, add_obs_noise_cov=true`
539+
- `transform_to_real=false` replaced with `encode="out", add_obs_noise_cov=true`
540+
""",
541+
:predict,
542+
)
543+
544+
# modify kwargs
545+
kw = Dict(kwargs)
546+
kw[:add_obs_noise_cov] = true
547+
kw[:encode] = transform_to_real ? nothing : "out"
548+
predict(em_or_fmw, new_inputs; kw...)
549+
end
550+
551+
return predict(em_or_fmw, new_inputs; kwargs...)
552+
499553
end

src/MachineLearningTools/GaussianProcess.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct GaussianProcess{GPPackage, FT, VV <: AbstractVector} <: MachineLearningTo
7878
noise_learn::Bool
7979
"Additional observational or regularization noise in used in GP algorithms"
8080
alg_reg_noise::FT
81-
"Prediction type (`y` to predict the data, `f` to predict the latent function)."
81+
"[Deprecated - use `add_obs_noise_cov` kwarg when calling `predict(`] Prediction type (`y` to predict the data, `f` to predict the latent function)."
8282
prediction_type::PredictionType
8383
"Regularization vector for each output dimension (based on alg_reg_noise"
8484
regularization::VV
@@ -260,7 +260,7 @@ end
260260
function _predict(
261261
gp::GaussianProcess,
262262
new_inputs::AbstractMatrix{FT},
263-
predict_method::Function,
263+
predict_method::Function;
264264
) where {FT <: AbstractFloat}
265265
M = length(gp.models)
266266
N_samples = size(new_inputs, 2)
@@ -284,11 +284,12 @@ predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}, ::FType) wher
284284
"""
285285
$(DocStringExtensions.TYPEDSIGNATURES)
286286
287-
Predict means and covariances in decorrelated output space using Gaussian process models.
287+
Predict means and covariances in decorrelated output space using Gaussian process models. The use of stored `FType` and `YType` to control this method is deprecated, the return covariance is now determined by the `predict(` kwarg `add_obs_noise_cov`
288288
"""
289-
predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat} =
290-
predict(gp, new_inputs, gp.prediction_type)
291-
289+
function predict(gp::GaussianProcess{GPJL}, new_inputs::AbstractMatrix{FT}; add_obs_noise_cov=false, mlt_kwargs...) where {FT <: AbstractFloat}
290+
pred_type= add_obs_noise_cov ? YType() : FType()
291+
return predict(gp, new_inputs, pred_type)
292+
end
292293

293294
#now we build the SKLJL implementation
294295
function build_models!(
@@ -371,13 +372,15 @@ function _SKJL_predict_function(gp_model::PyObject, new_inputs::AbstractMatrix{F
371372
μ, σ = gp_model.predict(new_inputs', return_std = true)
372373
return μ, (σ .* σ)
373374
end
374-
function predict(gp::GaussianProcess{SKLJL}, new_inputs::AbstractMatrix{FT}) where {FT <: AbstractFloat}
375+
function predict(gp::GaussianProcess{SKLJL}, new_inputs::AbstractMatrix{FT}; add_obs_noise_cov=false, mlt_kwargs...) where {FT <: AbstractFloat}
375376
μ, σ2 = _predict(gp, new_inputs, _SKJL_predict_function)
376377

377378
# for SKLJL does not return the observational noise (even if return_std = true)
378379
# we must add contribution depending on whether we learnt the noise or not.
379-
for i in 1:size(σ2, 2)
380-
σ2[:, i] = σ2[:, i] + gp.regularization
380+
if add_obs_noise_cov
381+
for i in 1:size(σ2, 2)
382+
σ2[:, i] = σ2[:, i] + gp.regularization
383+
end
381384
end
382385

383386
return μ, σ2
@@ -484,7 +487,7 @@ function optimize_hyperparameters!(gp::GaussianProcess{AGPJL}, args...; kwargs..
484487
@info "AbstractGP already built. Continuing..."
485488
end
486489

487-
function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: AbstractMatrix}
490+
function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM; add_obs_noise_cov=false, mlt_kwargs...) where {AM <: AbstractMatrix}
488491

489492
N_models = length(gp.models)
490493
N_samples = size(new_inputs, 2)
@@ -497,8 +500,10 @@ function predict(gp::GaussianProcess{AGPJL}, new_inputs::AM) where {AM <: Abstra
497500
μ[i, :] = mean(pred)
498501
σ2[i, :] = var(pred)
499502
end
500-
for i in 1:size(σ2, 2)
501-
σ2[:, i] .= σ2[:, i] + gp.regularization
503+
if add_obs_noise_cov
504+
for i in 1:size(σ2, 2)
505+
σ2[:, i] .= σ2[:, i] + gp.regularization
506+
end
502507
end
503508
return μ, σ2
504509
end

src/MachineLearningTools/ScalarRandomFeature.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,8 @@ function predict(
631631
srfi::ScalarRandomFeatureInterface,
632632
new_inputs::MM;
633633
multithread = "ensemble",
634+
add_obs_noise_cov=false,
635+
mlt_kwargs...,
634636
) where {MM <: AbstractMatrix}
635637
M = length(get_rfms(srfi))
636638
N_samples = size(new_inputs, 2)
@@ -653,11 +655,13 @@ function predict(
653655
end
654656

655657
# add the noise contribution stored within the regularization
656-
reg = get_regularization(srfi)[1]
657-
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658+
if add_obs_noise_cov
659+
reg = get_regularization(srfi)[1]
660+
reg_diag = isa(reg, UniformScaling) ? reg.λ * ones(M) : diag(reg)
658661

659-
for i in 1:M
660-
σ2[i, :] .+= reg_diag[i]
662+
for i in 1:M
663+
σ2[i, :] .+= reg_diag[i]
664+
end
661665
end
662666

663667
return μ, σ2

src/MachineLearningTools/VectorRandomFeature.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ $(DocStringExtensions.TYPEDSIGNATURES)
656656
657657
Prediction of data observation (not latent function) at new inputs (passed in as columns in a matrix). That is, we add the observational noise into predictions.
658658
"""
659-
function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <: AbstractMatrix}
659+
function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M; add_obs_noise_cov=false, mlt_kwargs...) where {M <: AbstractMatrix}
660660
input_dim = get_input_dim(vrfi)
661661
output_dim = get_output_dim(vrfi)
662662
rfm = get_rfms(vrfi)[1]
@@ -676,12 +676,14 @@ function predict(vrfi::VectorRandomFeatureInterface, new_inputs::M) where {M <:
676676
# sizes (output_dim x n_test), (output_dim x output_dim x n_test)
677677
# add the noise contribution from the regularization
678678
# note this is because we are predicting the data here, not the latent function.
679-
lambda = get_regularization(vrfi)[1]
680-
for i in 1:N_samples
681-
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
682-
683-
if !isposdef(σ2[:, :, i])
684-
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
679+
if add_obs_noise_cov
680+
lambda = get_regularization(vrfi)[1]
681+
for i in 1:N_samples
682+
σ2[:, :, i] = 0.5 * (σ2[:, :, i] + permutedims(σ2[:, :, i], (2, 1))) + lambda
683+
684+
if !isposdef(σ2[:, :, i])
685+
σ2[:, :, i] = posdef_correct(σ2[:, :, i])
686+
end
685687
end
686688
end
687689
return μ, σ2

src/MarkovChainMonteCarlo.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function emulator_log_density_model(
255255

256256
# predict is written to apply to columns.
257257
# Returned g is a length-1, Vector{Real} or Vector{Vector}, and g_cov is length-1 Vector{Vector} or Vector{Matrix} respectively
258-
g, g_cov = Emulators.predict(em_or_fmw, reshape(θ, :, 1), transform_to_real = false)
258+
g, g_cov = Emulators.predict(em_or_fmw, reshape(θ, :, 1), encode="out", add_obs_noise_cov=true)
259259

260260
if isa(g_cov[1], Real)
261261
return sum([logpdf(MvNormal(obs, g_cov[1] * I), vec(g)) for obs in obs_vec]) + logpdf(prior, θ)
@@ -577,8 +577,10 @@ function MCMCWrapper(
577577
eachcol(observation)
578578
end
579579

580-
# encoding works on columns but mcmc wants vec-of-vec
580+
# encoding data works on columns but mcmc wants vec-of-vec
581581
encoded_obs = [vec(encode_data(em_or_fmw, reshape(obs, :, 1), "out")) for obs in obs_slice]
582+
# encoding initial condition
583+
#encoded_init_params = vec(encode_data(em_or_fmw, reshape(init_params,:,1), "in"))
582584

583585
log_posterior_map = EmulatorPosteriorModel(prior, em_or_fmw, encoded_obs)
584586
mh_proposal_sampler = MetropolisHastingsSampler(mcmc_alg, prior)
@@ -594,6 +596,7 @@ function MCMCWrapper(
594596
end
595597

596598
sample_kwargs = (; # set defaults here
599+
# :initial_params => deepcopy(encoded_init_params),
597600
:initial_params => deepcopy(init_params),
598601
:param_names => param_names,
599602
:discard_initial => burnin,

0 commit comments

Comments
 (0)