diff --git a/lib/ModelingToolkitTearing/src/tearingstate.jl b/lib/ModelingToolkitTearing/src/tearingstate.jl index f466f7c..a7d783b 100644 --- a/lib/ModelingToolkitTearing/src/tearingstate.jl +++ b/lib/ModelingToolkitTearing/src/tearingstate.jl @@ -382,22 +382,6 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS neqs = length(eqs) symbolic_incidence = symbolic_incidence[eqs_to_retain] - if sort_eqs - # sort equations lexicographically to reduce simplification issues - # depending on order due to NP-completeness of tearing. Sort on a - # bounded prefix of the printed form: the full `string` of an equation - # is exponential in the sharing depth of hash-consed expressions, and - # ties on the first 4096 bytes keep their original (deterministic) - # relative order since the default sort is stable. - sortidxs = Base.sortperm(map(Base.Fix2(bounded_string, 4096), eqs)) - eqs = eqs[sortidxs] - original_eqs = original_eqs[sortidxs] - symbolic_incidence = symbolic_incidence[sortidxs] - if !isempty(sources) - sources = sources[sortidxs] - end - end - dervaridxs = OrderedSet{Int}() add_intermediate_derivatives!(fullvars, dervaridxs, addvar!) # Handle shifts - find lowest shift and add intermediates with derivative edges @@ -414,12 +398,27 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS # build `var_to_diff` var_to_diff = build_var_to_diff(fullvars, ndervars, var2idx, iv) - # build incidence graph - graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx) - state_priorities = build_state_priorities(sys, fullvars, var_to_diff) canonical_ranks = build_canonical_ranks(fullvars) + if sort_eqs + sortkeys = Vector{EquationSortKeyT}(undef, length(eqs)) + cache = Base.IdDict{SymbolicT, EquationSortKeyT}() + for (i, eq) in enumerate(eqs) + sortkeys[i] = get_equation_sort_key!(cache, eq, var2idx, canonical_ranks) + end + sortidxs = Base.sortperm(sortkeys) + eqs = eqs[sortidxs] + original_eqs = original_eqs[sortidxs] + symbolic_incidence = symbolic_incidence[sortidxs] + if !isempty(sources) + sources = sources[sortidxs] + end + end + + # build incidence graph + graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx) + # Identify unknowns that do not appear in any equations and are thus not present in # `fullvars`. The bindings and initial conditions for these variables should be removed. for v in fullvars @@ -449,6 +448,109 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS typeof(sys)[], sources, nothing) end +""" +Key used for sorting equations. Each element is a tuple of +(`canonical_rank` of variable, constant coefficient, exponent). +""" +const EquationSortKeyT = Vector{Tuple{Int, Float64, Float64}} + +function get_equation_sort_key!( + cache::Base.IdDict{SymbolicT, EquationSortKeyT}, eq::Equation, + var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int} + ) + return get_expression_sort_key!(cache, eq.rhs, var2idx, canonical_ranks) +end + +function get_expression_sort_key!( + cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT, + var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int} + ) + val = get(cache, expr, nothing) + val === nothing || return val + val = __get_expression_sort_key!(cache, expr, var2idx, canonical_ranks) + cache[expr] = val + return val +end + +function __get_expression_sort_key!( + cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT, + var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int} + ) + @match expr begin + # Use `0` as the rank for constants + BSImpl.Const(; val) => if val isa Real + return [(0, convert(Float64, val), 1.0)] + else + return eltype(EquationSortKeyT)[] + end + BSImpl.Sym(;) => begin + idx = get(var2idx, expr, nothing) + idx === nothing && return eltype(EquationSortKeyT)[] + return [(canonical_ranks[idx], 1.0, 1.0)] + end + BSImpl.AddMul(; coeff, dict, variant) => begin + result = eltype(EquationSortKeyT)[] + ks = collect(keys(dict)) + sort!(ks; lt = SU.:(<ₑ)) + if variant == SU.AddMulVariant.ADD + for k in ks + arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks) + v = dict[k] + if !(v isa Real) + append!(result, arg_k) + continue + end + v = convert(Float64, v) + for t in arg_k + push!(result, (t[1], t[2] * v, t[3])) + end + end + else + if coeff isa Real + cf = convert(Float64, coeff) + else + cf = 1.0 + end + for k in ks + arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks) + v = dict[k] + if !(v isa Real) + append!(result, arg_k) + continue + end + v = convert(Float64, v) + for t in arg_k + push!(result, (t[1], t[2] ^ v * cf, t[3] + v)) + end + end + end + return result + end + _ => begin + idx = get(var2idx, expr, nothing) + idx === nothing || return [(canonical_ranks[idx], 1.0, 1.0)] + f = operation(expr) + args = arguments(expr) + if f === (^) + base_key = get_expression_sort_key!(cache, args[1], var2idx, canonical_ranks) + @match args[2] begin + BSImpl.Const(; val) => if val isa Real + v = convert(Float64, val) + return map(k -> (k[1], k[2] ^ v, k[3] + v), base_key) + else + return vcat(base_key, get_expression_sort_key!(cache, args[2], var2idx, canonical_ranks)) + end + end + return base_key + end + result = eltype(EquationSortKeyT)[] + for arg in args + append!(result, get_expression_sort_key!(cache, arg, var2idx, canonical_ranks)) + end + return result + end + end +end """ $TYPEDSIGNATURES @@ -527,7 +629,7 @@ function canonical_sort_key(v::SymbolicT) end _ => nothing end - return (canonical_name(x), idxs, opsig) + return (opsig, canonical_name(x), idxs) end """ diff --git a/src/singularity_removal.jl b/src/singularity_removal.jl index 7d6e189..e6f987d 100644 --- a/src/singularity_removal.jl +++ b/src/singularity_removal.jl @@ -264,7 +264,16 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} end end solvable_variables = findall(is_linear_variables) - var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing + sp = has_state_priorities(structure) ? get_state_priorities(structure) : nothing + cr = has_canonical_ranks(structure) ? get_canonical_ranks(structure) : nothing + var_priorities = if cr === nothing + sp + elseif sp === nothing + cr + else + big = maximum(cr; init = 0) + 1 + Int[sp[i] * big + cr[i] for i in eachindex(sp)] + end bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)