Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

In development

- `preprocess_ts` now always sorts the tree sequence, which may change the order
of mutations

**Bugfixes**

- Removed an assertion in `split_disjoint_nodes` that is not guarenteed
to hold since table sorting changes the order of mutations in tskit 1.0.0 onwards

## [0.2.6] - 2026-03-06

Maintenance release.
Expand Down
90 changes: 89 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,44 @@ def test_inferred(self):
sequence_length=1e6,
random_seed=1,
)
# use a high mutation rate so as to get >1 mutation per site
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(sample_data).simplify()
inferred_ts = tsinfer.infer(sample_data)
split_ts = tsdate.util.split_disjoint_nodes(inferred_ts)
assert self.has_disjoint_nodes(inferred_ts)
assert not self.has_disjoint_nodes(split_ts)
assert split_ts.num_edges == inferred_ts.num_edges
assert split_ts.num_nodes > inferred_ts.num_nodes

def test_worked_example(self):
"""
The mutation table is reordered such that the mutation above 4 (7 after
splitting) is placed after the mutation above 5.
"""
tables = tskit.TableCollection()
tables.nodes.set_columns(
time=[0, 0, 0, 0, 1, 1, 2],
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 3,
)
tables.edges.set_columns(
left=[0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2],
right=[1, 1, 3, 3, 1, 3, 2, 2, 3, 3, 3],
child=[0, 1, 2, 3, 4, 5, 0, 1, 0, 1, 4],
parent=[4, 4, 5, 5, 6, 6, 6, 6, 4, 4, 6],
)
site_id = tables.sites.add_row(position=2.5, ancestral_state="A")
tables.mutations.add_row(site=site_id, node=4, time=1.5, derived_state="G")
tables.mutations.add_row(site=site_id, node=5, time=1.5, derived_state="G")
tables.sequence_length = 3.0
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
ts = tables.tree_sequence()
np.testing.assert_equal(ts.mutations_node, [4, 5])
ts2 = tsdate.util.split_disjoint_nodes(ts)
np.testing.assert_equal(ts2.mutations_node, [5, 7])


class TestPreprocessTs:
def verify(self, ts, caplog, minimum_gap=None, erase_flanks=None, **kwargs):
Expand Down Expand Up @@ -396,6 +425,43 @@ def test_sim_example(self):
# Next assumes no breakpoints before first site or after last
assert ts.num_trees == num_trees + first_empty + last_empty

@pytest.mark.parametrize("split_disjoint", [True, False])
@pytest.mark.parametrize("keep_unary", [True, False])
def test_mutation_order(self, split_disjoint, keep_unary):
"""
Check that mutations are in sorted order after preprocessing
"""
tables = tskit.TableCollection()
tables.nodes.set_columns(
time=[0, 0, 0, 0, 1, 1.5, 1.25, 2],
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 4,
)
tables.edges.set_columns(
left=[0, 0, 0, 0, 0, 0, 0],
right=[3, 3, 3, 3, 3, 3, 3],
child=[0, 1, 2, 3, 4, 5, 6],
parent=[4, 4, 6, 6, 5, 7, 7],
)
site_id = tables.sites.add_row(position=2.5, ancestral_state="A")
tables.mutations.add_row(
site=site_id, node=5, time=tskit.UNKNOWN_TIME, derived_state="G"
)
tables.mutations.add_row(
site=site_id, node=6, time=tskit.UNKNOWN_TIME, derived_state="G"
)
tables.sequence_length = 3.0
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
ts0 = tables.tree_sequence()
tables.simplify(keep_unary=keep_unary)
tables.sort()
ts1 = tables.tree_sequence()
ts2 = tsdate.preprocess_ts(
ts0, split_disjoint=split_disjoint, keep_unary=keep_unary
)
assert np.array_equal(ts1.mutations_node, ts2.mutations_node)


class TestUnaryNodeCheck:
def test_inferred(self):
Expand Down Expand Up @@ -430,3 +496,25 @@ def test_simulated(self):
assert not tsdate.util.contains_unary_nodes(simplified_ts)
with pytest.raises(ValueError, match="contains unary nodes"):
tsdate.date(ts, mutation_rate=1e-8, method="variational_gamma")


class TestInferencePipeline:
"""
Test that tsinfer->preprocess_ts->tsdate runs through
"""

def test_inference_pipeline(self):
ts = msprime.sim_ancestry(
10,
population_size=1e4,
recombination_rate=1e-8,
sequence_length=1e6,
random_seed=1,
)
print(tskit.__version__, tsinfer.__version__)
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
tsdate.date(
tsdate.preprocess_ts(tsinfer.infer(sample_data)),
mutation_rate=1e-8,
)
12 changes: 6 additions & 6 deletions tsdate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def preprocess_ts(
:param \\**kwargs: All further keyword arguments are passed to the
{meth}`tskit.TreeSequence.simplify` command.

:return: A tree sequence with gaps removed.
:return: A tree sequence with gaps removed and disjoint node segments split.
:rtype: tskit.TreeSequence
"""

Expand Down Expand Up @@ -194,8 +194,13 @@ def preprocess_ts(
record_provenance=False,
**kwargs,
)
tables.sort()
if split_disjoint:
ts = split_disjoint_nodes(tables.tree_sequence(), record_provenance=False)
logger.info(
f"Split disjoint node segments from {tables.nodes.num_rows} "
f"nodes into {ts.num_nodes} nodes"
)
tables = ts.dump_tables()
if record_provenance:
provenance.record_provenance(
Expand Down Expand Up @@ -582,11 +587,6 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
tables.sort()
tables.build_index()
tables.compute_mutation_parents()

assert np.array_equal(
tables.nodes.time[tables.mutations.node], ts.nodes_time[ts.mutations_node]
)

if record_provenance:
provenance.record_provenance(
tables,
Expand Down
Loading