@@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):
10271027
10281028def _truncate_priors (ts , priors , progress = False ):
10291029 """
1030- Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1031- if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1032- sequence
1030+ Truncate priors for all nonfixed nodes
1031+ so they conform to the age of fixed nodes in the tree sequence
10331032 """
10341033 tables = ts .tables
1035- truncate_nodes = priors .nonfixed_node_ids ()
1036- # ensure truncate_nodes is ordered by node time
1037- truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
10381034
10391035 fixed_nodes = priors .fixed_node_ids ()
10401036 fixed_times = tables .nodes .time [fixed_nodes ]
@@ -1050,29 +1046,32 @@ def _truncate_priors(ts, priors, progress=False):
10501046 constrained_min_times = np .zeros_like (tables .nodes .time )
10511047 # Set the min times of fixed nodes to those in the tree sequence
10521048 constrained_min_times [fixed_nodes ] = fixed_times
1053- constrained_max_times = np .full_like (constrained_min_times , np .inf )
1054-
1055- parents = tables .edges .parent
1056- nd_children = tables .edges .child [np .argsort (parents )]
1057- parents = sorted (parents )
1058- parents_unique = np .unique (parents , return_index = True )
1059- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1060- for index , nd in tqdm (
1061- enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1049+
1050+ # Traverse through the ARG, ensuring children come before parents.
1051+ # This can be done by iterating over groups of edges with the same parent
1052+ new_parent_edge_idx = np .concatenate (
1053+ (
1054+ [0 ],
1055+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1056+ [tables .edges .num_rows ],
1057+ )
1058+ )
1059+ for edges_start , edges_end in zip (
1060+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
10621061 ):
1063- if index + 1 != len ( truncate_nodes ):
1064- children_index = np . arange ( parent_indices [ index ], parent_indices [ index + 1 ])
1065- else :
1066- children_index = np . arange ( parent_indices [ index ], ts . num_edges )
1067- children = nd_children [ children_index ]
1068- time = np . max ( constrained_min_times [ children ])
1069- # The constrained time of the node should be the age of the oldest child
1070- if constrained_min_times [ nd ] <= time :
1071- constrained_min_times [ nd ] = time
1072- nearest_time = np . argmin ( np . abs ( timepoints - time ))
1073- lookup_index = priors .row_lookup [ int ( nd )]
1074- grid_data [ lookup_index ][: nearest_time ] = zero_value
1075- assert np . all ( constrained_min_times < constrained_max_times )
1062+ parent = tables . edges . parent [ edges_start ]
1063+ child_ids = tables . edges . child [ edges_start : edges_end ] # May contain dups
1064+ oldest_child_time = np . max ( constrained_min_times [ child_ids ])
1065+ if oldest_child_time > constrained_min_times [ parent ]:
1066+ if priors . is_fixed ( parent ):
1067+ raise ValueError (
1068+ "Invalid fixed times: time for"
1069+ + f"fixed node { parent } is younger than some of its descendants"
1070+ )
1071+ constrained_min_times [ parent ] = oldest_child_time
1072+ if constrained_min_times [ parent ] > 0 and not priors .is_fixed ( parent ):
1073+ nearest_time = np . argmin ( np . abs ( timepoints - constrained_min_times [ parent ]))
1074+ grid_data [ priors . row_lookup [ parent ]][: nearest_time ] = zero_value
10761075
10771076 rowmax = grid_data [:, 1 :].max (axis = 1 )
10781077 if priors .probability_space == "linear" :
@@ -1132,7 +1131,7 @@ def build_grid(
11321131 :param dict node_var_override: is a dict mapping node IDs to a variance value.
11331132 Any nodes listed here will be treated as non-fixed nodes whose prior is not
11341133 calculated from the conditional coalescent but instead are allocated a prior
1135- whose mean is thenode time in the tree sequence and whose variance is the
1134+ whose mean is the node time in the tree sequence and whose variance is the
11361135 value in this dictionary. This allows sample nodes to be treated as nonfixed
11371136 nodes, and therefore dated. If ``None`` (default) then all sample nodes are
11381137 treated as occurring ata fixed time (as if this were an empty dict).
0 commit comments