Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 122 additions & 20 deletions lib/ModelingToolkitTearing/src/tearingstate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,6 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

if sort_eqs
# sort equations lexicographically to reduce simplification issues
# depending on order due to NP-completeness of tearing. Sort on a
# bounded prefix of the printed form: the full `string` of an equation
# is exponential in the sharing depth of hash-consed expressions, and
# ties on the first 4096 bytes keep their original (deterministic)
# relative order since the default sort is stable.
sortidxs = Base.sortperm(map(Base.Fix2(bounded_string, 4096), eqs))
eqs = eqs[sortidxs]
original_eqs = original_eqs[sortidxs]
symbolic_incidence = symbolic_incidence[sortidxs]
if !isempty(sources)
sources = sources[sortidxs]
end
end

dervaridxs = OrderedSet{Int}()
add_intermediate_derivatives!(fullvars, dervaridxs, addvar!)
# Handle shifts - find lowest shift and add intermediates with derivative edges
Expand All @@ -414,12 +398,27 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
# build `var_to_diff`
var_to_diff = build_var_to_diff(fullvars, ndervars, var2idx, iv)

# build incidence graph
graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx)

state_priorities = build_state_priorities(sys, fullvars, var_to_diff)
canonical_ranks = build_canonical_ranks(fullvars)

if sort_eqs
sortkeys = Vector{EquationSortKeyT}(undef, length(eqs))
cache = Base.IdDict{SymbolicT, EquationSortKeyT}()
for (i, eq) in enumerate(eqs)
sortkeys[i] = get_equation_sort_key!(cache, eq, var2idx, canonical_ranks)
end
sortidxs = Base.sortperm(sortkeys)
eqs = eqs[sortidxs]
original_eqs = original_eqs[sortidxs]
symbolic_incidence = symbolic_incidence[sortidxs]
if !isempty(sources)
sources = sources[sortidxs]
end
end

# build incidence graph
graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx)

# Identify unknowns that do not appear in any equations and are thus not present in
# `fullvars`. The bindings and initial conditions for these variables should be removed.
for v in fullvars
Expand Down Expand Up @@ -449,6 +448,109 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
typeof(sys)[], sources, nothing)
end

"""
Key used for sorting equations. Each element is a tuple of
(`canonical_rank` of variable, constant coefficient, exponent).
"""
const EquationSortKeyT = Vector{Tuple{Int, Float64, Float64}}

function get_equation_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, eq::Equation,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
return get_expression_sort_key!(cache, eq.rhs, var2idx, canonical_ranks)
end

function get_expression_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
val = get(cache, expr, nothing)
val === nothing || return val
val = __get_expression_sort_key!(cache, expr, var2idx, canonical_ranks)
cache[expr] = val
return val
end

function __get_expression_sort_key!(
cache::Base.IdDict{SymbolicT, EquationSortKeyT}, expr::SymbolicT,
var2idx::Dict{SymbolicT, Int}, canonical_ranks::Vector{Int}
)
@match expr begin
# Use `0` as the rank for constants
BSImpl.Const(; val) => if val isa Real
return [(0, convert(Float64, val), 1.0)]
else
return eltype(EquationSortKeyT)[]
end
BSImpl.Sym(;) => begin
idx = get(var2idx, expr, nothing)
idx === nothing && return eltype(EquationSortKeyT)[]
return [(canonical_ranks[idx], 1.0, 1.0)]
end
BSImpl.AddMul(; coeff, dict, variant) => begin
result = eltype(EquationSortKeyT)[]
ks = collect(keys(dict))
sort!(ks; lt = SU.:(<ₑ))
if variant == SU.AddMulVariant.ADD
for k in ks
arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks)
v = dict[k]
if !(v isa Real)
append!(result, arg_k)
continue
end
v = convert(Float64, v)
for t in arg_k
push!(result, (t[1], t[2] * v, t[3]))
end
end
else
if coeff isa Real
cf = convert(Float64, coeff)
else
cf = 1.0
end
for k in ks
arg_k = get_expression_sort_key!(cache, k, var2idx, canonical_ranks)
v = dict[k]
if !(v isa Real)
append!(result, arg_k)
continue
end
v = convert(Float64, v)
for t in arg_k
push!(result, (t[1], t[2] ^ v * cf, t[3] + v))
end
end
end
return result
end
_ => begin
idx = get(var2idx, expr, nothing)
idx === nothing || return [(canonical_ranks[idx], 1.0, 1.0)]
f = operation(expr)
args = arguments(expr)
if f === (^)
base_key = get_expression_sort_key!(cache, args[1], var2idx, canonical_ranks)
@match args[2] begin
BSImpl.Const(; val) => if val isa Real
v = convert(Float64, val)
return map(k -> (k[1], k[2] ^ v, k[3] + v), base_key)
else
return vcat(base_key, get_expression_sort_key!(cache, args[2], var2idx, canonical_ranks))
end
end
return base_key
end
result = eltype(EquationSortKeyT)[]
for arg in args
append!(result, get_expression_sort_key!(cache, arg, var2idx, canonical_ranks))
end
return result
end
end
end
"""
$TYPEDSIGNATURES

Expand Down Expand Up @@ -527,7 +629,7 @@ function canonical_sort_key(v::SymbolicT)
end
_ => nothing
end
return (canonical_name(x), idxs, opsig)
return (opsig, canonical_name(x), idxs)
end

"""
Expand Down
11 changes: 10 additions & 1 deletion src/singularity_removal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,16 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti}
end
end
solvable_variables = findall(is_linear_variables)
var_priorities = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
sp = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
cr = has_canonical_ranks(structure) ? get_canonical_ranks(structure) : nothing
var_priorities = if cr === nothing
sp
elseif sp === nothing
cr
else
big = maximum(cr; init = 0) + 1
Int[sp[i] * big + cr[i] for i in eachindex(sp)]
end

bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff, var_priorities)

Expand Down
Loading