diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 220afbd..601feef 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -5,6 +5,10 @@ from cmlibs.utils.zinc.field import get_group_list from cmlibs.utils.zinc.group import group_get_highest_dimension, groups_have_same_local_contents from cmlibs.zinc.field import Field +import logging + + +logger = logging.getLogger(__name__) class AnnotationCategory(Enum): @@ -24,9 +28,18 @@ def get_group_name(self): """ return '.' + self.name + def get_lower_name(self): + """ + :return: Lower case category name. + """ + return self.name.lower() + def is_connectable(self): return self in (self.INDEPENDENT_NETWORK, self.NETWORK_GROUP_1, self.NETWORK_GROUP_2) + def is_connectable_different_annotation(self): + return self.is_connectable() and not (self == self.INDEPENDENT_NETWORK) + class Annotation: """ @@ -56,8 +69,8 @@ def decode_settings(self, settings_in: dict): assert (settings_in.get("name") == self._name) and (settings_in.get("term") == self._term) settings_dimension = settings_in.get("dimension") if settings_dimension != self._dimension: - print("WARNING: Segmentation Stitcher. Annotation with name", self._name, "term", self._term, - "was dimension ", settings_dimension, "in settings, is now ", self._dimension, + logger.warning("Segmentation Stitcher. Annotation with name " + self._name, " term " + str(self._term) + + " was dimension " + str(settings_dimension), "in settings, is now " + str(self._dimension) + ". Have input files changed?") settings_in["dimension"] = self._dimension # update current settings to gain new ones and override old ones @@ -125,6 +138,31 @@ def set_term(self, term): assert self._term is None self._term = term + def clear_term(self): + """ + Clear term to None, call in cases of mismatched terms for the same group name. + """ + self._term = None + + def is_connectable(self): + """ + :return: True if annotation's category is connectable. + """ + return self._category.is_connectable() + + def is_connectable_with(self, other_annotation): + """ + Query whether ends annotated with self and other_annotation can be connected. + :param other_annotation: Another Annotation object. + :return: True if self and other_annotation are allowed to be connected by a link. + """ + if self._category.is_connectable() and (self._category == other_annotation.get_category()): + if other_annotation is self: + return True + elif self._category.is_connectable_different_annotation(): + return True + return False + def region_get_annotations(region, network_group1_keywords, network_group2_keywords, term_keywords): """ @@ -146,6 +184,17 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo annotations = [] term_annotations = [] datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + # these terms have had slight mismatches with contents of url group, so explicitly matching: + segment_name = region.getParent().getName() + known_terms = { + "epineurium": "http://uri.interlex.org/base/ilx_0103892", + "left cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794142", + "right cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794141", + "left thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0787543", + "right thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0786664", + "left vagus x nerve trunk": "http://uri.interlex.org/base/ilx_0736691", + "right vagus x nerve trunk": "http://uri.interlex.org/base/ilx_0730515" + } for group in groups: # clean up name to remove case and leading/trailing whitespace name = group.getName().strip() @@ -159,7 +208,9 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo continue # empty group if lower_name.isdigit(): continue # ignore as these can never be valid annotation names - category = AnnotationCategory.GENERAL + if ' list of [segment0_node_identifier, segment1_node_identifier]] + self._annotation_links = {} # dict: annotation name --> list of {'lock': bool, 'node identifiers': list} for segment in self._segments: segment.add_transformation_change_callback(self._segment_transformation_change) @@ -62,14 +69,51 @@ def decode_settings(self, settings_in: dict): Update segment settings from JSON dict containing serialised settings. :param settings_in: Dictionary of settings as produced by encode_settings(). """ - settings_name = self._separator.join(settings_in["segments"]) + settings_name = self._separator.join(settings_in['segments']) assert settings_name == self._name # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) - linked_nodes = settings.get("linked nodes") - if isinstance(linked_nodes, dict): - self._linked_nodes = linked_nodes + # migrate from previous 'linked nodes' which had a list of node identifiers + linked_nodes = settings.get('linked nodes') + if linked_nodes is not None: + # migrate to new annotation links + annotation_links = {} + annotation_names = list(linked_nodes.keys()) + for annotation_name in annotation_names: + links = linked_nodes[annotation_name] + new_links = [] + if isinstance(links[0], list): + for node_identifiers in links: + new_links.append({'lock': False, 'node identifiers': node_identifiers}) + annotation_links[annotation_name] = new_links + del settings['linked nodes'] + settings['annotation links'] = annotation_links + else: + annotation_links = settings['annotation links'] + # check nodes exist for all links, otherwise remove stale links + segment_nodes = [segment.get_raw_region().getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + for segment in self._segments] + annotation_names = list(annotation_links.keys()) + for annotation_name in annotation_names: + links = annotation_links[annotation_name] + invalid_indexes = [] + for i, link in enumerate(links): + node_identifiers = link['node identifiers'] + invalid_link = False + for s, node_identifier in enumerate(node_identifiers): + if not segment_nodes[s].findNodeByIdentifier(node_identifier).isValid(): + logger.warning('Stitcher connection ' + self._name + ' annotation ' + annotation_name + + ' link missing node ' + str(node_identifier) + ' from segment ' + str(s + 1) + + '. Removing link.') + invalid_link = True + if invalid_link: + invalid_indexes.append(i) + for i in reversed(invalid_indexes): + links.pop(i) + if len(links) == 0: + del annotation_links[annotation_name] + self._annotation_links = annotation_links def encode_settings(self) -> dict: """ @@ -77,12 +121,12 @@ def encode_settings(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ settings = { - "segments": [segment.get_name() for segment in self._segments], - "linked nodes": self._linked_nodes + 'segments': [segment.get_name() for segment in self._segments], + 'annotation links': self._annotation_links } return settings - def printLog(self): + def printZincLog(self): logger = self._region.getContext().getLogger() for index in range(logger.getNumberOfMessages()): print(logger.getMessageTextAtIndex(index)) @@ -129,190 +173,244 @@ def get_segments(self): def _segment_transformation_change(self, segment): self.build_links() - self.update_annotation_category_groups(self._annotations) + self.update_annotation_category_groups() - def add_linked_nodes(self, annotation, node_id0, node_id1): + def set_linked_nodes(self, annotation, node_id0, node_id1, lock=False): """ + Ensure there is a link between node_id0 and node_id1 for annotation with the chosen lock state. + If link already exists, updates the lock state only. :param annotation: Annotation to use for link. :param node_id0: Node identifier to link from segment[0]. - :param node_id1: Node identifier to link from segment[1]. + :param node_id1: Node identifier to link from segment[1]. + :param lock: True to keep link connected until unlocked. """ annotation_name = annotation.get_name() - annotation_linked_nodes = self._linked_nodes.get(annotation_name) - if not annotation_linked_nodes: - self._linked_nodes[annotation_name] = annotation_linked_nodes = [] - annotation_linked_nodes.append([node_id0, node_id1]) + links = self._annotation_links.get(annotation_name) + if not links: + # first inserts at the end + self._annotation_links[annotation_name] = links = [] + # then reinsert any other names which should be after name + for name in list(self._annotation_links.keys()): + if name > annotation_name: + self._annotation_links[name] = self._annotation_links.pop(name) + node_identifiers = [node_id0, node_id1] + for index, link in enumerate(links): + if link['node identifiers'] < node_identifiers: + continue + if link['node identifiers'] == node_identifiers: + link['lock'] = lock + return + break + else: + index = len(links) + # insert in order of lowest first then second node identifier + links.insert(index, {'lock': lock, 'node identifiers': node_identifiers}) - def get_linked_nodes(self): + def get_annotation_links(self): """ :return: Map annotation name -> list of paired nodes from segment1 and segment2 """ - return self._linked_nodes + return self._annotation_links - def optimise_transformation(self): + def get_coordinates_midpoint(self): """ - Optimise transformation of second segment to align with position and direction of nearest points between - both segments. + Get midpoint of linked nodes, if any, which are transformed by the respective segments. + :return: Coordinates at the midpoint in their x, y, z range, or None if no linked nodes. """ - segment_end_point_data = [] - initial_rotation = [] - initial_rotation_matrix = [] - for s, segment in enumerate(self._segments): - translation = segment.get_translation() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] - initial_rotation.append(rotation) - rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None - initial_rotation_matrix.append(rotation_matrix) - end_point_data = [] - raw_end_point_data = segment.get_end_point_data() - for node_id, data in raw_end_point_data.items(): - coordinates, direction, radius, annotation = data - transformed_coordinates = coordinates - if (annotation is not None) and annotation.get_category().is_connectable(): - if rotation_matrix: - transformed_coordinates = matrix_vector_mult(rotation_matrix, transformed_coordinates) - transformed_coordinates = add(transformed_coordinates, translation) - end_point_data.append((node_id, transformed_coordinates, coordinates, direction, radius, annotation)) - segment_end_point_data.append(end_point_data) + minimums, maximums = self.get_coordinates_range() + if minimums and maximums: + return [0.5 * (minimum + maximum) for minimum, maximum in zip(minimums, maximums)] + return None - mean_coordinates = [] - mean_directions = [] - for s, segment in enumerate(self._segments): - total_weight = 0.0 - distances = [] - max_distance = None - for node_id0, transformed_coordinates0, _, _, _, annotation0 in segment_end_point_data[s]: - category0 = annotation0.get_category() - distance = None - for node_id1, transformed_coordinates1, _, _, _, annotation1 in segment_end_point_data[s - 1]: - category1 = annotation1.get_category() - if (category0 != category1) or ( - (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): - continue # end points are not allowed to join - tmp_distance = magnitude(sub(transformed_coordinates0, transformed_coordinates1)) - if (distance is None) or (tmp_distance < distance): - distance = tmp_distance - if (distance is not None) and ((max_distance is None) or (distance > max_distance)): - max_distance = distance - distances.append(distance) - if max_distance is None: - print("Segmentation Stitcher. No connectable points to optimise transformation with") - return - nearby_proportion = 0.1 # proportion of max distance under which distance weighting is the same - nearby_distance = max_distance * nearby_proportion - sum_coordinates = [0.0, 0.0, 0.0] - sum_transformed_coordinates = [0.0, 0.0, 0.0] - sum_direction = [0.0, 0.0, 0.0] - sum_transformed_direction = [0.0, 0.0, 0.0] - total_weight = 0.0 - for p, data in enumerate(segment_end_point_data[s]): - distance = distances[p] - if distance is None: - continue - _, transformed_coordinates, coordinates, direction, radius, annotation = data - if distance < nearby_distance: - distance = nearby_distance - weight = annotation.get_align_weight() * radius * radius / (distance * distance) - sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) - sum_direction = add(sum_direction, mult(direction, weight)) - total_weight += weight - mean_coordinates.append(div(sum_coordinates, total_weight)) - mean_directions.append(div(sum_direction, total_weight)) - unit_mean_directions = [normalize(v) for v in mean_directions] - mean_transformed_coordinates = [] - unit_mean_transformed_directions = [] - for s, segment in enumerate(self._segments): - x = mean_coordinates[s] - d = mean_directions[s] - if initial_rotation_matrix[s]: - x = matrix_vector_mult(initial_rotation_matrix[s], x) - d = matrix_vector_mult(initial_rotation_matrix[s], d) - x = add(x, segment.get_translation()) - mean_transformed_coordinates.append(x) - unit_mean_transformed_directions.append(normalize(d)) - - # optimise transformation of second segment so mean coordinates and directions coincide - - def rotation_objective(trial_rotation, *args): - target_direction, source_direction, target_side_direction, source_side_direction = args - trial_rotation_matrix = euler_to_rotation_matrix(trial_rotation) - trans_direction = matrix_vector_mult(trial_rotation_matrix, source_direction) - trans_side_direction = matrix_vector_mult(trial_rotation_matrix, source_side_direction) - return dot(trans_direction, target_direction) + dot(target_side_direction, trans_side_direction) - - # note the result is dependent on the initial position, but final optimisation should reduced effect - # get a side direction to minimise the unconstrained twist from the current direction - axis = [1.0, 0.0, 0.0] - if dot(unit_mean_transformed_directions[0], axis) < 0.1: - axis = [0.0, 1.0, 0.0] - target_side = normalize(cross(unit_mean_transformed_directions[0], axis)) - source_side = normalize( - cross(cross(target_side, unit_mean_transformed_directions[1]), unit_mean_transformed_directions[1])) - if initial_rotation_matrix[1]: - transformed_source_side = source_side - inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[1]) - source_side = matrix_vector_mult(inverse_rotation_matrix, transformed_source_side) - initial_angles = [math.radians(angle_degrees) for angle_degrees in self._segments[1].get_rotation()] - side_weight = 0.01 # so side has only a small effect on objective - res = minimize(rotation_objective, initial_angles, - args=(unit_mean_transformed_directions[0], unit_mean_directions[1], - mult(target_side, side_weight), mult(source_side, side_weight)), - method='Nelder-Mead', tol=0.001) - if not res.success: - print("Segmentation Stitcher. Could not optimise initial rotation") + def get_coordinates_range(self): + """ + Get x, y, z ranges of linked nodes in connection, which are transformed by the respective segments. + :return: Minimum coordinates, maximum coordinates, or None, None if no linked nodes. + """ + nodes = self._region.getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + return evaluate_field_nodeset_range(self._coordinates, nodes) + + def auto_align_segment(self, dependent_segment_index, phase1_align=True, gap_distance=0.0, phase_2_optimize=True): + """ + Optimise transformation of one connected segment relative to the other, by getting best fit + alignment and connection between nearest end points between them. + :param dependent_segment_index: Index of segment to optimise transformation of. + :param phase1_align: True if performing phase 1 align ends. + :param gap_distance: Gap distance to apply in phase 1. Can be negative to overlap. + :param phase_2_optimize: True if performing phase 2 optimize transformation in plane. + """ + max_gap_distance = 0.5 * self._max_distance + if math.fabs(gap_distance) > max_gap_distance: + logger.warning("Auto align gap distance is too large, limiting to " + str(max_gap_distance)) + gap_distance = math.copysign(max_gap_distance, gap_distance) + segments_count = len(self._segments) + if (dependent_segment_index < 0) or (dependent_segment_index >= segments_count): + logger.error("auto_align_segment. Segment index " + str(dependent_segment_index) + " out of range") return - rotation = [math.degrees(angle_radians) for angle_radians in res.x] - rotation_matrix = euler_to_rotation_matrix(res.x) - rotated_mean_coordinates = matrix_vector_mult(rotation_matrix, mean_coordinates[1]) - translation = sub(mean_transformed_coordinates[0], rotated_mean_coordinates) - # update transformed_coordinates in second segment data - for p, data in enumerate(segment_end_point_data[1]): - coordinates = data[2] - transformed_coordinates = add(matrix_vector_mult(rotation_matrix, coordinates), translation) - segment_end_point_data[1][p] = (data[0], transformed_coordinates, data[2], data[3], data[4], data[5]) - unit_transformed_direction = matrix_vector_mult(rotation_matrix, unit_mean_directions[1]) - # translate along unit_transformed_direction so no overlap between points - total_overlap = 0.0 - for s, segment in enumerate(self._segments): - max_overlap = 0.0 - for data in segment_end_point_data[s]: - overlap = dot(sub(data[1], mean_transformed_coordinates[0]), unit_transformed_direction) - if s == 0: - overlap = -overlap - if overlap > max_overlap: - max_overlap = overlap - total_overlap += max_overlap - translation = sub(translation, mult(unit_transformed_direction, total_overlap)) - self._segments[1].set_rotation(rotation, notify=False) - self._segments[1].set_translation(translation, notify=False) - - # GRC temp - # score = self.build_links(build_link_objects=False) - # print("part 1 rotation", rotation, "translation", translation, "score", score) - - # optimise angles and translation + if segments_count != 2: + logger.error("auto_align_segment. Not implemented for " + str(segments_count) + " segments") + return + fixed_segment_index = 1 if (dependent_segment_index == 0) else 0 + dependent_segment = self._segments[dependent_segment_index] + + fixed_transformed_end_location = None + fixed_transformed_end_direction = None + number_of_iterations = 2 # so second iteration starts reliably close + for iter in range(number_of_iterations): + # get segment transformations and apply to end points + segment_end_point_data = [] + initial_rotation_matrix = [] + for s, segment in enumerate(self._segments): + translation = segment.get_translation() + rotation_matrix = euler_to_rotation_matrix(segment.get_rotation_radians()) + initial_rotation_matrix.append(rotation_matrix) + end_point_data = [] + raw_end_point_data = segment.get_end_point_data() + for node_id, data in raw_end_point_data.items(): + coordinates, direction, radius, annotation = data + transformed_coordinates = coordinates + if (annotation is not None) and annotation.get_category().is_connectable(): + if rotation_matrix: + transformed_coordinates = matrix_vector_mult(rotation_matrix, transformed_coordinates) + transformed_coordinates = add(transformed_coordinates, translation) + end_point_data.append((node_id, transformed_coordinates, coordinates, direction, radius, annotation)) + segment_end_point_data.append(end_point_data) + + # get weighted mean end coordinates and directions of segment end points weighted by closeness to other segment + mean_end_locations = [] + mean_end_directions = [] # unit mean untransformed directions + # distance above which distance weighting is zero + far_distance = self._max_distance + gap_distance + for s, segment in enumerate(self._segments): + distances = [] # min transformed distance from end points of this segment to linkable end points in other + max_distance = None + remove_end_point_indexes = [] + for index0, data0 in enumerate(segment_end_point_data[s]): + node_id0, transformed_coordinates0, _, _, _, annotation0 = data0 + category0 = annotation0.get_category() + distance = None + for node_id1, transformed_coordinates1, _, _, _, annotation1 in segment_end_point_data[s - 1]: + category1 = annotation1.get_category() + if (category0 != category1) or ( + (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): + continue # end points are not allowed to join + tmp_distance = magnitude(sub(transformed_coordinates0, transformed_coordinates1)) + if (tmp_distance < far_distance) and ((distance is None) or (tmp_distance < distance)): + distance = tmp_distance + if (distance is not None) and ((max_distance is None) or (distance > max_distance)): + max_distance = distance + if distance is None: + remove_end_point_indexes.append(index0) + else: + distances.append(distance) # can be None + if max_distance is None: + logger.warning("Segmentation Stitcher. No linkable points to optimise transformation with") + return + for ix in reversed(remove_end_point_indexes): + del segment_end_point_data[s][ix] + sum_coordinates = [0.0, 0.0, 0.0] + sum_direction = [0.0, 0.0, 0.0] + total_weight = 0.0 + for distance, data in zip(distances, segment_end_point_data[s]): + _, _, coordinates, direction, radius, annotation = data + weight = annotation.get_align_weight() * radius * radius * (far_distance - distance) + sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) + sum_direction = add(sum_direction, mult(direction, weight)) + total_weight += weight + mean_end_direction = normalize(sum_direction) + mean_end_directions.append(mean_end_direction) + mean_coordinates = div(sum_coordinates, total_weight) + # get mean_end_locations at furthermost point in mean_end_direction + mean_projection = dot(mean_coordinates, mean_end_direction) + max_projection = mean_projection + for data in segment_end_point_data[s]: + coordinates = data[2] + projection = dot(coordinates, mean_end_direction) + if projection > max_projection: + max_projection = projection + offset = max_projection - mean_projection + # add gap_distance to fixed side + if s == fixed_segment_index: + offset += gap_distance + mean_end_locations.append(add(mean_coordinates, mult(mean_end_direction, offset))) + + if not phase1_align: + break # not transforming here and no need for multiple iterations + + # get angle axis transformation of dependent direction onto fixed direction + rotated_mean_end_directions = [ + matrix_vector_mult(initial_rotation_matrix[s], mean_end_directions[s]) for s in range(2)] + fixed_transformed_end_direction = rotated_mean_end_directions[fixed_segment_index] + # need to reverse fixed direction so inline + axis = cross(rotated_mean_end_directions[dependent_segment_index], + [-d for d in rotated_mean_end_directions[fixed_segment_index]]) + mag_axis = magnitude(axis) + if mag_axis > 1.0E-6: + axis = div(axis, mag_axis) + angle_radians = math.asin(mag_axis) + centre = mean_end_locations[dependent_segment_index] + dependent_segment.rotate_about_point_axis(centre, axis, angle_radians, notify=False) + dependent_rotation_matrix = euler_to_rotation_matrix(dependent_segment.get_rotation_radians()) + else: + dependent_rotation_matrix = initial_rotation_matrix[dependent_segment_index] + dependent_rotated_end_location = matrix_vector_mult( + dependent_rotation_matrix, mean_end_locations[dependent_segment_index]) + fixed_transformed_end_location = add( + matrix_vector_mult(initial_rotation_matrix[fixed_segment_index], mean_end_locations[fixed_segment_index]), + self._segments[fixed_segment_index].get_translation()) + translation = sub(fixed_transformed_end_location, dependent_rotated_end_location) + dependent_segment.set_translation(translation, notify=False) + + if not phase_2_optimize: + if phase1_align: + dependent_segment.set_translation(translation) # force notification + return + + # optimise rotation and translation in plane + + translation = dependent_segment.get_translation() + rotation_matrix = euler_to_rotation_matrix(dependent_segment.get_rotation_radians()) + centre = add(matrix_vector_mult(rotation_matrix, mean_end_locations[dependent_segment_index]), translation) + axis3 = [-c for c in matrix_vector_mult(rotation_matrix, mean_end_directions[dependent_segment_index])] + + # get 2 orthogonal axes for translations, scaled by max_distance so parameter scale similar to rotation radians: + axis1 = cross([1.0, 0.0, 0.0], axis3) + if magnitude(axis1) < 0.1: + axis1 = cross([0.0, 1.0, 0.0], axis3) + axis1 = set_magnitude(axis1, 0.5 * self._max_distance) + axis2 = cross(axis3, axis1) + initial_rotation = dependent_segment.get_rotation_radians() + initial_translation = dependent_segment.get_translation() + def links_objective(rotation_translation, *args): - rotation = list(rotation_translation[:3]) - translation = list(rotation_translation[3:]) - self._segments[1].set_rotation(rotation, notify=False) - self._segments[1].set_translation(translation, notify=False) + angle_radians = rotation_translation[0] + translation1 = rotation_translation[1] + translation2 = rotation_translation[2] + dependent_segment.set_rotation_radians(initial_rotation, notify=False) + dependent_segment.set_translation(initial_translation, notify=False) + dependent_segment.rotate_about_point_axis(centre, axis3, angle_radians, notify=False) + dependent_segment.translate(add(mult(axis1, translation1), mult(axis2, translation2)), notify=False) score = self.build_links(build_link_objects=False) - # print("rotation", rotation, "translation", translation, "score", score) + # print(rotation_translation, "score", score) return score - initial_parameters = rotation + translation - initial_score = links_objective(initial_parameters, ()) - # TOL = initial_score * 1.0E-6 - # method='Nelder-Mead' - res = minimize(links_objective, initial_parameters, method='Powell') # , tol=TOL) - if not res.success: - print("Segmentation Stitcher. Could not optimise final rotation and translation") - return - rotation = list(res.x[:3]) - translation = list(res.x[3:]) - self._segments[1].set_rotation(rotation, notify=False) - # this will invoke build_links: - self._segments[1].set_translation(translation) + initial_rotation_translation = [0.0, 0.0, 0.0] + # 0.75 ~ 43 degrees + res = minimize(links_objective, initial_rotation_translation, + args=(), + method='Nelder-Mead', # method='Powell', + bounds=[(-0.75, 0.75), (-0.75, 0.75), (-0.75, 0.75)]) # , tol=TOL) + if res.success: + links_objective(res.x) # to ensure the last values are converted to rotation and translation + # this will invoke build_links and build_link_objects: + dependent_segment.set_translation(dependent_segment.get_translation()) + else: + logger.warning("Segmentation Stitcher. Could not optimise rotation and translation") + # restore transformation + dependent_segment.set_rotation_radians(initial_rotation, notify=False) + # this will invoke build_links and build_link_objects: + dependent_segment.set_translation(initial_translation) + return def build_links(self, build_link_objects=True): """ @@ -321,96 +419,183 @@ def build_links(self, build_link_objects=True): :return: Total link score. """ total_score = 0.0 - remaining_radius_factor = 0.25 - self._linked_nodes = {} + + # remember locked tuples of linked nodes to re-attach in algorithm below + locked_node_identifiers_list = [] + annotation_names = list(self._annotation_links.keys()) + for annotation_name in annotation_names: + for link in self._annotation_links[annotation_name]: + if link['lock']: + locked_node_identifiers_list.append(tuple(link['node identifiers'])) + self._annotation_links = {} + # filter, transform and sort end point data from largest to smallest radius segment_sorted_end_point_data = [] + min_area = None for s, segment in enumerate(self._segments): translation = segment.get_translation() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation = segment.get_rotation_radians() rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None sorted_end_point_data = [] end_point_data = segment.get_end_point_data() for node_id, data in end_point_data.items(): coordinates, direction, radius, annotation = data + area = math.pi * radius * radius + if (min_area is None) or (area < min_area): + min_area = area if (annotation is not None) and annotation.get_category().is_connectable(): if rotation_matrix: coordinates = matrix_vector_mult(rotation_matrix, coordinates) direction = matrix_vector_mult(rotation_matrix, direction) coordinates = add(coordinates, translation) - for i, data in enumerate(sorted_end_point_data): - if radius > data[3]: + if area > data[3]: break else: i = len(sorted_end_point_data) - sorted_end_point_data.insert(i, (node_id, coordinates, direction, radius, annotation)) + sorted_end_point_data.insert(i, [node_id, coordinates, direction, area, annotation]) segment_sorted_end_point_data.append(sorted_end_point_data) sorted_end_point_data0 = segment_sorted_end_point_data[0] sorted_end_point_data1 = segment_sorted_end_point_data[1] - - while len(sorted_end_point_data0): - end_point_data0 = sorted_end_point_data0[0] - node_id0, coordinates0, direction0, radius0, annotation0 = end_point_data0 - category0 = annotation0.get_category() - best_index1 = None - lowest_score = 0.0 - weight = None + min_area *= 0.5 # so reliably below smallest end point area + + # print("Connection", self._name) + # make 2D array of base score independent of area and exclusivity + base_scores0 = [] # index over segment 0 endpoints, then segment 1 + max_mag_delta_coordinates = 0.5 * self._max_distance + # below this proportion of max_mag_delta_coordinates the closeness score is the same: + min_relative_distance = 0.0001 + worst_base_score = 10.0 + for index0, end_point_data0 in enumerate(sorted_end_point_data0): + node_id0, coordinates0, direction0, area0, annotation0 = end_point_data0 + base_scores1 = [] for index1, end_point_data1 in enumerate(sorted_end_point_data1): - node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 - category1 = annotation1.get_category() - # inter-segment links are only to the same annotation; links within category will be done separately - if annotation0 != annotation1: - continue # end points are not allowed to join - direction_score = math.fabs(1.0 + dot(direction0, direction1)) - if direction_score > 0.5: # arbitrary factor - continue # end points are not pointing towards each other + node_id1, coordinates1, direction1, area1, annotation1 = end_point_data1 + if not annotation0.is_connectable_with(annotation1): + # use worst score for end points which cannot connect + base_scores1.append(worst_base_score) + continue + dot_directions = dot(direction0, direction1) # -1.0 if perfectly pointing at each other + # if dot_directions > 0.2: # arbitrary factor + # base_scores1.append(worst_base_score) + # continue # end points are not pointing towards each other + direction_score = 0.2 + (0.8 / 1.2) * (1.0 + dot_directions) # minimum 0.2 delta_coordinates = sub(coordinates1, coordinates0) mag_delta_coordinates = magnitude(delta_coordinates) - tdistance = dot(direction0, delta_coordinates) - ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) - if mag_delta_coordinates > (0.5 * self._max_distance): - continue # point is too far away - distance_score = ((tdistance * tdistance + 100.0 * ndistance * ndistance) / - (self._max_distance * self._max_distance)) - tfactor = math.exp(-1000.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor - penetration_distance_score = ((tfactor * tdistance * tdistance) / - (self._max_distance * self._max_distance)) - delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale - radius_score = delta_radius * delta_radius - score = radius0 * (10.0 * direction_score + distance_score + radius_score) - if (best_index1 is None) or (score < lowest_score): - best_index1 = index1 - weight = 0.5 * (annotation0.get_align_weight() + annotation1.get_align_weight()) - lowest_score = score + penetration_distance_score - if best_index1 is not None: - # if category0 != AnnotationCategory.NETWORK_GROUP_1: - total_score += weight * lowest_score - node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] - self.add_linked_nodes(annotation1, node_id0, node_id1) - remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) - if (radius0 > radius1) and (remaining_radius > remaining_radius_factor * radius0): - for i in range(1, len(sorted_end_point_data0)): - if remaining_radius > sorted_end_point_data0[i][3]: - break - # sorted_end_point_data0.insert(i, (node_id0, coordinates0, direction0, remaining_radius, annotation0)) - elif remaining_radius > (remaining_radius_factor * radius1): - for i in range(best_index1, len(sorted_end_point_data1)): - if remaining_radius > sorted_end_point_data1[i][3]: - break - # sorted_end_point_data1.insert(i, (node_id1, coordinates1, direction1, remaining_radius, annotation1)) - sorted_end_point_data1.pop(best_index1) - else: - total_score += radius0 * 20.0 # arbitrary factor - sorted_end_point_data0.pop(0) + # if mag_delta_coordinates > max_mag_delta_coordinates: + # base_scores1.append(worst_base_score) + # continue # end point are too far away from each other + relative_distance = mag_delta_coordinates / max_mag_delta_coordinates + closeness_score = max(relative_distance, min_relative_distance) + if mag_delta_coordinates == 0.0: + align_score = 0.2 # minimum 0.2 + else: + t0 = dot(direction0, delta_coordinates) + n0 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t0 * t0) + t1 = dot(direction1, delta_coordinates) + n1 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t1 * t1) + align_score = 0.2 + 0.4 * (n0 + n1) / mag_delta_coordinates # minimum 0.5 + base_score = closeness_score * align_score * direction_score + base_scores1.append(base_score) + base_scores0.append(base_scores1) + + def get_minimums_ratio(scores): + """ + Get ratio of lowest / next lowest score as measure of 'only option' for first link to end point. + :param scores: List of real values or None to ignore an entry. + """ + inf = float('inf') + min1 = min2 = inf + for score in scores: + if score is not None: + if score < min1: + min1, min2 = score, min1 + elif score < min2: + min2 = score + if min2 is not inf: + return min1 / min2 + return 1.0 + + base_scores1 = [[score1[index1] for score1 in base_scores0] for index1 in range(len(sorted_end_point_data1))] + exclusive_base_scores0 = [get_minimums_ratio(scores1) for scores1 in base_scores0] + exclusive_base_scores1 = [get_minimums_ratio(scores0) for scores0 in base_scores1] + links_count0 = [0.0] * len(exclusive_base_scores0) + links_count1 = [0.0] * len(exclusive_base_scores1) + + best_score = 1.0 + cut_off_base_score = 0.1 + while best_score is not None: + best_score = None + best_nonexclusive_score = None + best_area = 0.0 + best_indexes = None + lock = False + for index0, end_point_data0 in enumerate(sorted_end_point_data0): + node_id0 = end_point_data0[0] + area0 = end_point_data0[3] + base_scores1 = base_scores0[index0] + for index1, end_point_data1 in enumerate(sorted_end_point_data1): + base_score = base_scores1[index1] + node_id1 = end_point_data1[0] + area1 = end_point_data1[3] + area = min(area0, area1) + indexes = (index0, index1) + node_identifiers = (node_id0, node_id1) + if node_identifiers in locked_node_identifiers_list: + best_area = max(min_area, area) # don't want area to get negative + best_nonexclusive_score = best_score = base_score / math.sqrt(best_area) + best_indexes = indexes + locked_node_identifiers_list.remove(node_identifiers) + lock = True + break + else: + if base_score > cut_off_base_score: + continue + if area < min_area: + continue + nonexclusive_score = score = base_score / math.sqrt(area) + # lower score for first links by factor indicating 'only option' + exclusive_base_scores = [] + if links_count0[index0] == 0: + exclusive_base_scores.append(exclusive_base_scores0[index0]) + if links_count1[index1] == 0: + exclusive_base_scores.append(exclusive_base_scores1[index1]) + if exclusive_base_scores: + score *= min(exclusive_base_scores) ** 2.0 + else: + if area < (3.0 * min_area): + continue + score *= (links_count0[index0] + links_count1[index1] + 1) + if (best_score is None) or (score < best_score): + best_score = score + best_nonexclusive_score = nonexclusive_score + best_area = area + best_indexes = indexes + if lock: + break + if best_score is not None: + end_point_data0 = sorted_end_point_data0[best_indexes[0]] + node_id0 = end_point_data0[0] + end_point_data1 = sorted_end_point_data1[best_indexes[1]] + node_id1 = end_point_data1[0] + # use annotation1 as it may be a branch which should logically own the link + annotation = end_point_data1[4] + self.set_linked_nodes(annotation, node_id0, node_id1, lock) + # print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) + end_point_data0[3] -= best_area + end_point_data1[3] -= best_area + links_count0[best_indexes[0]] += 1 + links_count1[best_indexes[1]] += 1 + # total score is not affected by exclusive measure used to match 'only option' links + total_score += best_nonexclusive_score * best_area if build_link_objects: - self._build_link_objects() + self.build_link_objects() return total_score - def _build_link_objects(self): + def build_link_objects(self): """ Make link nodes/elements for visualisation. """ @@ -435,7 +620,7 @@ def _build_link_objects(self): snodes.append(sfieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)) sfieldcache.append(sfieldmodule.createFieldcache()) tr_coordinates = sfieldmodule.findFieldByName("coordinates").castFiniteElement() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation = segment.get_rotation_radians() if rotation != [0.0, 0.0, 0.0]: rotation_matrix = euler_to_rotation_matrix(rotation) tr_coordinates = sfieldmodule.createFieldMatrixMultiply( @@ -453,13 +638,14 @@ def _build_link_objects(self): with (ChangeManager(fieldmodule)): mesh1d.destroyAllElements() nodes.destroyAllNodes() - for group_name, linked_nodes_list in self._linked_nodes.items(): - group = find_or_create_field_group(fieldmodule, group_name) + for annotation_name, links in self._annotation_links.items(): + group = find_or_create_field_group(fieldmodule, annotation_name) nodeset_group = group.getOrCreateNodesetGroup(nodes) mesh_group = group.getOrCreateMeshGroup(mesh1d) - for linked_nodes in linked_nodes_list: + for link in links: + node_identifiers = link['node identifiers'] cnode_ids = [None, None] - for s, snode_id in enumerate(linked_nodes): + for s, snode_id in enumerate(node_identifiers): cnode_ids[s] = snode_id_to_cnode_id[s].get(snode_id) if not cnode_ids[s]: snode = snodes[s].findNodeByIdentifier(snode_id) @@ -477,10 +663,103 @@ def _build_link_objects(self): element.setNodesByIdentifier(eft, cnode_ids) element_identifier += 1 - def update_annotation_category_groups(self, annotations): + def link_and_lock_selected_ends(self): + """ + Create and lock links between all permutations of selected end points in selected elements of each segment. + """ + end_node_identifiers0, end_annotations0 = self._segments[0].get_selected_end_points() + end_node_identifiers1, end_annotations1 = self._segments[1].get_selected_end_points() + new_links_count = 0 + for node_id0, annotation0 in zip(end_node_identifiers0, end_annotations0): + for node_id1, annotation1 in zip(end_node_identifiers1, end_annotations1): + if annotation0.is_connectable_with(annotation1): + # use annotation1 as it may be a branch which should logically own the link + self.set_linked_nodes(annotation1, node_id0, node_id1, lock=True) + new_links_count += 1 + if new_links_count: + self.build_links() + self.update_annotation_category_groups() + else: + logger.warning('Connection ' + self._name + '. Link and lock selected ends. No valid links exist') + + def remove_selected_links(self): + """ + Unlock and remove links corresponding to selected visualization elements in this connection. + This does not prevent links from being automatically re-found after either segment is moved. + """ + root_scene = self._region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return + element_identifier = 1 + for annotation_name, links in self._annotation_links.items(): + remove_link_indexes = [] + index = 0 + for link in links: + link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() + if link_selected: + remove_link_indexes.insert(0, index) # reverse order + element_identifier += 1 + index += 1 + if remove_link_indexes: + for remove_link_index in remove_link_indexes: + del links[remove_link_index] + self.build_link_objects() + self.update_annotation_category_groups() + + def set_link_locking_from_selection(self, lock: bool): + """ + Lock or unlock links for nodes matching any selected visualization elements. + :param lock: True to lock, False to unlock. + """ + root_scene = self._region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return + element_identifier = 1 + for annotation_name, links in self._annotation_links.items(): + for link in links: + link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() + if link_selected: + link['lock'] = lock + element_identifier += 1 + + def add_locked_links_to_selection(self): + """ + Add locked links to the scene selection. + """ + root_region = self._region.getRoot() + root_scene = root_region.getScene() + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + # create selection on demand if any links have a lock + root_selection_group = None + selection_mesh_group = None + element_identifier = 1 + with ChangeManager(root_scene), HierarchicalChangeManager(root_region): + for annotation_name, links in self._annotation_links.items(): + for link in links: + if link['lock']: + if not selection_mesh_group: + root_selection_group = scene_get_or_create_selection_group(root_scene) + selection_mesh_group = root_selection_group.getOrCreateMeshGroup(mesh1d) + link_element = mesh1d.findElementByIdentifier(element_identifier) + selection_mesh_group.addElement(link_element) + element_identifier += 1 + + def update_annotation_category_groups(self): """ Rebuild all annotation category groups e.g. after loading settings. - :param annotations: List of all annotations from stitcher. """ fieldmodule = self._region.getFieldmodule() with ChangeManager(fieldmodule): @@ -488,7 +767,7 @@ def update_annotation_category_groups(self, annotations): for category in AnnotationCategory: category_group = self.get_category_group(category) category_group.clear() - for annotation in annotations: + for annotation in self._annotations: annotation_group = self.get_annotation_group(annotation) if annotation_group: category_group = self.get_category_group(annotation.get_category()) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index c9c6927..b791e87 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -1,19 +1,29 @@ """ A segment of the segmentation data, generally from a separate image block. """ -from builtins import enumerate +from cmlibs.maths.vectorops import ( + add, axis_angle_to_rotation_matrix, cross, dot, euler_to_rotation_matrix, magnitude, matrix_mult, + matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, set_magnitude, sub) -from cmlibs.maths.vectorops import cross, dot, magnitude, matrix_mult, mult, normalize, set_magnitude, sub from cmlibs.utils.zinc.field import ( - get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) -from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range + get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, + find_or_create_field_stored_string) +from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range, get_maximum_node_identifier from cmlibs.utils.zinc.group import group_add_group_local_contents, group_remove_group_local_contents from cmlibs.utils.zinc.general import ChangeManager from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK -from segmentationstitcher.annotation import AnnotationCategory +from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations + +import copy +import json +import logging import math +import os + + +logger = logging.getLogger(__name__) class Segment: @@ -21,11 +31,19 @@ class Segment: A segment of the segmentation data, generally from a separate image block. """ - def __init__(self, name, segmentation_file_name, root_region): + def __init__(self, name, segmentation_file_name, root_region, annotations, + network_group1_keywords, network_group2_keywords, term_keywords): """ :param name: Unique name of segment, usually derived from the file name. :param segmentation_file_name: Path and file name of raw segmentation file, in Zinc format. :param root_region: Zinc root region to create segment region under. + :param annotations: List of all annotations to add to. + :param network_group1_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 1, allowing them to be stitched together. + :param network_group2_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 2, allowing them to be stitched together. + :param term_keywords: List of term keywords; if found in group name, group is considered to mark the term for + the separately named group. """ self._name = name self._segmentation_file_name = segmentation_file_name @@ -36,6 +54,7 @@ def __init__(self, name, segmentation_file_name, root_region): # the raw region contains the original segment data which is not modified apart from building # groups to categorise data for stitching a visualisation, including selecting for display. self._raw_region = self._base_region.createChild("raw") + self._annotations = annotations result = self._raw_region.readFile(segmentation_file_name) assert result == RESULT_OK, \ "Could not read segmentation file " + segmentation_file_name @@ -49,6 +68,7 @@ def __init__(self, name, segmentation_file_name, root_region): group.setManaged(True) self._rotation = [0.0, 0.0, 0.0] self._translation = [0.0, 0.0, 0.0] + self._ignore_orientation = False self._transformation_change_callbacks = [] self._raw_fieldcache = self._raw_fieldmodule.createFieldcache() self._raw_coordinates = self._raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -66,10 +86,40 @@ def __init__(self, name, segmentation_file_name, root_region): self._working_best_fit_line_orientation = find_or_create_field_finite_element( self._working_fieldmodule, "best_fit_line_orientation", 9) self._working_end_group = find_or_create_field_group(self._working_fieldmodule, "active_ends") + for category in AnnotationCategory: + if category.is_connectable(): + group_name = category.get_group_name() + group = self._working_fieldmodule.createFieldGroup() + group.setName(group_name) + group.setManaged(True) self._element_node_ids, self._node_element_ids = self._get_element_node_maps() - self._end_node_ids = self._get_end_node_ids() + # following are determing by client call to create_end_point_directions() + self._end_node_ids = [] # mesh end points = nodes in only 1 element + self._interior_end_node_ids = [] # certain interior points where annotation changes, also connectable self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) + segment_annotations = region_get_annotations( + self._raw_region, network_group1_keywords, network_group2_keywords, term_keywords) + for segment_annotation in segment_annotations: + annotation_name = segment_annotation.get_name() + term = segment_annotation.get_term() + index = 0 + for annotation in self._annotations: + if annotation.get_name() == annotation_name: + existing_term = annotation.get_term() + if term != existing_term: + logger.warning("Segment " + name + ": Found existing annotation with name " + annotation_name + + " but existing term " + str(existing_term) + + " does not equal new term " + str(term) + ". Clearing term") + annotation.clear_term() + break # exists already + if annotation_name > annotation.get_name(): + index += 1 + else: + # print("Add annotation name", annotation_name, "term", term, "dim", segment_annotation.get_dimension(), + # "category", segment_annotation.get_category()) + self._annotations.insert(index, segment_annotation) + def decode_settings(self, settings_in: dict): """ Update segment settings from JSON dict containing serialised settings. @@ -79,8 +129,9 @@ def decode_settings(self, settings_in: dict): # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) - self._rotation = settings["rotation"] + self._rotation = [math.radians(deg) for deg in settings["rotation"]] self._translation = settings["translation"] + self._ignore_orientation = settings["ignore orientation"] def encode_settings(self) -> dict: """ @@ -88,12 +139,64 @@ def encode_settings(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ settings = { + "ignore orientation": self._ignore_orientation, "name": self._name, - "rotation": self._rotation, + "rotation": [math.degrees(rad) for rad in self._rotation], "translation": self._translation } return settings + def define_endpoints(self, endpoints_file_name): + """ + Read endpoints labels and positions from the endpoints file as markers in the raw region. + These are used to associate externally recognised names with endpoints for later reporting. + :param endpoints_file_name: Name of file in Slicer 3D markups json format. + """ + if os.path.isfile(endpoints_file_name): + with open(endpoints_file_name, "r") as f: + try: + marker_file_data = json.loads(f.read()) + schema = marker_file_data.get("@schema") + markups = marker_file_data.get("markups") + if not (schema and ("slicer" in schema) and ("markups" in schema) and markups): + logger.error("Stitcher endpoints file " + endpoints_file_name + " is not a supported file type") + return + if len(markups) == 0: + logger.warning("Stitcher endpoints file" + endpoints_file_name + " has no markups") + for markup in markups: + coordinate_system = markup.get("coordinateSystem") + if coordinate_system != "LPS": + logger.warning("Stitcher endpoints file" + endpoints_file_name + + " unimplemented coordinateSystem name " + coordinate_system) + coordinate_units = markup.get("coordinateUnits") + control_points = markup.get("controlPoints") + if not (isinstance(control_points, list) and (len(control_points) > 0)): + logger.warning("Stitcher endpoints file" + endpoints_file_name + " has no controlPoints") + continue + # build list of marker labels and positions + marker_labels = [] + marker_positions = [] + for control_point in control_points: + marker_labels.append(control_point["label"]) + x = control_point["position"] + # if "C1L" in self._name: + # x = [-1000.0 * x[1] - 4500.0, 1000.0 * x[0] + 2250.0, -1000.0 * x[2] + 11500.0] + # elif any(s in self._name for s in ["T5L", "T6L"]): + # x = [-1000.0 * x[1] - 9500.0, 1000.0 * x[0] + 5000.0, -1000.0 * x[2]] + # elif any(s in self._name for s in ["C2L", "C4L", "T2L", "T3L", "T4L"]): + # x = [-9.0 * x[1], 9.0 * x[0], -9.0 * x[2]] + # else: + # x = [9.0 * x[1], -9.0 * x[0], -9.0 * x[2]] + marker_positions.append(x) + generate_datapoints(self._raw_region, marker_positions, + field_names_and_values=[("marker_name", marker_labels)], + group_name="marker") + except json.JSONDecodeError as e: + logger.error("Stitcher endpoints file " + endpoints_file_name + + " exception reading json format " + str(e)) + else: + logger.error("Stitcher endpoints file " + endpoints_file_name + " not found") + def _get_element_node_maps(self): """ Get maps from 1-D elements to nodes and nodes to elements for the raw data. @@ -120,52 +223,57 @@ def _get_element_node_maps(self): element = elem_iter.next() return element_node_ids, node_element_ids - def _get_end_node_ids(self): - """ - :return: List of identifiers of nodes at end points i.e. in only 1 element. + def _element_id_to_annotation(self, element_id): """ - end_node_ids = [] - for node_id, element_ids in self._node_element_ids.items(): - if len(element_ids) == 1: - end_node_ids.append(node_id) - return end_node_ids - - def _element_id_to_group(self, element_id, annotations): - """ - Get the first Annotation zinc Group containing raw element of supplied identifier. - :param node_id: Identifier of [end] node to query. - :param annotations: Global list of all annotations. - :return: Zinc Group, MeshGroup or None, None if not found. + Get first Annotation containing raw element of supplied identifier, prioritizing connectable annotations + then those with term ids. + :param element_id: Identifier of element from raw region to query. + :return: Annotation or None if not found. """ element = self._raw_mesh1d.findElementByIdentifier(element_id) - for annotation in annotations: - group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() - if group.isValid(): - mesh_group = group.getMeshGroup(self._raw_mesh1d) - if mesh_group.isValid() and mesh_group.containsElement(element): - return group, mesh_group - return None, None - - def _track_segment(self, start_node_id, start_element_id, + for priority in range(3): + for annotation in self._annotations: + is_connectable = annotation.is_connectable() + if priority == 0: + if not is_connectable: + continue + else: + if is_connectable: + continue + has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) + if priority == 1: + if not has_term: + continue + else: + if has_term: + continue + group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if group.isValid(): + mesh_group = group.getMeshGroup(self._raw_mesh1d) + if mesh_group.isValid() and mesh_group.containsElement(element): + return annotation + return None + + def _track_segment(self, start_node_id, start_element_id, annotation, max_length=None, min_element_count=None, min_aspect_ratio=None): """ Get coordinates and radii along segment from start_node_id in start_element_id, proceeding - first to other local node in element, until junction, end point or max_distance is tracked. + first to other local node in element, until junction, end point, annotation change or max_distance is tracked. Can finish earlier if min_element_count, min_aspect_ratio reached, but both must be reached if both in use. :param start_node_id: First node in path. :param start_element_id: Element containing start_node_id and another node to be added. + :param annotation: Annotation which element needs to be in. :param max_length: Maximum length to track from first node coordinates, or None for no limit. :param min_element_count: Minimum number of elements to track, or None to not test. :param min_aspect_ratio: Minimum ratio of length / mean radius to end tracking, or None to not test. :return: coordinates list, radius list, node id list, endElementId """ - self._element_node_ids, self._node_element_ids node_id = start_node_id element_id = start_element_id path_coordinates = [] path_radii = [] path_node_ids = [] - lastNode = False + is_last_node = False sum_r = 0.0 while True: if node_id in path_node_ids: @@ -183,7 +291,7 @@ def _track_segment(self, start_node_id, start_element_id, r = 1.0 path_radii.append(r) sum_r += r - if lastNode: + if is_last_node: break point_count = len(path_coordinates) if point_count > 1: @@ -203,29 +311,34 @@ def _track_segment(self, start_node_id, start_element_id, node_id = node_ids[1] if (node_ids[0] == node_id) else node_ids[0] element_ids = self._node_element_ids[node_id] if len(element_ids) != 2: - lastNode = True + is_last_node = True continue element_id = element_ids[1] if (element_ids[0] == element_id) else element_ids[0] + # Future: more efficient to check element is in annotation mesh group? + next_annotation = self._element_id_to_annotation(element_id) + if next_annotation != annotation: + is_last_node = True return path_coordinates, path_radii, path_node_ids, element_id - def _track_path(self, end_node_id, annotations, max_length=None): + def _track_path(self, start_node_id, start_element_id, max_length=None): """ - Get coordinates and radii along path from end_node_id, continuing along - branches if in similar direction. - :param end_node_id: End node identifier to track from. Must be in only one element. - :param annotations: Global list of all annotations. + Get coordinates and radii along path from start_node_id, across start_element_id, continuing along branches if + in a similar direction. Stops at another end point (1 parent element) or annotation change. + :param start_node_id: Start node identifier to track from. + :param start_element_id: Start element identifier to track across. :param max_length: Maximum length to track along, or None for no limit. - :return: coordinates list, radius list, path node ids, path group, start_x, end_x, mean_r + :return: coordinates list, radius list, path node ids, path annotation, start_x, end_x, mean_r """ - element_ids = self._node_element_ids[end_node_id] - assert len(element_ids) == 1 - path_group = self._element_id_to_group(element_ids[0], annotations)[0] + path_annotation = self._element_id_to_annotation(start_element_id) path_coordinates = [] path_radii = [] path_node_ids = [] path_mean_r = None - stop_node_id = end_node_id + stop_node_id = start_node_id + element_ids = [start_element_id] + # node_ids = self._element_node_ids[start_element_id] + # start_node_index = 0 if (node_ids[0] == start_node_id) else -1 stop_element_id = None start_x = None end_x = None @@ -247,11 +360,14 @@ def _track_path(self, end_node_id, annotations, max_length=None): for element_id in element_ids: if element_id == stop_element_id: continue - segment_group = self._element_id_to_group(element_id, annotations)[0] - if path_group and (segment_group != path_group): + element_annotation = self._element_id_to_annotation(element_id) + if path_annotation and (element_annotation != path_annotation): continue + # node_ids = self._element_node_ids[element_id] + # if node_ids[start_node_index] != stop_node_id: + # continue # change of element orientation: stops if same-named branches eminate from a junction segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id = self._track_segment( - stop_node_id, element_id, + stop_node_id, element_id, path_annotation, max_length=max_length - length, min_element_count=min_element_count - element_count, min_aspect_ratio=min_aspect_ratio - aspect_ratio) @@ -299,49 +415,89 @@ def _track_path(self, end_node_id, annotations, max_length=None): if add_path_mean_r > 0.0: aspect_ratio += add_path_length / add_path_mean_r # 2nd iteration of fit line removes outliers: - start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.5)[0:3] - return path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r + start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.25)[0:3] + return path_coordinates, path_radii, path_node_ids, path_annotation, start_x, end_x, mean_r - def create_end_point_directions(self, annotations, max_distance): + def create_end_point_directions(self, max_distance): """ Track mean directions of network end points and create working objects for visualisation. - :param annotations: Global list of all annotations. :param max_distance: Maximum length to track back from end point. Stored for link tolerance. """ - nodetemplate = self._working_datapoints.createNodetemplate() - nodetemplate.defineField(self._working_coordinates) - nodetemplate.defineField(self._working_radius_direction) - nodetemplate.defineField(self._working_best_fit_line_orientation) - fieldcache = self._working_fieldmodule.createFieldcache() - self._end_point_data = {} - for end_node_id in self._end_node_ids: - path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r =( - self._track_path(end_node_id, annotations, max_distance)) - # Future: want to extend length to be equivalent to path_coordinates - direction = sub(start_x, end_x) - annotation = None - annotation_group_name = path_group.getName() if path_group else None - if annotation_group_name: - for tmp_annotation in annotations: - if tmp_annotation.get_name() == annotation_group_name: - annotation = tmp_annotation - break - self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) - # set up visualization objects. End direction datapoints have same identifiers as raw end nodes - node = self._working_datapoints.createNode(end_node_id, nodetemplate) - fieldcache.setNode(node) - radius_direction = set_magnitude(direction, mean_r) - self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) - self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, - radius_direction) - direction1 = sub(end_x, start_x) - axis = [1.0, 0.0, 0.0] - if dot(normalize(direction1), axis) < 0.1: - axis = [0.0, 1.0, 0.0] - direction2 = set_magnitude(cross(axis, direction1), mean_r) - direction3 = set_magnitude(cross(direction1, direction2), mean_r) - self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, - direction1 + direction2 + direction3) + # following are determined here: + self._end_node_ids = [] # mesh end points = nodes in only 1 element + self._interior_end_node_ids = [] # certain interior points where annotation changes, also connectable + self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) + for node_id, element_ids in self._node_element_ids.items(): + element_count = len(element_ids) + if element_count == 1: + self._end_node_ids.append(node_id) + with ChangeManager(self._working_fieldmodule): + nodetemplate = self._working_datapoints.createNodetemplate() + nodetemplate.defineField(self._working_coordinates) + nodetemplate.defineField(self._working_radius_direction) + nodetemplate.defineField(self._working_best_fit_line_orientation) + fieldcache = self._working_fieldmodule.createFieldcache() + interior_end_node_element_ids = {} # map from interior end node to untracked element ids + for interior in (False, True): + for end_node_id in (sorted(interior_end_node_element_ids) if interior else self._end_node_ids): + if interior: + end_element_ids = interior_end_node_element_ids[end_node_id] + if len(end_element_ids) != 1: + continue # not a valid interior end node + self._interior_end_node_ids.append(end_node_id) + end_element_id = end_element_ids[0] + else: + end_element_id = self._node_element_ids[end_node_id][0] + + path_coordinates, path_radii, path_node_ids, annotation, start_x, end_x, mean_r =( + self._track_path(end_node_id, end_element_id, max_distance)) + # Future: want to extend length to be equivalent to path_coordinates + direction = sub(start_x, end_x) + if not annotation: + print("No annotation group for node", end_node_id) + self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) + # set up visualization objects. End direction datapoints have same identifiers as raw end nodes + node = self._working_datapoints.createNode(end_node_id, nodetemplate) + fieldcache.setNode(node) + radius_direction = set_magnitude(direction, mean_r) + self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) + self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, + radius_direction) + direction1 = sub(end_x, start_x) + if (magnitude(direction1) > 0.0) and (mean_r > 0.0): + norm_direction1 = normalize(direction1) + for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: + if math.fabs(dot(norm_direction1, side)) < 0.1: + break + direction2 = set_magnitude(cross(side, direction1), mean_r) + direction3 = set_magnitude(cross(direction1, direction2), mean_r) + else: + direction2 = direction3 = [0.0, 0.0, 0.0] + self._working_best_fit_line_orientation.setNodeParameters( + fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, direction1 + direction2 + direction3) + + if not interior: + stop_node_id = path_node_ids[-1] + if stop_node_id not in self._end_node_ids: + # determine the stop element for this path + stop_element_id = None + prev_node_id = path_node_ids[-2] + element_ids = self._node_element_ids[stop_node_id] + for element_id in element_ids: + node_ids = self._element_node_ids[element_id] + if prev_node_id in node_ids: + stop_element_id = element_id + break + # determine if path ended on a change of annotation = interior end node + for element_id in element_ids: + element_annotation = self._element_id_to_annotation(element_id) + if element_annotation != annotation: + end_element_ids = interior_end_node_element_ids.get(stop_node_id) + if not end_element_ids: + interior_end_node_element_ids[stop_node_id] = end_element_ids =\ + copy.copy(element_ids) + end_element_ids.remove(stop_element_id) + break def get_end_point_data(self): """ @@ -389,6 +545,19 @@ def get_end_point_fields(self): def get_name(self): return self._name + def get_coordinates_midpoint(self): + """ + :return: Coordinates at the midpoint in their x, y, z range. + """ + return [0.5 * (minimum + maximum) for minimum, maximum in zip(self._raw_minimums, self._raw_maximums)] + + def get_coordinates_range(self): + """ + Get x, y, z ranges of coordinates in raw data. + :return: Minimum coordinates, maximum coordinates. + """ + return self._raw_minimums, self._raw_maximums + def get_max_range(self): """ :return: Maximum range of raw coordinates on any axis x, y, z. @@ -396,6 +565,14 @@ def get_max_range(self): raw_range = [self._raw_maximums[c] - self._raw_minimums[c] for c in range(3)] return max(raw_range) + def transform_coordinates(self, position): + """ + :param position Coordinates x, y, z in the segment. + :return: Transformed position. + """ + rotation_matrix = euler_to_rotation_matrix(self._rotation) + return add(matrix_vector_mult(rotation_matrix, position), self._translation) + def get_raw_region(self): """ Get the raw region, a child of base region, into which the raw segmentation was loaded. @@ -425,17 +602,48 @@ def _transformation_change(self): for transformation_change_callback in self._transformation_change_callbacks: transformation_change_callback(self) - def get_rotation(self): + def get_rotation_radians(self): return self._rotation - def set_rotation(self, rotation, notify=True): + def set_rotation_radians(self, rotation, notify=True): """ Set segment rotation, which applies before translation. - :param rotation: Rotation as list of 3 Euler angles in degrees. + :param rotation: Rotation as list of 3 Euler angles in radians. :param notify: Set to False to avoid notification to clients if setting translation afterwards. """ assert len(rotation) == 3 - self._rotation = rotation + self._rotation = copy.copy(rotation) + if notify: + self._transformation_change() + + def get_rotation_degrees(self): + return [math.degrees(rad) for rad in self._rotation] + + def set_rotation_degrees(self, rotation, notify=True): + """ + Set segment rotation, which applies before translation. + :param rotation: Rotation as list of 3 Euler angles in degrees. + :param notify: Set to False to avoid notification to clients if setting translation afterwards. + """ + self.set_rotation_radians([math.radians(deg) for deg in rotation], notify) + + def rotate_about_point_axis(self, centre, axis, angle_radians, notify=True): + """ + Update rotation and translation parameters to include a subsequent rotation about a centre. + :param centre: Centre of subsequent rotation (after initial rotation and translation applied). + :param axis: Axis of subsequent rotation (after initial rotation and translation applied). + :param angle_radians: Rotation in radians in a right hand sense about axis. + :param notify: Set to False to avoid notification to clients if setting rotation afterwards. + """ + mat1 = euler_to_rotation_matrix(self._rotation) + centre_translation1 = matrix_vector_mult(mat1, centre) + mat2 = axis_angle_to_rotation_matrix(axis, angle_radians) + product_mat = matrix_mult(mat2, mat1) + centre_translation2 = matrix_vector_mult(product_mat, centre) + self._rotation = rotation_matrix_to_euler(product_mat) + # correct translation of centre by new rotation: + centre_offset = sub(centre_translation1, centre_translation2) + self._translation = add(self._translation, centre_offset) if notify: self._transformation_change() @@ -449,7 +657,28 @@ def set_translation(self, translation, notify=True): :param notify: Set to False to avoid notification to clients if setting rotation afterwards. """ assert len(translation) == 3 - self._translation = translation + self._translation = copy.copy(translation) + if notify: + self._transformation_change() + + def is_ignore_orientation(self): + return self._ignore_orientation + + def set_ignore_orientation(self, ignore_orientation: bool): + """ + :param ignore_orientation: If True, on export put all groups starting with 'orientation' in segment into a group + 'ignore orientation' for subsequent tools to ignore orientation. + """ + self._ignore_orientation = ignore_orientation + + def translate(self, offset, notify=True): + """ + :param offset: 3 value to add to translation + :param notify: Set to False to avoid notification to clients if setting rotation afterwards. + """ + assert len(offset) == 3 + for c in range(3): + self._translation[c] += offset[c] if notify: self._transformation_change() @@ -460,6 +689,12 @@ def get_working_region(self): """ return self._working_region + def get_working_fieldmodule(self): + """ + :return: Zinc Fieldmodule for working region. + """ + return self._working_fieldmodule + def get_working_end_group(self): """ Get group from working region containing connectable end points in segment. @@ -467,6 +702,15 @@ def get_working_end_group(self): """ return self._working_end_group + def get_working_category_group(self, category): + """ + Get group from working region containing connectable end points in segment and in the supplied category. + :param category: AnnotationCategory. + :return: Zinc group containing connectable end points in that category. + """ + category_group = self._working_fieldmodule.findFieldByName(category.get_group_name()).castGroup() + return category_group if category_group.isValid() else None + def update_annotation_category(self, annotation, old_category=AnnotationCategory.EXCLUDE): """ Ensures special groups representing annotion categories contain via addition or removal the @@ -489,17 +733,16 @@ def update_annotation_category(self, annotation, old_category=AnnotationCategory group_add_group_local_contents(new_category_group, annotation_group) self._update_working_end_group() - def update_annotation_category_groups(self, annotations): + def update_annotation_category_groups(self): """ Rebuild all annotation category groups e.g. after loading settings. - :param annotations: List of all annotations from stitcher. """ with ChangeManager(self._raw_fieldmodule): # clear all category groups for category in AnnotationCategory: category_group = self.get_category_group(category) category_group.clear() - for annotation in annotations: + for annotation in self._annotations: annotation_group = self.get_annotation_group(annotation) if annotation_group: category_group = self.get_category_group(annotation.get_category()) @@ -510,29 +753,63 @@ def _update_working_end_group(self): """ Ensure working end group contains all connectable end points. """ - connectable_node_groups = [] - for category in AnnotationCategory: - if category.is_connectable(): - category_group = self.get_category_group(category) - node_group = category_group.getNodesetGroup(self._raw_nodes) - if node_group.isValid() and (node_group.getSize() > 0): - connectable_node_groups.append(node_group) + # list of (raw_category_node_group, working_category_node_group) capable of connections with ChangeManager(self._working_fieldmodule): - self._working_end_group.clear() working_datapoints = \ self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + connectable_node_groups = [] + for category in AnnotationCategory: + if category.is_connectable(): + category_group = self.get_category_group(category) + category_node_group = category_group.getNodesetGroup(self._raw_nodes) + working_category_group = self.get_working_category_group(category) + if category_node_group.isValid() and working_category_group: + working_category_group.clear() + working_category_node_group = working_category_group.getOrCreateNodesetGroup(working_datapoints) + connectable_node_groups.append((category_node_group, working_category_node_group)) + self._working_end_group.clear() working_node_group = self._working_end_group.getOrCreateNodesetGroup(working_datapoints) working_nodeiterator = working_datapoints.createNodeiterator() working_node = working_nodeiterator.next() while working_node.isValid(): node_identifier = working_node.getIdentifier() raw_node = self._raw_nodes.findNodeByIdentifier(node_identifier) - for node_group in connectable_node_groups: + for node_group, working_category_node_group in connectable_node_groups: if node_group.containsNode(raw_node): working_node_group.addNode(working_node) - break; + working_category_node_group.addNode(working_node) + break working_node = working_nodeiterator.next() + def get_selected_end_points(self): + """ + Get end points in selected elements in segment. This includes node points with a single parent element, and + certain interior node points where the annotation changed. + :return: List of node identifiers, list of annotations for each node. + """ + root_scene = self._base_region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return [], [] + fieldmodule = self._raw_region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return [], [] + node_ids = [] + node_annotations = [] + elementiterator = selection_mesh_group.createElementiterator() + element = elementiterator.next() + while element.isValid(): + element_id = element.getIdentifier() + for node_id in self._element_node_ids[element_id]: + if (node_id in self._end_node_ids) or (node_id in self._interior_end_node_ids): + if node_id not in node_ids: + node_ids.append(node_id) + node_annotations.append(self._end_point_data[node_id][3]) + element = elementiterator.next() + return node_ids, node_annotations + def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): """ Compute best fit line to path coordinates, and mean radius of unfiltered points. @@ -619,3 +896,55 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 # [a_inv[1][0] * a[0][0] + a_inv[1][1] * a[1][0], # a_inv[1][0] * a[0][1] + a_inv[1][1] * a[1][1]]) return start_x, end_x, mean_r, mean_projection_error + + +def generate_datapoints(region, px, start_data_identifier=None, coordinate_field_name="coordinates", + field_names_and_values=[], group_name=None): + """ + Generate a set of datapoints in the region. + :param region: Zinc Region. + :param px: Coordinates of data points. + :param start_data_identifier: Optional first datapoint identifier to use. + :param coordinate_field_name: Optional name of coordinate field to define, if omitted use "coordinates". + :param field_names_and_values: Optional lists of (field_name, list of values) for additional fields to + define on the datapoints. Values may be scalar or vector (list of lists) real, or string. + Must be same number of values as number of points. + :param group_name: Optional name of group to put new datapoints in. + :return: next datapoint identifier + """ + fieldmodule = region.getFieldmodule() + with ChangeManager(fieldmodule): + coordinates = find_or_create_field_coordinates(fieldmodule, name=coordinate_field_name) + group = find_or_create_field_group(fieldmodule, group_name) if group_name else None + + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + data_identifier = start_data_identifier if (start_data_identifier is not None) else \ + max(get_maximum_node_identifier(datapoints), 0) + 1 + data_group = group.getOrCreateNodesetGroup(datapoints) if group else datapoints + + nodetemplate = datapoints.createNodetemplate() + nodetemplate.defineField(coordinates) + fields_values = [] # (field, is_string, values) + for field_name, field_values in field_names_and_values: + is_string = isinstance(field_values[0], str) + if is_string: + field = find_or_create_field_stored_string(fieldmodule, field_name, managed=True) + else: + components_count = len(field_values[0]) if isinstance(field_values[0], list) else 1 + field = find_or_create_field_finite_element(fieldmodule, field_name, components_count, managed=True) + nodetemplate.defineField(field) + fields_values.append((field, is_string, field_values)) + + fieldcache = fieldmodule.createFieldcache() + for n, x in enumerate(px): + node = data_group.createNode(data_identifier, nodetemplate) + fieldcache.setNode(node) + coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) + for field, is_string, values in fields_values: + if is_string: + field.assignString(fieldcache, values[n]) + else: + field.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, values[n]) + data_identifier += 1 + + return data_identifier diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index eefa454..de6c36b 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -10,13 +10,24 @@ from cmlibs.zinc.element import Element, Elementbasis from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node +from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.connection import Connection from segmentationstitcher.segment import Segment -from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations +from segmentationstitcher.annotation import AnnotationCategory import copy +import logging import math from pathlib import Path +import re + + +logger = logging.getLogger(__name__) + + +def natural_sort_key(s): + # Split the string by numeric parts, converting numbers to integers + return [int(c) if c.isdigit() else c.lower() for c in re.split('([0-9]+)', s)] class Stitcher: @@ -24,79 +35,97 @@ class Stitcher: Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. """ - def __init__(self, segmentation_file_names: list, network_group1_keywords, network_group2_keywords): + def __init__(self, segmentation_file_names: list, network_group1_keywords, network_group2_keywords, + endpoints_file_names=None): """ :param segmentation_file_names: List of filenames containing raw segmentations in Zinc format. :param network_group1_keywords: List of keywords. Segmented networks annotated with any of these keywords are initially assigned to network group 1, allowing them to be stitched together. :param network_group2_keywords: List of keywords. Segmented networks annotated with any of these keywords are initially assigned to network group 2, allowing them to be stitched together. + :param endpoints_file_names: Optional list of files defining additional markers for applying external labels for + ends of networks. Slicer3D markup json files currently supported. Files must start with the same stem name + as the segmentation files to load into that segment. """ + self._segmentation_file_names = sorted(segmentation_file_names, key=natural_sort_key) + self._network_group1_keywords = copy.deepcopy(network_group1_keywords) + self._network_group2_keywords = copy.deepcopy(network_group2_keywords) + self._endpoints_file_names = endpoints_file_names if endpoints_file_names else [] self._context = Context("Segmentation Stitcher") self._root_region = self._context.getDefaultRegion() self._stitch_region = self._root_region.createRegion() - self._annotations = [] - self._network_group1_keywords = copy.deepcopy(network_group1_keywords) - self._network_group2_keywords = copy.deepcopy(network_group2_keywords) + self._annotations = [] # note all segments and connections share this common list self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] self._connections = [] self._max_distance = 0.0 - self._version = 1 # increment when new settings added to migrate older serialised settings + self._version = "1.1.0" # increment when new settings added to migrate older serialised settings + unused_endpoints_file_names = copy.copy(self._endpoints_file_names) + unused_endpoints_file_name_stems = [Path(file_path).stem for file_path in unused_endpoints_file_names] with HierarchicalChangeManager(self._root_region): max_range_reciprocal_sum = 0.0 - for segmentation_file_name in segmentation_file_names: - name = Path(segmentation_file_name).name - segment = Segment(name, segmentation_file_name, self._root_region) - max_range_reciprocal_sum += 1.0 / segment.get_max_range() + zero_range_segments_count = 0 + for segmentation_file_name in self._segmentation_file_names: + file_path = Path(segmentation_file_name) + name = file_path.name + segment = Segment(name, segmentation_file_name, self._root_region, self._annotations, + self._network_group1_keywords, self._network_group2_keywords, self._term_keywords) + name_stem = file_path.stem + used_endpoints_file_indexes = [] + for ix, endpoints_file_name_stem in enumerate(unused_endpoints_file_name_stems): + if name_stem in endpoints_file_name_stem: + segment.define_endpoints(unused_endpoints_file_names[ix]) + used_endpoints_file_indexes.append(ix) + for ix in reversed(used_endpoints_file_indexes): + del unused_endpoints_file_name_stems[ix] + del unused_endpoints_file_names[ix] + segment_max_range = segment.get_max_range() + if segment_max_range > 0.0: + max_range_reciprocal_sum += 1.0 / segment_max_range + else: + zero_range_segments_count += 1 self._segments.append(segment) - segment_annotations = region_get_annotations( - segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, - self._term_keywords) - for segment_annotation in segment_annotations: - name = segment_annotation.get_name() - term = segment_annotation.get_term() - index = 0 - for annotation in self._annotations: - if annotation.get_name() == name: - existing_term = annotation.get_term() - if term != existing_term: - print("Warning: Found existing annotation with name", name, - "but existing term", existing_term, "does not equal new term", term) - if term and (existing_term is None): - annotation.set_term(term) - break # exists already - if name > annotation.get_name(): - index += 1 - else: - # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), - # "category", segment_annotation.get_category()) - self._annotations.insert(index, segment_annotation) - # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" + # by default put all GENERAL annotations without terms into the EXCLUDE category, + # except "marker" and those starting with "orientation" for annotation in self._annotations: + annotation_name = annotation.get_name() if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and - (annotation.get_name() != "marker")): + ((annotation_name != "marker") and not annotation_name.startswith("orientation"))): # print("Exclude general annotation", annotation.get_name(), "with no term") annotation.set_category(AnnotationCategory.EXCLUDE) + self._mean_segment_length = 1.0 if self._segments: + if max_range_reciprocal_sum > 0.0: + self._mean_segment_length = ( + (len(self._segments) - zero_range_segments_count) / max_range_reciprocal_sum) with HierarchicalChangeManager(self._root_region): - self._max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum + self._max_distance = 0.25 * self._mean_segment_length for segment in self._segments: - segment.create_end_point_directions(self._annotations, self._max_distance) - segment.update_annotation_category_groups(self._annotations) + segment.create_end_point_directions(self._max_distance) + segment.update_annotation_category_groups() for annotation in self._annotations: annotation.set_category_change_callback(self._annotation_category_change) + for endpoints_file_name in unused_endpoints_file_names: + logger.warning('Stitcher: No segment matched to endpoint file: ' + endpoints_file_name) + + SEGMENTATION_STITCHER_SETTINGS_ID = "segmentation stitcher settings" def decode_settings(self, settings_in: dict): """ Update stitcher settings from dictionary of serialised settings. :param settings_in: Dictionary of settings as produced by encode_settings(). """ - assert settings_in.get("annotations") and settings_in.get("segments") and settings_in.get("version"), \ - "Stitcher.decode_settings: Invalid settings dictionary" - # settings_version = settings_in["version"] + settings_version = settings_in.get("version") + assert (settings_in.get("annotations") and settings_in.get("segments") and + (settings_in.get("id", self.SEGMENTATION_STITCHER_SETTINGS_ID) == + self.SEGMENTATION_STITCHER_SETTINGS_ID) and + settings_version), "Stitcher.decode_settings: Invalid settings dictionary" settings = self.encode_settings() settings.update(settings_in) + # migrate from integer version number to string "major#.minor#.patch#" + if isinstance(settings_version, int): + settings_version = settings["version"] = "1.0.0" + # assert settings_version == "1.1.0" # future: migrate if version changes # update annotations and warn about differences processed_count = 0 @@ -109,8 +138,8 @@ def decode_settings(self, settings_in: dict): processed_count += 1 break else: - print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, - "in settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + + " in settings not found; ignoring. Have input files changed?") if processed_count != len(self._annotations): for annotation in self._annotations: name = annotation.get_name() @@ -119,30 +148,30 @@ def decode_settings(self, settings_in: dict): if (annotation_settings["name"] == name) and (annotation_settings["term"] == term): break else: - print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, - "not found in settings; using defaults. Have input files changed?") + logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + + " not found in settings; using defaults. Have input files changed?") # update segment settings and warn about differences processed_count = 0 for segment_settings in settings["segments"]: - name = segment_settings["name"] + segment_name = segment_settings["name"] for segment in self._segments: - if segment.get_name() == name: + if segment.get_name() == segment_name: segment.decode_settings(segment_settings) processed_count += 1 break else: - print("WARNING: Segmentation Stitcher. Segment with name", name, - "in settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " in settings not found; ignoring. Have input files changed?") if processed_count != len(self._segments): for segment in self._segments: - name = segment.get_name() + segment_name = segment.get_name() for segment_settings in settings["segments"]: - if segment_settings["name"] == name: + if segment_settings["name"] == segment_name: break else: - print("WARNING: Segmentation Stitcher. Segment with name", name, - "not found in settings; using defaults. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " not found in settings; using defaults. Have input files changed?") # create connections from stitcher settings' connection serialisations assert len(self._connections) == 0, "Cannot decode connections after any exist" @@ -154,22 +183,23 @@ def decode_settings(self, settings_in: dict): connection_segments.append(segment) break else: - print("WARNING: Segmentation Stitcher. Segment with name", segment_name, - "in connection settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " in connection settings not found; ignoring. Have input files changed?") if len(connection_segments) >= 2: - connection = self.create_connection(connection_segments, connection_settings) + connection = self.create_connection(connection_segments, connection_settings, build_links=False) with HierarchicalChangeManager(self._root_region): for segment in self._segments: - segment.update_annotation_category_groups(self._annotations) + segment.update_annotation_category_groups() for connection in self._connections: - connection.update_annotation_category_groups(self._annotations) + connection.update_annotation_category_groups() def encode_settings(self) -> dict: """ :return: Dictionary of Stitcher settings ready to serialise to JSON. """ settings = { + "id": self.SEGMENTATION_STITCHER_SETTINGS_ID, "annotations": [annotation.encode_settings() for annotation in self._annotations], "connections": [connection.encode_settings() for connection in self._connections], "segments": [segment.encode_settings() for segment in self._segments], @@ -189,15 +219,17 @@ def _annotation_category_change(self, annotation, old_category): segment.update_annotation_category(annotation, old_category) for connection in self._connections: connection.build_links(self._max_distance) - connection.update_annotation_category_groups(self._annotations) + connection.update_annotation_category_groups() def get_annotations(self): return self._annotations - def create_connection(self, segments, connection_settings={}): + def create_connection(self, segments, connection_settings={}, build_links=True): """ :param segments: List of 2 Stitcher Segment objects to connect. :param connection_settings: Optional serialisation of connection to read before building links. + :param build_links: If True (default) automatically build any links. If False, called when decoding settings, + only the link graphics are made for existing links, which is required to not rebuild removed links. :return: Connection object or None if invalid segments or connection between segments already exists """ if len(segments) != 2: @@ -214,8 +246,11 @@ def create_connection(self, segments, connection_settings={}): if connection_settings: connection.decode_settings(connection_settings) self._connections.append(connection) - connection.build_links() - connection.update_annotation_category_groups(self._annotations) + if build_links: + connection.build_links() + else: + connection.build_link_objects() + connection.update_annotation_category_groups() return connection def delete_connection(self, connection): @@ -241,7 +276,17 @@ def get_root_region(self): def get_segments(self): return self._segments + def get_mean_segment_length(self): + """ + Get representative mean segment length for sizing graphics and tolerances. + :return: Real length > 0.0. + """ + return self._mean_segment_length + def get_version(self): + """ + :return: Stitcher version number string "major#.minor#.patch#" + """ return self._version def stitch(self, region): @@ -299,6 +344,7 @@ def stitch(self, region): segment_node_maps = [{} for segment in self._segments] # maps from segment node id to output node id # stitch segments in order of connections, followed by unconnected segments + default_radius = self._max_distance * 0.01 for connection in self._connections: segment_node_map_pair = [segment_node_maps[self._segments.index(segment)] for segment in connection.get_segments()] @@ -308,7 +354,7 @@ def stitch(self, region): node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier) output_segment_elements = True processed_segments.append(segment) if segment is connection.get_segments()[1]: @@ -327,7 +373,7 @@ def stitch(self, region): node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier) element_identifier = _output_segment_elements( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, @@ -342,7 +388,21 @@ def write_output_segmentation_file(self, file_name): def _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier): + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier): + """ + :param segment: The segment to output. + :param segment_node_map: maps from segment node id to output node id + :param annotation_groups: map from annotation name to list of Zinc groups (2nd is term group) + :param fieldmodule: Fieldmodule for output region. + :param fieldcache: Fieldcache for output region. + :param coordinates: Coordinates field. + :param radius: Radius field. + :param rgb: Optional rgb field. + :param default_radius: Radius value to use if not define on nodes or datapoints. + :param node_identifier: starting node identfier. + :param datapoint_identifier: starting datapoint identifier. + :return: Next node_identifier, next datapoint_identifier + """ raw_region = segment.get_raw_region() raw_fieldmodule = raw_region.getFieldmodule() raw_coordinates = raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -350,8 +410,7 @@ def _output_segment_nodes_and_markers( raw_rgb = raw_fieldmodule.findFieldByName("rgb").castFiniteElement() if rgb else None raw_marker_name = raw_fieldmodule.findFieldByName("marker_name").castStoredString() raw_nodes = raw_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] - rotation_matrix = euler_to_rotation_matrix(rotation) + rotation_matrix = euler_to_rotation_matrix(segment.get_rotation_radians()) translation = segment.get_translation() raw_groups = get_group_list(raw_fieldmodule) raw_nodeset_groups = [] @@ -361,6 +420,10 @@ def _output_segment_nodes_and_markers( segment_group = find_or_create_field_group(fieldmodule, segment.get_name()) segment_node_group = segment_group.getOrCreateNodesetGroup(nodes) segment_datapoint_group = segment_group.getOrCreateNodesetGroup(datapoints) + orientation_ignore_group = find_or_create_field_group(fieldmodule, "orientation ignore") \ + if segment.is_ignore_orientation() else None + orientation_ignore_nodeset_group =\ + orientation_ignore_group.getOrCreateNodesetGroup(nodes) if orientation_ignore_group else None for raw_group in raw_groups: group_name = raw_group.getName() groups = annotation_groups.get(group_name) @@ -368,7 +431,10 @@ def _output_segment_nodes_and_markers( raw_nodeset_group = raw_group.getNodesetGroup(raw_nodes) if raw_nodeset_group.isValid() and (raw_nodeset_group.getSize() > 0): raw_nodeset_groups.append(raw_nodeset_group) - nodeset_group_lists.append([group.getOrCreateNodesetGroup(nodes) for group in groups]) + nodeset_group_list = [group.getOrCreateNodesetGroup(nodes) for group in groups] + if orientation_ignore_group and group_name.startswith('orientation'): + nodeset_group_list.append(orientation_ignore_nodeset_group) + nodeset_group_lists.append(nodeset_group_list) raw_fieldcache = raw_fieldmodule.createFieldcache() raw_nodeiterator = raw_nodes.createNodeiterator() raw_node = raw_nodeiterator.next() @@ -385,9 +451,13 @@ def _output_segment_nodes_and_markers( x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + if result != RESULT_OK: + r = default_radius radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) if rgb: result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + if result != RESULT_OK: + rgb_value = [1.0, 1.0, 1.0] rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) segment_node_map[raw_node_identifier] = node_identifier segment_node_group.addNode(node) @@ -406,9 +476,13 @@ def _output_segment_nodes_and_markers( x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + if result != RESULT_OK: + r = default_radius radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) if rgb: result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + if result != RESULT_OK: + rgb_value = [1.0, 1.0, 1.0] rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) name = raw_marker_name.evaluateString(raw_fieldcache) marker_name.assignString(fieldcache, name) @@ -466,14 +540,15 @@ def _output_connection_elements(connection, segment_node_maps, annotation_groups connection_group = find_or_create_field_group(fieldmodule, connection.get_name()) mesh = fieldmodule.findMeshByDimension(1) connection_mesh_group = connection_group.getOrCreateMeshGroup(mesh) - linked_nodes = connection.get_linked_nodes() - for annotation_name, annotation_linked_nodes in linked_nodes.items(): + annotation_links = connection.get_annotation_links() + for annotation_name, links in annotation_links.items(): groups = annotation_groups.get(annotation_name) mesh_groups = [group.getOrCreateMeshGroup(mesh) for group in groups] mesh_groups.append(connection_mesh_group) - for segment_node_identifiers in annotation_linked_nodes: + for link in links: + node_identifiers = link['node identifiers'] element = mesh.createElement(element_identifier, elementtemplate) - element.setNodesByIdentifier(eft, [segment_node_maps[n][segment_node_identifiers[n]] for n in range(2)]) + element.setNodesByIdentifier(eft, [segment_node_maps[s][node_identifiers[s]] for s in range(2)]) for mesh_group in mesh_groups: mesh_group.addElement(element) element_identifier += 1 diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf index d34997a..c3edbf6 100644 --- a/tests/resources/vagus-segment1.exf +++ b/tests/resources/vagus-segment1.exf @@ -431,6 +431,11 @@ Node: 84 -4.364982277527066e-01 -3.626326906335586e-01 1.000000000000000e-02 +Node: 85 + 2.532597965468619e+00 + 1.278581726279602e-01 + 5.305649151562682e-01 + 1.000000000000000e-02 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -444,18 +449,12 @@ Shape. Dimension=0 3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 -Node: 85 +Node: 1 9.014003757666196e-01 -3.826434512807377e-03 -1.393280944888287e-02 - "landmark 1" - 0.01 -Node: 86 - 2.532597965468619e+00 - 1.278581726279602e-01 - 5.305649151562682e-01 - orientation - 0.01 + "left level of inferior border of jugular foramen on the vagus nerve" + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -770,14 +769,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..38,43..46 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 49..84 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 47..82 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..11 @@ -808,4 +807,8 @@ Element group: Group name: marker !#nodeset datapoints Node group: -85..86 +1 +Group name: orientation anterior +!#nodeset nodes +Node group: +85 diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index 03d340a..a2e152f 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -381,25 +381,16 @@ Node: 74 3.870986066902640e-02 -5.955710515963688e-01 1.000000000000000e-02 -!#nodeset datapoints -Define node template: node2 -Shape. Dimension=0 -#Fields=3 -1) coordinates, coordinate, rectangular cartesian, real, #Components=3 - x. #Values=1 (value) - y. #Values=1 (value) - z. #Values=1 (value) -2) marker_name, field, string, #Components=1 - 1. #Values=1 (value) -3) radius, field, rectangular cartesian, real, #Components=1 - 1. #Values=1 (value) -Node template: node2 Node: 75 + 4.349351322583406e+00 + -1.214784316068148e-00 + -2.929821893185177e-01 + 3.000000000000000e-01 +Node: 76 2.398918879869167e+00 -1.432597145502751e-01 2.241355737421549e-01 - orientation - 0.01 + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -646,6 +637,9 @@ Element: 70 Element: 71 Nodes: 74 63 +Element: 72 + Nodes: + 11 75 Group name: 00001 !#nodeset nodes Node group: @@ -681,14 +675,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..35 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 39..74 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 36..71 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..11 @@ -702,6 +696,13 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..35 +Group name: left A branch END +!#nodeset nodes +Node group: +11,75 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +72 Group name: left vagus X nerve trunk !#nodeset nodes Node group: @@ -709,7 +710,7 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 1..10 -Group name: marker -!#nodeset datapoints +Group name: orientation anterior +!#nodeset nodes Node group: -75 +76 diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf index 55e9809..cc25c37 100644 --- a/tests/resources/vagus-segment3.exf +++ b/tests/resources/vagus-segment3.exf @@ -351,6 +351,11 @@ Node: 68 -4.359756419261278e-01 -5.458579284260732e-01 1.000000000000000e-02 +Node: 69 + 3.000000000000000e+00 + 0.000000000000000e+00 + 7.000000000000000e-01 + 1.000000000000000e-02 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -364,18 +369,12 @@ Shape. Dimension=0 3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 -Node: 69 - 3.000000000000000e+00 - 0.000000000000000e+00 - 7.000000000000000e-01 - orientation - 0.01 -Node: 70 +Node: 1 1.599724956533351e+00 3.788960603141545e-03 4.423695723249146e-03 - "landmark 2" - 0.01 + "left level of superior border of the clavicle on the vagus nerve" + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -632,14 +631,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 14..38 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 43..68 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 40..65 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..10 @@ -670,7 +669,11 @@ Element group: Group name: marker !#nodeset datapoints Node group: -69..70 +1 +Group name: orientation anterior +!#nodeset nodes +Node group: +69 Group name: unknown !#nodeset nodes Node group: diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 422872f..9fec4f2 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -35,23 +35,31 @@ def test_io_vagus1(self): assertAlmostEqualList(self, zero, segment12.get_translation(), delta=TOL) segment12.set_translation(new_translation) annotations1 = stitcher1.get_annotations() - self.assertEqual(7, len(annotations1)) - self.assertEqual(1, stitcher1.get_version()) + self.assertEqual(9, len(annotations1)) + self.assertEqual("1.1.0", stitcher1.get_version()) annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) - self.assertEqual("http://purl.obolibrary.org/obo/UBERON_0000124", annotation11.get_term()) + self.assertEqual("http://uri.interlex.org/base/ilx_0103892", annotation11.get_term()) self.assertEqual(AnnotationCategory.GENERAL, annotation11.get_category()) annotation12 = annotations1[1] self.assertEqual("Fascicle", annotation12.get_name()) self.assertEqual("http://uri.interlex.org/base/ilx_0738426", annotation12.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_2, annotation12.get_category()) annotation15 = annotations1[4] - self.assertEqual("left vagus X nerve trunk", annotation15.get_name()) - self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation15.get_term()) + self.assertEqual("left A branch END", annotation15.get_name()) + self.assertIsNone(annotation15.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation15.get_category()) - annotation17 = annotations1[6] - self.assertEqual("unknown", annotation17.get_name()) - self.assertEqual(AnnotationCategory.EXCLUDE, annotation17.get_category()) + annotation16 = annotations1[5] + self.assertEqual("left vagus X nerve trunk", annotation16.get_name()) + self.assertEqual('http://uri.interlex.org/base/ilx_0736691', annotation16.get_term()) + self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation16.get_category()) + annotation17 = annotations1[7] + self.assertEqual("orientation anterior", annotation17.get_name()) + self.assertEqual(AnnotationCategory.GENERAL, annotation17.get_category()) + annotation18 = annotations1[8] + self.assertEqual("unknown", annotation18.get_name()) + self.assertEqual(AnnotationCategory.EXCLUDE, annotation18.get_category()) + stitcher1.create_connection([segments1[0], segments1[1]]) connections = stitcher1.get_connections() @@ -69,10 +77,10 @@ def test_io_vagus1(self): self.assertEqual(1, exclude13_mesh_group.getSize()) self.assertEqual(26, general13_mesh_group.getSize()) self.assertFalse(indep13_mesh_group.isValid()) - annotation17_group = segment13.get_annotation_group(annotation17) - annotation17_mesh_group = annotation17_group.getMeshGroup(mesh1d) - self.assertEqual(1, annotation17_mesh_group.getSize()) - annotation17.set_category(AnnotationCategory.INDEPENDENT_NETWORK) + annotation18_group = segment13.get_annotation_group(annotation18) + annotation18_mesh_group = annotation18_group.getMeshGroup(mesh1d) + self.assertEqual(1, annotation18_mesh_group.getSize()) + annotation18.set_category(AnnotationCategory.INDEPENDENT_NETWORK) indep13_mesh_group = indep13_group.getMeshGroup(mesh1d) self.assertEqual(0, exclude13_mesh_group.getSize()) self.assertEqual(26, general13_mesh_group.getSize()) @@ -80,10 +88,10 @@ def test_io_vagus1(self): settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) - self.assertEqual(7, len(settings["annotations"])) - self.assertEqual(1, settings["version"]) + self.assertEqual(9, len(settings["annotations"])) + self.assertEqual("1.1.0", settings["version"]) assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][6]["category"]) + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][8]["category"]) stitcher2 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) stitcher2.decode_settings(settings) @@ -91,8 +99,8 @@ def test_io_vagus1(self): segment22 = segments2[1] assertAlmostEqualList(self, new_translation, segment22.get_translation(), delta=TOL) annotations2 = stitcher2.get_annotations() - annotation27 = annotations2[6] - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation27.get_category()) + annotation29 = annotations2[8] + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation29.get_category()) def test_align_stitch_vagus1(self): """ @@ -110,7 +118,7 @@ def test_align_stitch_vagus1(self): stitcher = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) segments = stitcher.get_segments() - segments[1].set_rotation([0.0, -10.0, -60.0]) + segments[1].set_rotation_degrees([0.0, -10.0, -60.0]) segments[1].set_translation([5.0, 0.0, 0.0]) segments[2].set_translation([10.0, 0.0, 0.5]) @@ -130,24 +138,59 @@ def test_align_stitch_vagus1(self): connection01 = stitcher.create_connection([segments[0], segments[1]]) connection12 = stitcher.create_connection([segments[1], segments[2]]) - connection01.optimise_transformation() - assertAlmostEqualList(self, [-2.894576, -5.574263, -63.93093], segments[1].get_rotation(), delta=TOL) - assertAlmostEqualList(self, [4.88866, -0.01213587, 0.01357185], segments[1].get_translation(), delta=TOL) - linked_nodes01 = connection01.get_linked_nodes() - self.assertEqual(linked_nodes01, { - "Fascicle": [[22, 28], [35, 12], [40, 23]], - "left vagus X nerve trunk": [[11, 1]]}) - - connection12.optimise_transformation() - assertAlmostEqualList(self, [-4.919549, -2.280625, -13.52467], segments[2].get_rotation(), delta=TOL) - assertAlmostEqualList(self, [9.543171, -0.3494296, 0.03930248], segments[2].get_translation(), delta=TOL) - linked_nodes12 = connection12.get_linked_nodes() - self.assertEqual(linked_nodes12, { - "Fascicle": [[22, 15], [38, 25]], - "left vagus X nerve trunk": [[11, 1]]}) + expected_annotation_links01 = { + "Fascicle": [ + {'lock': False, + 'node identifiers': [22, 28]}, + {'lock': False, + 'node identifiers': [35, 12]}, + {'lock': False, + 'node identifiers': [40, 23]}], + "left vagus X nerve trunk": [ + {'lock': False, + 'node identifiers': [11, 1]}]} + + connection01.auto_align_segment(1) + rotation = segments[1].get_rotation_degrees() + translation = segments[1].get_translation() + assertAlmostEqualList(self, [-4.459501969125895, -8.161074730792063, -58.089501540814254], rotation, delta=TOL) + assertAlmostEqualList(self, [4.901057529124233, 0.004805043555627213, -0.04779580320829241], + translation, delta=TOL) + annotation_links01 = connection01.get_annotation_links() + self.assertEqual(expected_annotation_links01, annotation_links01) + + expected_annotation_links12 = { + "Fascicle": [ + {'lock': False, + 'node identifiers': [22, 15]}, + {'lock': False, + 'node identifiers': [38, 25]}], + "left vagus X nerve trunk": [ + {'lock': False, + 'node identifiers': [11, 1]}]} + + connection12.auto_align_segment(1) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [-3.216043371586617, -5.467042596782779, -0.4267779669299892], rotation, delta=TOL) + assertAlmostEqualList(self, [9.537442541080164, -0.3524223146102781, 0.28070488408317984], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) + + # now align first segment relative to second + connection01.auto_align_segment(0) + rotation = segments[0].get_rotation_degrees() + translation = segments[0].get_translation() + assertAlmostEqualList(self, [0.0022017172050087866, -0.05083254291897361, 1.5180006139100206], + rotation, delta=TOL) + assertAlmostEqualList(self, [-1.7169803818076998e-06, -0.00037724526155702106, -0.0029090218733886335], + translation, delta=TOL) + annotation_links01 = connection01.get_annotation_links() + self.assertEqual(expected_annotation_links01, annotation_links01) output_region = stitcher.get_root_region().createRegion() stitcher.stitch(output_region) + self.assertEqual("1.1.0", stitcher.get_version()) fieldmodule = output_region.getFieldmodule() coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -155,8 +198,8 @@ def test_align_stitch_vagus1(self): datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) mesh = fieldmodule.findMeshByDimension(1) minimums, maximums = evaluate_field_nodeset_range(coordinates, nodes) - assertAlmostEqualList(self, [0.04674543239403558, -1.5276719288528786, -0.5804178855490847], minimums, delta=TOL) - assertAlmostEqualList(self, [13.538987060134247, 1.11238124203403, 0.6470665850902932], maximums, delta=TOL) + assertAlmostEqualList(self, [0.04678894233410661, -1.3448619475857166, -0.5849221355942552], minimums, delta=TOL) + assertAlmostEqualList(self, [13.528908286654149, 1.12292211593189, 1.4370793304399627], maximums, delta=TOL) fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() self.assertTrue(fascicle.isValid()) @@ -169,7 +212,33 @@ def test_align_stitch_vagus1(self): marker = fieldmodule.findFieldByName("marker").castGroup() self.assertTrue(marker.isValid()) marker_datapoint_group = marker.getNodesetGroup(datapoints) - self.assertEqual(marker_datapoint_group.getSize(), 5) + self.assertEqual(marker_datapoint_group.getSize(), 2) + + # try auto-align with gap in 2 stages + + segments[0].set_rotation_degrees([0.0, 0.0, 0.0]) + segments[0].set_translation([0.0, 0.0, 0.0]) + segments[1].set_rotation_degrees([0.0, -10.0, -60.0]) + segments[1].set_translation([5.0, 0.0, 0.0]) + segments[2].set_rotation_degrees([0.0, 0.0, 40.0]) + segments[2].set_translation([10.0, 0.0, 0.5]) + + connection12.auto_align_segment(1, phase1_align=True, gap_distance=0.1, phase_2_optimize=False) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [2.7968079813220417, -7.433708312768542, 39.583915044651825], rotation, delta=TOL) + assertAlmostEqualList(self, [9.734631815723224, -0.028181186581394506, 0.505539399215602], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) + + connection12.auto_align_segment(1, phase1_align=False, gap_distance=0.1, phase_2_optimize=True) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [1.1774294709982658, -7.223345962981031, -3.1504154683525827], rotation, delta=TOL) + assertAlmostEqualList(self, [9.735859443921962, -0.003902802918894957, 0.4936970140282092], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) + if __name__ == "__main__": unittest.main()