Skip to content
116 changes: 112 additions & 4 deletions lib/ModelingToolkitTearing/src/tearingstate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Base.@kwdef mutable struct SystemStructure <: StateSelection.SystemStructure
var_types::Vector{VariableType}
"""State priorities corresponding to each variable in `fullvars`"""
state_priorities::Vector{Int}
"""
Canonical rank of each variable in `fullvars`: the rank of the variable's name in
lexicographic order. Used as a deterministic tie-break (after priorities) in tearing,
so that results do not depend on equation/declaration order.
"""
canonical_ranks::Vector{Int}
"""Whether the system is discrete."""
only_discrete::Bool
end
Expand All @@ -29,7 +35,8 @@ function Base.copy(structure::SystemStructure)
var_types = structure.var_types === nothing ? nothing : copy(structure.var_types)
SystemStructure(copy(structure.var_to_diff), copy(structure.eq_to_diff),
copy(structure.graph), copy(structure.solvable_graph),
var_types, copy(structure.state_priorities), structure.only_discrete)
var_types, copy(structure.state_priorities), copy(structure.canonical_ranks),
structure.only_discrete)
end

StateSelection.is_only_discrete(s::SystemStructure) = s.only_discrete
Expand Down Expand Up @@ -399,6 +406,7 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx)

state_priorities = build_state_priorities(sys, fullvars, var_to_diff)
canonical_ranks = build_canonical_ranks(fullvars)

# Identify unknowns that do not appear in any equations and are thus not present in
# `fullvars`. The bindings and initial conditions for these variables should be removed.
Expand All @@ -422,20 +430,120 @@ function TearingState(sys::System, source_info::Union{Nothing, MTKBase.EquationS
eq_to_diff = StateSelection.DiffGraph(nsrcs(graph))

structure = SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, state_priorities, false)
complete(graph), nothing, var_types, state_priorities,
canonical_ranks, false)
return TearingState(sys, fullvars, structure, Equation[], param_derivative_map,
no_deriv_params, original_eqs, Equation[], falses(length(fullvars)),
typeof(sys)[], sources)
end

"""
$TYPEDSIGNATURES

Total, allocation-light "canonical name" for any expression that may appear in
`fullvars`: the name for named variables (`Sym` / call-variable / `getindex`),
the operation's name for operator/function applications, and a fixed sentinel
symbol for the remaining structural variants. Key collisions are acceptable —
they only mean the canonical tie-break falls back to the original order among
the colliding entries.
"""
function canonical_name(x::SymbolicT)
@match x begin
BSImpl.Sym(; name) => name
BSImpl.Term(; f, args) && if f === getindex end => canonical_name(args[1])
BSImpl.Term(; f) => canonical_opname(f)
BSImpl.AddMul(; variant) => variant === SU.AddMulVariant.ADD ? :+ : :*
BSImpl.Div(;) => :/
BSImpl.ArrayOp(;) => Symbol("#arrayop")
BSImpl.Const(;) => Symbol("#const")
_ => Symbol("#expr")
end
end
function canonical_opname(@nospecialize(f))
f isa SymbolicT && return canonical_name(f)
f isa Function && return nameof(f)::Symbol
return nameof(typeof(f))::Symbol
end

"""
$TYPEDSIGNATURES

Structured canonical sort key for a `fullvars` entry: a tuple
`(name, indices, opsig)` where `name` is the [`canonical_name`](@ref) of the
underlying (array) variable, `indices` are the integer indices when `v` is a
scalarized array element (empty otherwise) and `opsig` encodes the operator
chain wrapping the variable (`Differential` → `1`; `Shift` → `2` followed by
its step count; any other single-argument operator → `3`). Comparing these
tuples orders variables deterministically regardless of equation/declaration
order, without stringifying symbolic expressions. Total: compound expressions
(e.g. multi-argument clock operators over non-variable arguments) key off
their operation name.
"""
function canonical_sort_key(v::SymbolicT)
# `opsig`/`idxs` are built as `Vector{Int}` (not growing tuples) so the key
# type is concrete and inferrable after the loop. Vectors compare
# lexicographically, so the tuple ordering is unchanged.
opsig = Int[]
x = v
while true
stripped = @match x begin
BSImpl.Term(; f, args) && if f isa Differential end => begin
push!(opsig, 1)
args[1]
end
BSImpl.Term(; f, args) && if f isa MTKBase.Shift end => begin
push!(opsig, 2, Int(f.steps))
args[1]
end
BSImpl.Term(; f, args) && if f isa SU.Operator && length(args) == 1 end => begin
push!(opsig, 3)
args[1]
end
_ => nothing
end
stripped === nothing && break
x = stripped
end
idxs = Int[]
@match x begin
BSImpl.Term(; f, args) && if f === getindex end => begin
for i in Iterators.drop(args, 1)
ival = SU.isconst(i) ? manual_dispatch_is_small_int(unwrap_const(i))::Int : 0
push!(idxs, ival)
end
x = args[1]
end
_ => nothing
end
return (canonical_name(x), idxs, opsig)
end

"""
$TYPEDSIGNATURES

Rank of each variable in `fullvars` under the [`canonical_sort_key`](@ref) order.
Used as a deterministic tie-break (after priorities) in tearing.
"""
function build_canonical_ranks(fullvars::Vector{SymbolicT})
return invperm(sortperm(map(canonical_sort_key, fullvars)))
end

function build_state_priorities(sys::System, fullvars::Vector{SymbolicT}, var_to_diff::StateSelection.DiffGraph)
# Cache the state priorities
sps = state_priorities(sys)
var_priorities = Int[]
sizehint!(var_priorities, length(fullvars))
for v in fullvars
arr, _ = MTKBase.split_indexed_var(v)
push!(var_priorities, get(sps, arr, 0))
# Component-granular lookup: a priority on the scalarized variable
# itself (e.g. `v_0[1] => 5`) takes precedence over a priority on its
# parent array variable. This allows distinguishing array components,
# which is impossible with the parent-array-only lookup.
p = get(sps, v, nothing)
if p === nothing
arr, _ = MTKBase.split_indexed_var(v)
p = get(sps, arr, 0)
end
push!(var_priorities, round(Int, p))
end

# Propagate priorities up the `var_to_diff` graph. Each variable's final priority is
Expand Down
4 changes: 3 additions & 1 deletion lib/ModelingToolkitTearing/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,14 @@ function permute(structure::SystemStructure, oldtonewvar::Vector{Int}, oldtonewe

var_types = similar(structure.var_types)
sps = similar(structure.state_priorities)
cranks = similar(structure.canonical_ranks)
for i in 𝑑vertices(structure.graph)
var_types[oldtonewvar[i]] = structure.var_types[i]
sps[oldtonewvar[i]] = structure.state_priorities[i]
cranks[oldtonewvar[i]] = structure.canonical_ranks[i]
end

return SystemStructure(var_to_diff, eq_to_diff, graph, solvable_graph, var_types, sps, structure.only_discrete)
return SystemStructure(var_to_diff, eq_to_diff, graph, solvable_graph, var_types, sps, cranks, structure.only_discrete)
end

"""
Expand Down
40 changes: 35 additions & 5 deletions src/carpanzano_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ function carpanzano_tear_scc!(
# way to implement algorithm A2 and analyze the benefits.

(; graph, solvable_graph) = structure
# Variable priorities (when available) act as tie-breaks for the tear-variable
# selection below: among otherwise-equivalent candidates, prefer marking the
# variable with the HIGHEST priority as algebraic (torn), mirroring the
# `state_priority` semantics of dummy-derivative state selection. With uniform
# priorities the behavior is unchanged.
varpriority = has_state_priorities(structure) ? get_state_priorities(structure) : nothing
# Canonical (name-rank) tie-break after priorities: makes the tear-variable
# selection deterministic under equation/declaration reordering.
canonrank = has_canonical_ranks(structure) ? get_canonical_ranks(structure) : nothing
# Sort key: prefer HIGHER priority, then SMALLER canonical rank.
tearkey = v -> (varpriority === nothing ? 0 : -varpriority[v],
canonrank === nothing ? 0 : canonrank[v])
# Find variables which cannot be solved for using any of the equations in this SCC,
# and remove them from `active_vars`.
filter!(Base.Fix1(any, in(active_eqs)) ∘ Base.Fix1(𝑑neighbors, solvable_graph), active_vars)
Expand Down Expand Up @@ -224,8 +236,12 @@ function carpanzano_tear_scc!(
end

# Find a variable in `active_vars` which is present in one of these equations
# and yet is not solvable in it.
# and yet is not solvable in it. With priorities available, scan all
# min-incidence equations and pick the highest-priority such variable
# (strict comparison: first-found wins ties, matching the unprioritized
# behavior when priorities are uniform).
found_algvar = false
alg_candidate = 0
for ieq in enodes_with_min_incidence
empty!(non_solvable_incidence)
append!(non_solvable_incidence, 𝑠neighbors(graph, ieq))
Expand All @@ -238,14 +254,23 @@ function carpanzano_tear_scc!(
# We didn't update the matching, and the algorithm requires
# all variables in `active_vars` be initially matched to `unassigned` so
# this automatically makes it algebraic.
found_algvar = true
delete!(active_vars, ivar)
break
if varpriority === nothing && canonrank === nothing
alg_candidate = ivar
found_algvar = true
break
elseif alg_candidate == 0 || tearkey(ivar) < tearkey(alg_candidate)
alg_candidate = ivar
end
end

found_algvar && break
end

if alg_candidate != 0
delete!(active_vars, alg_candidate)
found_algvar = true
end

found_algvar && continue

# Heuristic 2:
Expand All @@ -255,13 +280,18 @@ function carpanzano_tear_scc!(
alg_var = 0
max_incidence_cnt = typemin(Int)
min_solvable_cnt = typemax(Int)
best_key = (typemax(Int), typemax(Int))
for ivar in active_vars
cnt = count(in(active_eqs), 𝑑neighbors(graph, ivar))
solvable_cnt = count(in(active_eqs), 𝑑neighbors(solvable_graph, ivar))
if iszero(alg_var) || cnt > max_incidence_cnt || cnt == max_incidence_cnt && solvable_cnt < min_solvable_cnt
key = tearkey(ivar)
if iszero(alg_var) || cnt > max_incidence_cnt ||
cnt == max_incidence_cnt && (solvable_cnt < min_solvable_cnt ||
solvable_cnt == min_solvable_cnt && key < best_key)
alg_var = ivar
max_incidence_cnt = cnt
min_solvable_cnt = solvable_cnt
best_key = key
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ is_only_discrete(::SystemStructure) = false

has_state_priorities(::T) where {T <: SystemStructure} = hasfield(T, :state_priorities)
get_state_priorities(ss::SystemStructure) = ss.state_priorities
has_canonical_ranks(::T) where {T <: SystemStructure} = hasfield(T, :canonical_ranks)
get_canonical_ranks(ss::SystemStructure) = ss.canonical_ranks

"""
$TYPEDEF
Expand Down
20 changes: 16 additions & 4 deletions src/modia_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function try_assign_eq!(ict::IncrementalCycleTracker, vars, v_active, eq::Intege
end

function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int},
v_active::BitSet, isder′::F) where {F}
v_active::BitSet, isder′::F, varpriority::P = nothing) where {F, P}
check_der = isder′ !== nothing
if check_der
has_der = Ref(false)
Expand All @@ -38,6 +38,16 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
for eq in es # iterate only over equations that are not in eSolvedFixed
vs = Gsolvable[eq]
((length(vs) == 1) ⊻ only_single_solvable) && continue
if varpriority !== nothing && length(vs) > 1
# Prefer solving the equation for the candidate with the LOWEST
# priority, so that high-priority variables remain tear
# (iteration) variables — mirroring the `state_priority`
# semantics of dummy-derivative state selection. The sort is
# stable, so behavior is unchanged whenever priorities are equal
# (e.g. all default 0). Do not mutate the graph's adjacency list.
vs = copy(vs)
sort!(vs; by = varpriority, alg = Base.Sort.DEFAULT_STABLE)
end
if check_der
# if there're differentiated variables, then only consider them
try_assign_eq!(ict, vs, v_active, eq, isder)
Expand All @@ -54,8 +64,8 @@ function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int}
end

function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars,
isder::F) where {F}
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder)
isder::F, varpriority::P = nothing) where {F, P}
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder, varpriority)
for var in vars
var_eq_matching[var] = ict.graph.matching[var]
end
Expand Down Expand Up @@ -92,6 +102,8 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
# find them here [TODO: It would be good to have an explicit example of this.]

(; graph, solvable_graph) = structure
varpriority = has_state_priorities(structure) ?
Base.Fix1(getindex, get_state_priorities(structure)) : nothing
var_eq_matching = maximal_matching(graph, U,
srcfilter=eqfilter,
dstfilter=varfilter)
Expand Down Expand Up @@ -123,7 +135,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
end
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, ieqs,
filtered_vars,
isder)
isder, varpriority)
update_full_var_eq_matching!(graph, full_var_eq_matching, var_eq_matching, vars, remaining_eqs; varfilter)

# If the systems is overdetemined, we cannot assume the free equations
Expand Down
Loading