Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ all(>(2.9), diagview(Dtrunc))
true
```

Use `maxrank` together with a tolerance to keep at most `maxrank` values above the tolerance (intersection):

```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
Expand All @@ -64,6 +66,16 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
true
```

Use `minrank` together with a tolerance to guarantee at least `minrank` values are kept (union):

```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 3.5, minrank = 2));
size(Dtrunc, 1) >= 2

# output
true
```

In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring:

```@docs; canonical = false
Expand All @@ -84,6 +96,8 @@ size(Dtrunc, 1) <= 2
true
```

Strategies can be combined with `&` (intersection: keep values satisfying **all** conditions):

```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
Expand All @@ -92,6 +106,17 @@ size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))
true
```

Strategies can also be combined with `|` (union: keep values satisfying **any** condition).
This is useful to set a lower bound on the number of kept values with `minrank`:

```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = trunctol(; atol = 3.5) | truncrank(2))
size(Dtrunc, 1) >= 2

# output
true
```

## Truncation Strategies

MatrixAlgebraKit provides several built-in truncation strategies:
Expand All @@ -104,11 +129,20 @@ truncfilter
truncerror
```

Truncation strategies can be combined using the `&` operator to create intersection-based truncation.
When strategies are combined, only the values that satisfy all conditions are kept.
Strategies can be composed using the `&` operator ([`TruncationIntersection`](@ref)) to keep only values
satisfying all conditions, or the `|` operator ([`TruncationUnion`](@ref)) to keep values satisfying any condition.

```@docs; canonical=false
TruncationIntersection
TruncationUnion
```

```julia
# Keep at most 10 values, all above tolerance (intersection)
combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);

# Keep values above tolerance, but always at least 3 (union)
combined_trunc = trunctol(; atol = 1e-6) | truncrank(3);
```

## Truncation Error
Expand Down
8 changes: 7 additions & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix
return C
end

# TODO: intersect doesn't work on GPU
# TODO: intersect/union don't work on GPU
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_union(A, collect(B))
MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::AbstractVector{<:Integer}) =
MatrixAlgebraKit._ind_union(collect(A), B)
MatrixAlgebraKit._ind_union(A::ROCVector{Int}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_union(collect(A), collect(B))

function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
hX = sylvester(collect(A), collect(B), collect(C))
Expand Down
8 changes: 7 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,15 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T
return C
end

# TODO: intersect doesn't work on GPU
# TODO: intersect/union don't work on GPU
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))
MatrixAlgebraKit._ind_union(A::AbstractVector{<:Integer}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_union(A, collect(B))
MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::AbstractVector{<:Integer}) =
MatrixAlgebraKit._ind_union(collect(A), B)
MatrixAlgebraKit._ind_union(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_union(collect(A), collect(B))

function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
eval(
Expr(
:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
:TruncationByError, :TruncationIntersection, :truncate
:TruncationByError, :TruncationIntersection, :TruncationUnion, :truncate
)
)
eval(
Expand Down
26 changes: 26 additions & 0 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ _ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_interse
# when all else fails, call intersect
_ind_intersect(A, B) = intersect(A, B)

function findtruncated(values::AbstractVector, strategy::TruncationUnion)
length(strategy.components) == 0 && return Base.OneTo(0)
length(strategy.components) == 1 && return findtruncated(values, only(strategy.components))
ind1 = findtruncated(values, strategy.components[1])
ind2 = findtruncated(values, TruncationUnion(Base.tail(strategy.components)))
return _ind_union(ind1, ind2)
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationUnion)
length(strategy.components) == 0 && return Base.OneTo(0)
length(strategy.components) == 1 && return findtruncated_svd(values, only(strategy.components))
ind1 = findtruncated_svd(values, strategy.components[1])
ind2 = findtruncated_svd(values, TruncationUnion(Base.tail(strategy.components)))
return _ind_union(ind1, ind2)
end

_ind_union(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .| B
function _ind_union(A::AbstractVector{Bool}, B::AbstractVector)
result = copy(A)
result[B] .= true
return result
end
_ind_union(A::AbstractVector, B::AbstractVector{Bool}) = _ind_union(B, A)
_ind_union(A::Base.OneTo, B::Base.OneTo) = Base.OneTo(max(length(A), length(B)))
_ind_union(A::AbstractUnitRange, B::AbstractUnitRange) = union(A, B)
_ind_union(A, B) = union(A, B)

# Truncation error
# ----------------
truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind)
Expand Down
64 changes: 56 additions & 8 deletions src/interface/truncation.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
const docs_truncation_kwargs = """
* `atol::Real` : Absolute tolerance for the truncation
* `rtol::Real` : Relative tolerance for the truncation
* `maxrank::Real` : Maximal rank for the truncation
* `maxerror::Real` : Maximal truncation error.
* `filter` : Custom filter to select truncated values.
* `atol::Real` : Absolute tolerance for the truncation
* `rtol::Real` : Relative tolerance for the truncation
* `maxrank::Integer` : Maximal rank for the truncation
* `minrank::Integer` : Minimal rank for the truncation
* `maxerror::Real` : Maximal truncation error.
* `filter` : Custom filter to select truncated values.
"""

const docs_truncation_strategies = """
Expand All @@ -28,16 +29,18 @@ Select a truncation strategy based on the provided keyword arguments.
## Keyword arguments
The following keyword arguments are all optional, and their default value (`nothing`)
will be ignored. It is also allowed to combine multiple of these, in which case the kept
values will consist of the intersection of the different truncated strategies.
values will consist of the intersection of the different truncated strategies (except
`minrank`, which uses union semantics to guarantee a lower bound on the number of kept values).

$docs_truncation_kwargs
"""
function TruncationStrategy(;
atol::Union{Real, Nothing} = nothing,
rtol::Union{Real, Nothing} = nothing,
maxrank::Union{Real, Nothing} = nothing,
maxrank::Union{Integer, Nothing} = nothing,
minrank::Union{Integer, Nothing} = nothing,
maxerror::Union{Real, Nothing} = nothing,
filter = nothing
filter = nothing,
)
strategy = notrunc()

Expand All @@ -51,6 +54,14 @@ function TruncationStrategy(;
isnothing(maxerror) || (strategy &= truncerror(; atol = maxerror))
isnothing(filter) || (strategy &= truncfilter(filter))

# union constraint: guarantee a lower bound on number of kept values
# special-case NoTruncation: keeping everything already satisfies any minrank
if !isnothing(minrank) && !(strategy isa NoTruncation)
strategy |= truncrank(minrank)
elseif !isnothing(minrank)
strategy = truncrank(minrank)
end

return strategy
end

Expand Down Expand Up @@ -222,6 +233,43 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc()
Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc
Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc

"""
TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...)

Truncation strategy that composes multiple truncation strategies, keeping values that are
present in any of them.
"""
struct TruncationUnion{T <: Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy
components::T
end
function TruncationUnion(trunc::TruncationStrategy, truncs::TruncationStrategy...)
return TruncationUnion((trunc, truncs...))
end

function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
return TruncationUnion((trunc1, trunc2))
end

# flatten components
function Base.:|(trunc1::TruncationUnion, trunc2::TruncationUnion)
return TruncationUnion((trunc1.components..., trunc2.components...))
end
function Base.:|(trunc1::TruncationUnion, trunc2::TruncationStrategy)
return TruncationUnion((trunc1.components..., trunc2))
end
function Base.:|(trunc1::TruncationStrategy, trunc2::TruncationUnion)
return TruncationUnion((trunc1, trunc2.components...))
end

# NoTruncation is the absorbing element for | (union with "keep all" = keep all)
Base.:|(::NoTruncation, ::TruncationStrategy) = notrunc()
Base.:|(::TruncationStrategy, ::NoTruncation) = notrunc()
Base.:|(::NoTruncation, ::NoTruncation) = notrunc()

# disambiguate
Base.:|(::NoTruncation, ::TruncationUnion) = notrunc()
Base.:|(::TruncationUnion, ::NoTruncation) = notrunc()

@doc """
truncation_error(values, ind)
Compute the truncation error as the 2-norm of the values that are not kept by `ind`.
Expand Down
46 changes: 46 additions & 0 deletions test/testsuite/decompositions/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,29 @@ function test_svd_trunc(
@test diagview(S2) ≈ diagview(S)[1:2]
end
end
@testset "mix minrank and tol" begin
m4 = 4
U = instantiate_unitary(T, A, m4)
Sdiag = similar(A, real(eltype(T)), m4)
copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01])
S = Diagonal(Sdiag)
Vᴴ = instantiate_unitary(T, A, m4)
A = U * S * Vᴴ
for trunc_fun in (
(rtol, minrank) -> (; rtol, minrank),
(rtol, minrank) -> trunctol(; rtol) | truncrank(minrank),
)
# trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3))
@test length(diagview(S1)) == 3
@test diagview(S1) ≈ diagview(S)[1:3]

# trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2
U2, S2, V2ᴴ = svd_trunc_no_error(A; trunc = trunc_fun(0.2, 1))
@test length(diagview(S2)) == 2
@test diagview(S2) ≈ diagview(S)[1:2]
end
end
@testset "specify truncation algorithm" begin
atol = sqrt(eps(real(eltype(T))))
m4 = 4
Expand Down Expand Up @@ -294,6 +317,29 @@ function test_svd_trunc_algs(
@test collect(diagview(S2)) ≈ collect(diagview(S)[1:2])
end
end
@testset "mix minrank and tol" begin
m4 = 4
U = instantiate_unitary(T, A, m4)
Sdiag = similar(A, real(eltype(T)), m4)
copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
S = Diagonal(Sdiag)
Vᴴ = instantiate_unitary(T, A, m4)
A = U * S * Vᴴ
for trunc_fun in (
(rtol, minrank) -> (; rtol, minrank),
(rtol, minrank) -> trunctol(; rtol) | truncrank(minrank),
)
# trunctol(rtol=0.5) keeps 1 value, truncrank(3) keeps 3, union keeps 3
U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.5, 3), alg)
@test length(diagview(S1)) == 3
@test collect(diagview(S1)) ≈ collect(diagview(S)[1:3])

# trunctol(rtol=0.2) keeps 2 values, truncrank(1) keeps 1, union keeps 2
U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg)
@test length(diagview(S2)) == 2
@test collect(diagview(S2)) ≈ collect(diagview(S)[1:2])
end
end
@testset "specify truncation algorithm" begin
atol = sqrt(eps(real(eltype(T))))
m4 = 4
Expand Down
32 changes: 30 additions & 2 deletions test/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder,
TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationUnion,
TruncationByOrder, TruncationByValue, TruncationStrategy, findtruncated, findtruncated_svd

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
Expand Down Expand Up @@ -65,4 +65,32 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationByOrder,
@test issetequal(values[@constinferred(findtruncated(values, strategy))], values[2:5])
vals_sorted = sort(values; by = abs, rev = true)
@test vals_sorted[@constinferred(findtruncated_svd(vals_sorted, strategy))] == vals_sorted[1:4]

# TruncationUnion / minrank
trunc = @constinferred TruncationStrategy(; minrank = 3)
@test trunc isa TruncationByOrder
@test trunc == truncrank(3)

trunc = @constinferred TruncationStrategy(; atol, minrank = 3)
@test trunc isa TruncationUnion
@test trunc == trunctol(; atol) | truncrank(3)

# | operator
values2 = [1.0, 0.9, 0.5, 0.3, 0.01]
# trunctol keeps 1:3 (above 0.4), truncrank(4) keeps 1:4, union keeps 1:4
strategy = trunctol(; atol = 0.4) | truncrank(4)
@test @constinferred(findtruncated_svd(values2, strategy)) == 1:4
# trunctol keeps 1:3, truncrank(2) keeps 1:2, union keeps 1:3
strategy = trunctol(; atol = 0.4) | truncrank(2)
@test @constinferred(findtruncated_svd(values2, strategy)) == 1:3

# notrunc is absorbing for |
@test (notrunc() | truncrank(3)) isa NoTruncation
@test (truncrank(3) | notrunc()) isa NoTruncation

# TruncationUnion flattening
union1 = truncrank(2) | trunctol(; atol = 0.4)
union2 = union1 | truncrank(4)
@test union2 isa TruncationUnion
@test length(union2.components) == 3
end
Loading