diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fbc8c6c..11970a08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ In development +- `preprocess_ts` now always sorts the tree sequence, which may change the order + of mutations + +**Bugfixes** + +- Removed an assertion in `split_disjoint_nodes` that is not guarenteed + to hold since table sorting changes the order of mutations in tskit 1.0.0 onwards + ## [0.2.6] - 2026-03-06 Maintenance release. diff --git a/tests/test_util.py b/tests/test_util.py index d175921e..b43151b1 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -193,15 +193,44 @@ def test_inferred(self): sequence_length=1e6, random_seed=1, ) + # use a high mutation rate so as to get >1 mutation per site ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) sample_data = tsinfer.SampleData.from_tree_sequence(ts) - inferred_ts = tsinfer.infer(sample_data).simplify() + inferred_ts = tsinfer.infer(sample_data) split_ts = tsdate.util.split_disjoint_nodes(inferred_ts) assert self.has_disjoint_nodes(inferred_ts) assert not self.has_disjoint_nodes(split_ts) assert split_ts.num_edges == inferred_ts.num_edges assert split_ts.num_nodes > inferred_ts.num_nodes + def test_worked_example(self): + """ + The mutation table is reordered such that the mutation above 4 (7 after + splitting) is placed after the mutation above 5. + """ + tables = tskit.TableCollection() + tables.nodes.set_columns( + time=[0, 0, 0, 0, 1, 1, 2], + flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 3, + ) + tables.edges.set_columns( + left=[0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2], + right=[1, 1, 3, 3, 1, 3, 2, 2, 3, 3, 3], + child=[0, 1, 2, 3, 4, 5, 0, 1, 0, 1, 4], + parent=[4, 4, 5, 5, 6, 6, 6, 6, 4, 4, 6], + ) + site_id = tables.sites.add_row(position=2.5, ancestral_state="A") + tables.mutations.add_row(site=site_id, node=4, time=1.5, derived_state="G") + tables.mutations.add_row(site=site_id, node=5, time=1.5, derived_state="G") + tables.sequence_length = 3.0 + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + ts = tables.tree_sequence() + np.testing.assert_equal(ts.mutations_node, [4, 5]) + ts2 = tsdate.util.split_disjoint_nodes(ts) + np.testing.assert_equal(ts2.mutations_node, [5, 7]) + class TestPreprocessTs: def verify(self, ts, caplog, minimum_gap=None, erase_flanks=None, **kwargs): @@ -396,6 +425,43 @@ def test_sim_example(self): # Next assumes no breakpoints before first site or after last assert ts.num_trees == num_trees + first_empty + last_empty + @pytest.mark.parametrize("split_disjoint", [True, False]) + @pytest.mark.parametrize("keep_unary", [True, False]) + def test_mutation_order(self, split_disjoint, keep_unary): + """ + Check that mutations are in sorted order after preprocessing + """ + tables = tskit.TableCollection() + tables.nodes.set_columns( + time=[0, 0, 0, 0, 1, 1.5, 1.25, 2], + flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 4, + ) + tables.edges.set_columns( + left=[0, 0, 0, 0, 0, 0, 0], + right=[3, 3, 3, 3, 3, 3, 3], + child=[0, 1, 2, 3, 4, 5, 6], + parent=[4, 4, 6, 6, 5, 7, 7], + ) + site_id = tables.sites.add_row(position=2.5, ancestral_state="A") + tables.mutations.add_row( + site=site_id, node=5, time=tskit.UNKNOWN_TIME, derived_state="G" + ) + tables.mutations.add_row( + site=site_id, node=6, time=tskit.UNKNOWN_TIME, derived_state="G" + ) + tables.sequence_length = 3.0 + tables.sort() + tables.build_index() + tables.compute_mutation_parents() + ts0 = tables.tree_sequence() + tables.simplify(keep_unary=keep_unary) + tables.sort() + ts1 = tables.tree_sequence() + ts2 = tsdate.preprocess_ts( + ts0, split_disjoint=split_disjoint, keep_unary=keep_unary + ) + assert np.array_equal(ts1.mutations_node, ts2.mutations_node) + class TestUnaryNodeCheck: def test_inferred(self): @@ -430,3 +496,25 @@ def test_simulated(self): assert not tsdate.util.contains_unary_nodes(simplified_ts) with pytest.raises(ValueError, match="contains unary nodes"): tsdate.date(ts, mutation_rate=1e-8, method="variational_gamma") + + +class TestInferencePipeline: + """ + Test that tsinfer->preprocess_ts->tsdate runs through + """ + + def test_inference_pipeline(self): + ts = msprime.sim_ancestry( + 10, + population_size=1e4, + recombination_rate=1e-8, + sequence_length=1e6, + random_seed=1, + ) + print(tskit.__version__, tsinfer.__version__) + ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1) + sample_data = tsinfer.SampleData.from_tree_sequence(ts) + tsdate.date( + tsdate.preprocess_ts(tsinfer.infer(sample_data)), + mutation_rate=1e-8, + ) diff --git a/tsdate/util.py b/tsdate/util.py index 5594e34f..a31e0960 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -114,7 +114,7 @@ def preprocess_ts( :param \\**kwargs: All further keyword arguments are passed to the {meth}`tskit.TreeSequence.simplify` command. - :return: A tree sequence with gaps removed. + :return: A tree sequence with gaps removed and disjoint node segments split. :rtype: tskit.TreeSequence """ @@ -194,8 +194,13 @@ def preprocess_ts( record_provenance=False, **kwargs, ) + tables.sort() if split_disjoint: ts = split_disjoint_nodes(tables.tree_sequence(), record_provenance=False) + logger.info( + f"Split disjoint node segments from {tables.nodes.num_rows} " + f"nodes into {ts.num_nodes} nodes" + ) tables = ts.dump_tables() if record_provenance: provenance.record_provenance( @@ -582,11 +587,6 @@ def split_disjoint_nodes(ts, *, record_provenance=None): tables.sort() tables.build_index() tables.compute_mutation_parents() - - assert np.array_equal( - tables.nodes.time[tables.mutations.node], ts.nodes_time[ts.mutations_node] - ) - if record_provenance: provenance.record_provenance( tables,