From b5bb9e7749da5bc3657c20dede3609ff224ef4c0 Mon Sep 17 00:00:00 2001 From: Tom Wright Date: Wed, 27 Oct 2021 17:04:54 -0400 Subject: [PATCH] Adding a MOInputHeterotopic type --- Project.toml | 2 +- src/KernelFunctions.jl | 3 ++- src/mokernels/moinput.jl | 47 +++++++++++++++++++++++++++++++++++++++ test/mokernels/moinput.jl | 21 +++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3d58062e3..a17fb8d99 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.23" +version = "0.10.24" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 37bde4a65..e7b92c44a 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -37,7 +37,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel export ColVecs, RowVecs -export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_output_data +export MOInput, MOInputHeterotopic, + prepare_isotopic_multi_output_data, prepare_heterotopic_multi_output_data export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel diff --git a/src/mokernels/moinput.jl b/src/mokernels/moinput.jl index 9a802280d..88db65eb0 100644 --- a/src/mokernels/moinput.jl +++ b/src/mokernels/moinput.jl @@ -58,6 +58,43 @@ struct MOInputIsotopicByOutputs{S,T<:AbstractVector{S}} <: AbstractVector{Tuple{ out_dim::Integer end +""" + MOInputsHeterotopic(x::AbstractVector, output_indices::Integer) + +`MOInputsHeterotopic(x, output_indices)` has length `length(x)`. + +```jldoctest +julia> x = [1, 2, 3, 4, 5, 6]; + +julia> out_inds = [1, 1, 2, 3, 2, 1]; + +julia> KernelFunctions.MOInputsHeterotopic(x, out_inds) +6-element KernelFunctions.MOInputsHeterotopic{Int64, Vector{Int64}}: + (1, 1) + (2, 1) + (3, 2) + (4, 3) + (5, 2) + (6, 1) +``` + +Accommodates modelling multi-dimensional output data where not all outputs are observed +for every input. + +As shown above, an `MOInputsHeterotopic` represents a vector of tuples. +The `length(x)` elements represent the inputs that are observed at the locations specified +by `output_indices`. +""" +struct MOInputsHeterotopic{S ,T<:AbstractVector{S}} <: AbstractVector{Tuple{S,Int}} + x::T + output_indices::AbstractVector{Int} +end + +# Return the inputs at a specific output +function get_inputs_at_output(inp::MOInputsHeterotopic, output) + return [input[1] for input in inputs if input[2]==output] +end + const IsotopicMOInputsUnion = Union{MOInputIsotopicByFeatures,MOInputIsotopicByOutputs} function Base.getindex(inp::MOInputIsotopicByOutputs, ind::Integer) @@ -74,7 +111,13 @@ function Base.getindex(inp::MOInputIsotopicByFeatures, ind::Integer) return feature, output_index end +function Base.getindex(inp::MOInputsHeterotopic, ind::Integer) + @boundscheck checkbounds(inp, ind) + return inp.x[ind], inp.output_indices[ind] +end + Base.size(inp::IsotopicMOInputsUnion) = (inp.out_dim * length(inp.x),) +Base.size(inp::MOInputsHeterotopic) = (length(inp.output_indices),) function Base.vcat(x::MOInputIsotopicByFeatures, y::MOInputIsotopicByFeatures) x.out_dim == y.out_dim || throw(DimensionMismatch("out_dim mismatch")) @@ -86,6 +129,10 @@ function Base.vcat(x::MOInputIsotopicByOutputs, y::MOInputIsotopicByOutputs) return MOInputIsotopicByOutputs(vcat(x.x, y.x), x.out_dim) end +function Base.vcat(x::MOInputsHeterotopic, y::MOInputsHeterotopic) + return MOInputsHeterotopic(vcat(x.x, y.x), vcat(x.output_indices, y.output_indices)) +end + """ MOInput(x::AbstractVector, out_dim::Integer) diff --git a/test/mokernels/moinput.jl b/test/mokernels/moinput.jl index 6fab09264..05f845ad2 100644 --- a/test/mokernels/moinput.jl +++ b/test/mokernels/moinput.jl @@ -48,6 +48,27 @@ @test all([(x_, i) for x_ in x for i in 1:3] .== ibf) end + @testset "heterotopic" begin + out_inds = [1, 2, 3, 2] + mo_input = KernelFunctions.MOInputsHeterotopic(x, out_inds) + @test isa(mo_input, type_1) == true + @test isa(mo_input, type_2) == false + + @test length(mo_input) == 4 + @test size(mo_input) == (4,) + @test size(mo_input, 1) == 4 + @test size(mo_input, 2) == 1 + @test lastindex(mo_input) == 4 + @test firstindex(mo_input) == 1 + @test_throws BoundsError mo_input[0] + @test vcat(mo_input, mo_input) == KernelFunctions.MOInputsHeterotopic(vcat(x, x), vcat(out_inds, out_inds)) + + @test mo_input[2] == (x[2], 2) + @test mo_input[3] == (x[3], 3) + @test mo_input[4] == (x[4], 2) + @test all([(x_, i) for (x_, i) in zip(x, out_inds)] .== mo_input) + end + @testset "prepare_isotopic_multi_output_data" begin @testset "ColVecs" begin N = 5