diff --git a/lib/ModelingToolkitTearing/src/reassemble.jl b/lib/ModelingToolkitTearing/src/reassemble.jl index ebdb4cd..a6b46b3 100644 --- a/lib/ModelingToolkitTearing/src/reassemble.jl +++ b/lib/ModelingToolkitTearing/src/reassemble.jl @@ -787,8 +787,11 @@ function __reduce_linear_system!(A::StateSelection.CLIL.SparseMatrixCLIL{Num, In cst += coeff * other_cst end - # `sparsevec` sums duplicate indices but keeps explicit zeros; drop them. - aliases[var] = SparseArrays.dropzeros!(SparseArrays.sparsevec(new_I, new_V, length(coeffs))) + # `sparsevec` sums duplicate indices but keeps explicit zeros. Drop them + # with a structural zero check + v = SparseArrays.sparsevec(new_I, new_V, N) + SparseArrays.fkeep!((_, x) -> !StateSelection.CLIL.cheap_iszero(x), v) + aliases[var] = v constants[var] = cst end diff --git a/lib/ModelingToolkitTearing/src/stateselection_interface.jl b/lib/ModelingToolkitTearing/src/stateselection_interface.jl index 7995305..9eb5e39 100644 --- a/lib/ModelingToolkitTearing/src/stateselection_interface.jl +++ b/lib/ModelingToolkitTearing/src/stateselection_interface.jl @@ -105,6 +105,13 @@ function StateSelection.linear_subsys_adjmat!(state::TearingState; kwargs...) return mm end +# Structural zero check for symbolic CLIL values: `Base.iszero(::Num)` performs +# a semantic (expansion-based) zero test that can OOM on large coefficient +# expressions (e.g. multibody models), while explicit stored zeros produced by +# duplicate-index summation are always structural `Const(0)`. +StateSelection.CLIL.cheap_iszero(x::Num) = SU._iszero(Symbolics.unwrap(x)) +StateSelection.CLIL.cheap_iszero(x::SymbolicT) = SU._iszero(x) + function maybe_zeros_descend(ex::SymbolicT) @match ex begin BSImpl.AddMul(; variant) => return variant === SU.AddMulVariant.MUL diff --git a/src/math/sparsematrixclil.jl b/src/math/sparsematrixclil.jl index ad5a5a1..2d9b15e 100644 --- a/src/math/sparsematrixclil.jl +++ b/src/math/sparsematrixclil.jl @@ -91,6 +91,18 @@ zero!(a::SparseVector) = (empty!(a.nzind); empty!(a.nzval)) zero!(a::CLILVector) = zero!(a.vec) SparseArrays.dropzeros!(a::CLILVector) = SparseArrays.dropzeros!(a.vec) +""" + cheap_iszero(x) + +Structural zero check used by [`SparseArrays.dropzeros!`](@ref) on +`SparseMatrixCLIL`. Defaults to `Base.iszero`. Downstream packages whose CLIL +value type is symbolic should overload this with a cheap *structural* check: +`Base.iszero` on e.g. `Symbolics.Num` performs a semantic (expansion-based) +zero test that can be arbitrarily expensive on large expressions, while +explicit stored zeros are always structural zeros. +""" +cheap_iszero(x) = iszero(x) + # Remove explicitly-stored zeros from each row, in place. function SparseArrays.dropzeros!(S::SparseMatrixCLIL) for r in eachindex(S.row_vals) @@ -98,7 +110,7 @@ function SparseArrays.dropzeros!(S::SparseMatrixCLIL) vals = S.row_vals[r] j = 0 for k in eachindex(vals) - iszero(vals[k]) && continue + cheap_iszero(vals[k]) && continue j += 1 cols[j] = cols[k] vals[j] = vals[k] diff --git a/src/utils.jl b/src/utils.jl index 5960c36..ccc2f4e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -345,7 +345,8 @@ function get_new_mm( # entry: a prior cancellation may have `pop!`ed the matching entry. if !isempty(final_row_cols) && col == final_row_cols[end] final_row_vals[end] += new_row_val_i[indices[i]] - if iszero(final_row_vals[end]) + # Structural zero check: + if CLIL.cheap_iszero(final_row_vals[end]) pop!(final_row_cols) pop!(final_row_vals) end diff --git a/test/runtests.jl b/test/runtests.jl index 1a03238..83730d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,16 @@ using Test include("bareiss.jl") include("carpanzano_tearing.jl") +# A value type whose `Base.iszero` is "semantic" and must never be consulted by +# `dropzeros!` — stands in for symbolic value types (e.g. `Symbolics.Num`), +# where the semantic zero test is arbitrarily expensive. `dropzeros!` must go +# through the `cheap_iszero` hook instead. +struct SemanticZero + x::Int +end +Base.iszero(::SemanticZero) = error("semantic iszero must not be called by dropzeros!") +SSel.CLIL.cheap_iszero(v::SemanticZero) = v.x == 0 + @testset "`get_new_mm`" begin mm = SSel.CLIL.SparseMatrixCLIL( [ @@ -72,3 +82,11 @@ include("carpanzano_tearing.jl") @test mm2.row_cols == [[2]] @test mm2.row_vals == [[1]] end + +@testset "`dropzeros!` uses the `cheap_iszero` hook" begin + mm = SSel.CLIL.SparseMatrixCLIL{SemanticZero, Int}( + 1, 2, [1], [[1, 2]], [[SemanticZero(0), SemanticZero(3)]]) + SparseArrays.dropzeros!(mm) + @test mm.row_cols == [[2]] + @test map(v -> v.x, only(mm.row_vals)) == [3] +end