diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index ee730020..9c4ed0c8 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -56,6 +56,8 @@ all(>(2.9), diagview(Dtrunc)) true ``` +Use `maxrank` together with a tolerance to keep at most `maxrank` values above the tolerance (intersection): + ```jldoctest truncations; output=false Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9)); size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) @@ -64,6 +66,16 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) true ``` +Use `minrank` together with a tolerance to guarantee at least `minrank` values are kept (union): + +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 3.5, minrank = 2)); +size(Dtrunc, 1) >= 2 + +# output +true +``` + In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring: ```@docs; canonical = false @@ -84,6 +96,8 @@ size(Dtrunc, 1) <= 2 true ``` +Strategies can be combined with `&` (intersection: keep values satisfying **all** conditions): + ```jldoctest truncations; output=false Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9)) size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) @@ -92,6 +106,17 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) true ``` +Strategies can also be combined with `|` (union: keep values satisfying **any** condition). +This is useful to set a lower bound on the number of kept values with `minrank`: + +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = trunctol(; atol = 3.5) | truncrank(2)) +size(Dtrunc, 1) >= 2 + +# output +true +``` + ## Truncation Strategies MatrixAlgebraKit provides several built-in truncation strategies: @@ -104,11 +129,20 @@ truncfilter truncerror ``` -Truncation strategies can be combined using the `&` operator to create intersection-based truncation. -When strategies are combined, only the values that satisfy all conditions are kept. +Strategies can be composed using the `&` operator ([`TruncationIntersection`](@ref)) to keep only values satisfying **all** conditions, +or the `|` operator ([`TruncationUnion`](@ref)) to keep values satisfying **any** condition. + +```@docs; canonical=false +TruncationIntersection +TruncationUnion +``` ```julia +# Keep at most 10 values, all above tolerance (intersection) combined_trunc = truncrank(10) & trunctol(; atol = 1e-6); + +# Keep values above tolerance, but always at least 3 (union) +combined_trunc = trunctol(; atol = 1e-6) | truncrank(3); ``` ## Truncation Error diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 962a599d..ebec311e 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -161,9 +161,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end -# TODO: intersect doesn't work on GPU +# TODO: intersect/union don't work on GPU MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::ROCVector{Int}) = + MatrixAlgebraKit._ind_union(A, collect(B)) +MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::AbstractVector{<:Integer}) = + MatrixAlgebraKit._ind_union(collect(A), B) +MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::ROCVector{Int}) = + MatrixAlgebraKit._ind_union(collect(A), collect(B)) function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix) hX = sylvester(collect(A), collect(B), collect(C)) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 945f81d1..937027a4 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -165,9 +165,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end -# TODO: intersect doesn't work on GPU +# TODO: intersect/union don't work on GPU MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::CuVector{Int}) = + MatrixAlgebraKit._ind_union(A, collect(B)) +MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::AbstractVector{<:Integer}) = + MatrixAlgebraKit._ind_union(collect(A), B) +MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::CuVector{Int}) = + MatrixAlgebraKit._ind_union(collect(A), collect(B)) function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) # https://github.com/JuliaGPU/CUDA.jl/issues/3021 diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 03fb05bd..ac95ae65 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -57,7 +57,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter eval( Expr( :public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue, - :TruncationByError, :TruncationIntersection, :truncate + :TruncationByError, :TruncationIntersection, :TruncationUnion, :truncate ) ) eval( diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 8d0229ba..18f56413 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -137,6 +137,32 @@ _ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_interse # when all else fails, call intersect _ind_intersect(A, B) = intersect(A, B) +function findtruncated(values::AbstractVector, strategy::TruncationUnion) + length(strategy.components) == 0 && return Base.OneTo(0) + length(strategy.components) == 1 && return findtruncated(values, only(strategy.components)) + ind1 = findtruncated(values, strategy.components[1]) + ind2 = findtruncated(values, TruncationUnion(Base.tail(strategy.components))) + return _ind_union(ind1, ind2) +end +function findtruncated_svd(values::AbstractVector, strategy::TruncationUnion) + length(strategy.components) == 0 && return Base.OneTo(0) + length(strategy.components) == 1 && return findtruncated_svd(values, only(strategy.components)) + ind1 = findtruncated_svd(values, strategy.components[1]) + ind2 = findtruncated_svd(values, TruncationUnion(Base.tail(strategy.components))) + return _ind_union(ind1, ind2) +end + +_ind_union(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .| B +function _ind_union(A::AbstractVector{Bool}, B::AbstractVector) + result = copy(A) + result[B] .= true + return result +end +_ind_union(A::AbstractVector, B::AbstractVector{Bool}) = _ind_union(B, A) +_ind_union(A::Base.OneTo, B::Base.OneTo) = Base.OneTo(max(length(A), length(B))) +_ind_union(A::AbstractUnitRange, B::AbstractUnitRange) = union(A, B) +_ind_union(A, B) = union(A, B) + # Truncation error # ---------------- truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind) diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 26141cc1..4d4e9d10 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -1,9 +1,10 @@ const docs_truncation_kwargs = """ -* `atol::Real` : Absolute tolerance for the truncation -* `rtol::Real` : Relative tolerance for the truncation -* `maxrank::Real` : Maximal rank for the truncation -* `maxerror::Real` : Maximal truncation error. -* `filter` : Custom filter to select truncated values. +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxrank::Integer` : Maximal rank for the truncation +* `minrank::Integer` : Minimal rank for the truncation +* `maxerror::Real` : Maximal truncation error +* `filter` : Custom filter to select truncated values """ const docs_truncation_strategies = """ @@ -28,16 +29,18 @@ Select a truncation strategy based on the provided keyword arguments. ## Keyword arguments The following keyword arguments are all optional, and their default value (`nothing`) will be ignored. It is also allowed to combine multiple of these, in which case the kept -values will consist of the intersection of the different truncated strategies. +values will consist of the intersection of the different truncated strategies (except +`minrank`, which uses union semantics to guarantee a lower bound on the number of kept values). $docs_truncation_kwargs """ function TruncationStrategy(; atol::Union{Real, Nothing} = nothing, rtol::Union{Real, Nothing} = nothing, - maxrank::Union{Real, Nothing} = nothing, + maxrank::Union{Integer, Nothing} = nothing, + minrank::Union{Integer, Nothing} = nothing, maxerror::Union{Real, Nothing} = nothing, - filter = nothing + filter = nothing, ) strategy = notrunc() @@ -51,6 +54,14 @@ function TruncationStrategy(; isnothing(maxerror) || (strategy &= truncerror(; atol = maxerror)) isnothing(filter) || (strategy &= truncfilter(filter)) + # union constraint: guarantee a lower bound on number of kept values + # special-case NoTruncation: keeping everything already satisfies any minrank + if !isnothing(minrank) && !(strategy isa NoTruncation) + strategy |= truncrank(minrank) + elseif !isnothing(minrank) + strategy = truncrank(minrank) + end + return strategy end @@ -222,6 +233,43 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc() Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc +""" + TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...) + +Truncation strategy that composes multiple truncation strategies, keeping values that are +present in any of them. +""" +struct TruncationUnion{T <: Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy + components::T +end +function TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...) + return TruncationUnion((trunc, truncs...)) +end + +function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + return TruncationUnion((trunc1, trunc2)) +end + +# flatten components +function Base.:|(trunc1::TruncationUnion, trunc2::TruncationUnion) + return TruncationUnion((trunc1.components..., trunc2.components...)) +end +function Base.:|(trunc1::TruncationUnion, trunc2::TruncationStrategy) + return TruncationUnion((trunc1.components..., trunc2)) +end +function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationUnion) + return TruncationUnion((trunc1, trunc2.components...)) +end + +# NoTruncation is the absorbing element for | (union with "keep all" = keep all) +Base.:|(::NoTruncation, ::TruncationStrategy) = notrunc() +Base.:|(::TruncationStrategy, ::NoTruncation) = notrunc() +Base.:|(::NoTruncation, ::NoTruncation) = notrunc() + +# disambiguate +Base.:|(::NoTruncation, ::TruncationUnion) = notrunc() +Base.:|(::TruncationUnion, ::NoTruncation) = notrunc() + @doc """ truncation_error(values, ind) Compute the truncation error as the 2-norm of the values that are not kept by `ind`. diff --git a/test/testsuite/decompositions/svd.jl b/test/testsuite/decompositions/svd.jl index 6d5799ba..a6aa4fb9 100644 --- a/test/testsuite/decompositions/svd.jl +++ b/test/testsuite/decompositions/svd.jl @@ -212,6 +212,29 @@ function test_svd_trunc( @test diagview(S2) ≈ diagview(S)[1:2] end end + @testset "mix minrank and tol" begin + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01]) + S = Diagonal(Sdiag) + Vᴴ = instantiate_unitary(T, A, m4) + A = U * S * Vᴴ + for trunc_fun in ( + (rtol, minrank) -> (; rtol, minrank), + (rtol, minrank) -> trunctol(; rtol) | truncrank(minrank), + ) + # trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3 + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3)) + @test length(diagview(S1)) == 3 + @test diagview(S1) ≈ diagview(S)[1:3] + + # trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2 + U2, S2, V2ᴴ = svd_trunc_no_error(A; trunc = trunc_fun(0.2, 1)) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] + end + end @testset "specify truncation algorithm" begin atol = sqrt(eps(real(eltype(T)))) m4 = 4 @@ -294,6 +317,29 @@ function test_svd_trunc_algs( @test collect(diagview(S2)) ≈ collect(diagview(S)[1:2]) end end + @testset "mix minrank and tol" begin + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + S = Diagonal(Sdiag) + Vᴴ = instantiate_unitary(T, A, m4) + A = U * S * Vᴴ + for trunc_fun in ( + (rtol, minrank) -> (; rtol, minrank), + (rtol, minrank) -> trunctol(; rtol) | truncrank(minrank), + ) + # trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3 + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3), alg) + @test length(diagview(S1)) == 3 + @test collect(diagview(S1)) ≈ collect(diagview(S)[1:3]) + + # trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2 + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg) + @test length(diagview(S2)) == 2 + @test collect(diagview(S2)) ≈ collect(diagview(S)[1:2]) + end + end @testset "specify truncation algorithm" begin atol = sqrt(eps(real(eltype(T)))) m4 = 4 diff --git a/test/truncate.jl b/test/truncate.jl index f2d24a7a..8be5e700 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -1,8 +1,8 @@ using MatrixAlgebraKit using Test using TestExtras -using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, - TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd +using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationUnion, + TruncationByOrder, TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd @testset "truncate" begin trunc = @constinferred TruncationStrategy() @@ -65,4 +65,32 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder, @test issetequal(values[@constinferred(findtruncated(values, strategy))], values[2:5]) vals_sorted = sort(values; by = abs, rev = true) @test vals_sorted[@constinferred(findtruncated_svd(vals_sorted, strategy))] == vals_sorted[1:4] + + # TruncationUnion / minrank + trunc = @constinferred TruncationStrategy(; minrank = 3) + @test trunc isa TruncationByOrder + @test trunc == truncrank(3) + + trunc = @constinferred TruncationStrategy(; atol, minrank = 3) + @test trunc isa TruncationUnion + @test trunc == trunctol(; atol) | truncrank(3) + + # | operator + values2 = [1.0, 0.9, 0.5, 0.3, 0.01] + # trunctol keeps 1:3 (above 0.4), truncrank(4) keeps 1:4, union keeps 1:4 + strategy = trunctol(; atol = 0.4) | truncrank(4) + @test @constinferred(findtruncated_svd(values2, strategy)) == 1:4 + # trunctol keeps 1:3, truncrank(2) keeps 1:2, union keeps 1:3 + strategy = trunctol(; atol = 0.4) | truncrank(2) + @test @constinferred(findtruncated_svd(values2, strategy)) == 1:3 + + # notrunc is absorbing for | + @test (notrunc() | truncrank(3)) isa NoTruncation + @test (truncrank(3) | notrunc()) isa NoTruncation + + # TruncationUnion flattening + union1 = truncrank(2) | trunctol(; atol = 0.4) + union2 = union1 | truncrank(4) + @test union2 isa TruncationUnion + @test length(union2.components) == 3 end