From e2008f2fa3bbe0b31d235341bee1c4f7f4c5b2ef Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 25 Aug 2025 16:07:11 +1200 Subject: [PATCH 01/24] Output settings with id and major.minor.patch version --- src/segmentationstitcher/stitcher.py | 20 ++++++++++++++++---- tests/test_vagus.py | 5 +++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index eefa454..0bff40a 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -42,7 +42,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._segments = [] self._connections = [] self._max_distance = 0.0 - self._version = 1 # increment when new settings added to migrate older serialised settings + self._version = "1.0.0" # increment when new settings added to migrate older serialised settings with HierarchicalChangeManager(self._root_region): max_range_reciprocal_sum = 0.0 for segmentation_file_name in segmentation_file_names: @@ -87,16 +87,24 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo for annotation in self._annotations: annotation.set_category_change_callback(self._annotation_category_change) + SEGMENTATION_STITCHER_SETTINGS_ID = "segmentation stitcher settings" + def decode_settings(self, settings_in: dict): """ Update stitcher settings from dictionary of serialised settings. :param settings_in: Dictionary of settings as produced by encode_settings(). """ - assert settings_in.get("annotations") and settings_in.get("segments") and settings_in.get("version"), \ - "Stitcher.decode_settings: Invalid settings dictionary" - # settings_version = settings_in["version"] + settings_version = settings_in.get("version") + assert (settings_in.get("annotations") and settings_in.get("segments") and + (settings_in.get("id", self.SEGMENTATION_STITCHER_SETTINGS_ID) == + self.SEGMENTATION_STITCHER_SETTINGS_ID) and + settings_version), "Stitcher.decode_settings: Invalid settings dictionary" settings = self.encode_settings() settings.update(settings_in) + # migrate from integer version number to string "major#.minor#.patch#" + if isinstance(settings_version, int): + settings_version = settings["version"] = "1.0.0" + assert settings_version == "1.0.0" # future: migrate if version changes # update annotations and warn about differences processed_count = 0 @@ -170,6 +178,7 @@ def encode_settings(self) -> dict: :return: Dictionary of Stitcher settings ready to serialise to JSON. """ settings = { + "id": self.SEGMENTATION_STITCHER_SETTINGS_ID, "annotations": [annotation.encode_settings() for annotation in self._annotations], "connections": [connection.encode_settings() for connection in self._connections], "segments": [segment.encode_settings() for segment in self._segments], @@ -242,6 +251,9 @@ def get_segments(self): return self._segments def get_version(self): + """ + :return: Stitcher version number string "major#.minor#.patch#" + """ return self._version def stitch(self, region): diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 422872f..f86d991 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -36,7 +36,7 @@ def test_io_vagus1(self): segment12.set_translation(new_translation) annotations1 = stitcher1.get_annotations() self.assertEqual(7, len(annotations1)) - self.assertEqual(1, stitcher1.get_version()) + self.assertEqual("1.0.0", stitcher1.get_version()) annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) self.assertEqual("http://purl.obolibrary.org/obo/UBERON_0000124", annotation11.get_term()) @@ -81,7 +81,7 @@ def test_io_vagus1(self): settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) self.assertEqual(7, len(settings["annotations"])) - self.assertEqual(1, settings["version"]) + self.assertEqual("1.0.0", settings["version"]) assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][6]["category"]) @@ -148,6 +148,7 @@ def test_align_stitch_vagus1(self): output_region = stitcher.get_root_region().createRegion() stitcher.stitch(output_region) + self.assertEqual("1.0.0", stitcher.get_version()) fieldmodule = output_region.getFieldmodule() coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement() From df79681be61b15c10db476d7b469e1f35a298850 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 3 Sep 2025 22:28:50 +1200 Subject: [PATCH 02/24] Read user endpoints files Add methods to get range and midpoint of segments and connections. Sort segment names more naturally. --- src/segmentationstitcher/connection.py | 19 ++++ src/segmentationstitcher/segment.py | 133 ++++++++++++++++++++++++- src/segmentationstitcher/stitcher.py | 40 +++++++- 3 files changed, 182 insertions(+), 10 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 892c520..446a24b 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -5,6 +5,7 @@ add, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_vector_mult, mult, normalize, sub) from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) +from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range from cmlibs.utils.zinc.general import ChangeManager from cmlibs.utils.zinc.group import group_add_group_local_contents from cmlibs.zinc.element import Element, Elementbasis @@ -149,6 +150,24 @@ def get_linked_nodes(self): """ return self._linked_nodes + def get_coordinates_midpoint(self): + """ + Get midpoint of linked nodes, if any, which are transformed by the respective segments. + :return: Coordinates at the midpoint in their x, y, z range, or None if no linked nodes. + """ + minimums, maximums = self.get_coordinates_range() + if minimums and maximums: + return [0.5 * (minimum + maximum) for minimum, maximum in zip(minimums, maximums)] + return None + + def get_coordinates_range(self): + """ + Get x, y, z ranges of linked nodes in connection, which are transformed by the respective segments. + :return: Minimum coordinates, maximum coordinates, or None, None if no linked nodes. + """ + nodes = self._region.getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + return evaluate_field_nodeset_range(self._coordinates, nodes) + def optimise_transformation(self): """ Optimise transformation of second segment to align with position and direction of nearest points between diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index c9c6927..61ecd73 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -1,19 +1,26 @@ """ A segment of the segmentation data, generally from a separate image block. """ -from builtins import enumerate - -from cmlibs.maths.vectorops import cross, dot, magnitude, matrix_mult, mult, normalize, set_magnitude, sub +from cmlibs.maths.vectorops import ( + add, cross, dot, euler_to_rotation_matrix, magnitude, matrix_mult, matrix_vector_mult, mult, normalize, + set_magnitude, sub) from cmlibs.utils.zinc.field import ( - get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) -from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range + get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, + find_or_create_field_stored_string) +from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range, get_maximum_node_identifier from cmlibs.utils.zinc.group import group_add_group_local_contents, group_remove_group_local_contents from cmlibs.utils.zinc.general import ChangeManager from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.annotation import AnnotationCategory +import json +import logging import math +import os + + +logger = logging.getLogger(__name__) class Segment: @@ -94,6 +101,49 @@ def encode_settings(self) -> dict: } return settings + def define_endpoints(self, endpoints_file_name): + """ + Read endpoints labels and positions from the endpoints file as markers in the raw region. + These are used to associate externally recognised names with endpoints for later reporting. + :param endpoints_file_name: Name of file in Slicer 3D markups json format. + """ + if os.path.isfile(endpoints_file_name): + with open(endpoints_file_name, "r") as f: + try: + marker_file_data = json.loads(f.read()) + schema = marker_file_data.get("@schema") + markups = marker_file_data.get("markups") + if not (schema and ("slicer" in schema) and ("markups" in schema) and markups): + logger.error("Stitcher endpoints file " + endpoints_file_name + " is not a supported file type") + return + if len(markups) == 0: + logger.warning("Stitcher endpoints file" + endpoints_file_name + " has no markups") + for markup in markups: + coordinate_system = markup.get("coordinateSystem") + if coordinate_system != "LPS": + logger.warning("Stitcher endpoints file" + endpoints_file_name + + " unimplemented coordinateSystem name " + coordinate_system) + coordinate_units = markup.get("coordinateUnits") + control_points = markup.get("controlPoints") + if not (isinstance(control_points, list) and (len(control_points) > 0)): + logger.warning("Stitcher endpoints file" + endpoints_file_name + " has no controlPoints") + continue + # build list of marker labels and positions + marker_labels = [] + marker_positions = [] + for control_point in control_points: + marker_labels.append(control_point["label"]) + x = control_point["position"] + marker_positions.append(x) + generate_datapoints(self._raw_region, marker_positions, + field_names_and_values=[("marker_name", marker_labels)], + group_name="marker") + except json.JSONDecodeError as e: + logger.error("Stitcher endpoints file " + endpoints_file_name + + " exception reading json format " + str(e)) + else: + logger.error("Stitcher endpoints file " + endpoints_file_name + " not found") + def _get_element_node_maps(self): """ Get maps from 1-D elements to nodes and nodes to elements for the raw data. @@ -389,6 +439,19 @@ def get_end_point_fields(self): def get_name(self): return self._name + def get_coordinates_midpoint(self): + """ + :return: Coordinates at the midpoint in their x, y, z range. + """ + return [0.5 * (minimum + maximum) for minimum, maximum in zip(self._raw_minimums, self._raw_maximums)] + + def get_coordinates_range(self): + """ + Get x, y, z ranges of coordinates in raw data. + :return: Minimum coordinates, maximum coordinates. + """ + return self._raw_minimums, self._raw_maximums + def get_max_range(self): """ :return: Maximum range of raw coordinates on any axis x, y, z. @@ -396,6 +459,14 @@ def get_max_range(self): raw_range = [self._raw_maximums[c] - self._raw_minimums[c] for c in range(3)] return max(raw_range) + def transform_coordinates(self, position): + """ + :param position Coordinates x, y, z in the segment. + :return: Transformed position. + """ + rotation_matrix = euler_to_rotation_matrix([math.radians(deg) for deg in self._rotation]) + return add(matrix_vector_mult(rotation_matrix, position), self._translation) + def get_raw_region(self): """ Get the raw region, a child of base region, into which the raw segmentation was loaded. @@ -619,3 +690,55 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 # [a_inv[1][0] * a[0][0] + a_inv[1][1] * a[1][0], # a_inv[1][0] * a[0][1] + a_inv[1][1] * a[1][1]]) return start_x, end_x, mean_r, mean_projection_error + + +def generate_datapoints(region, px, start_data_identifier=None, coordinate_field_name="coordinates", + field_names_and_values=[], group_name=None): + """ + Generate a set of datapoints in the region. + :param region: Zinc Region. + :param px: Coordinates of data points. + :param start_data_identifier: Optional first datapoint identifier to use. + :param coordinate_field_name: Optional name of coordinate field to define, if omitted use "coordinates". + :param field_names_and_values: Optional lists of (field_name, list of values) for additional fields to + define on the datapoints. Values may be scalar or vector (list of lists) real, or string. + Must be same number of values as number of points. + :param group_name: Optional name of group to put new datapoints in. + :return: next datapoint identifier + """ + fieldmodule = region.getFieldmodule() + with ChangeManager(fieldmodule): + coordinates = find_or_create_field_coordinates(fieldmodule, name=coordinate_field_name) + group = find_or_create_field_group(fieldmodule, group_name) if group_name else None + + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + data_identifier = start_data_identifier if (start_data_identifier is not None) else \ + max(get_maximum_node_identifier(datapoints), 0) + 1 + data_group = group.getOrCreateNodesetGroup(datapoints) if group else datapoints + + nodetemplate = datapoints.createNodetemplate() + nodetemplate.defineField(coordinates) + fields_values = [] # (field, is_string, values) + for field_name, field_values in field_names_and_values: + is_string = isinstance(field_values[0], str) + if is_string: + field = find_or_create_field_stored_string(fieldmodule, field_name, managed=True) + else: + components_count = len(field_values[0]) if isinstance(field_values[0], list) else 1 + field = find_or_create_field_finite_element(fieldmodule, field_name, components_count, managed=True) + nodetemplate.defineField(field) + fields_values.append((field, is_string, field_values)) + + fieldcache = fieldmodule.createFieldcache() + for n, x in enumerate(px): + node = data_group.createNode(data_identifier, nodetemplate) + fieldcache.setNode(node) + coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) + for field, is_string, values in fields_values: + if is_string: + field.assignString(fieldcache, values[n]) + else: + field.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, values[n]) + data_identifier += 1 + + return data_identifier diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 0bff40a..0173865 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -15,8 +15,18 @@ from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations import copy +import logging import math from pathlib import Path +import re + + +logger = logging.getLogger(__name__) + + +def natural_sort_key(s): + # Split the string by numeric parts, converting numbers to integers + return [int(c) if c.isdigit() else c.lower() for c in re.split('([0-9]+)', s)] class Stitcher: @@ -24,30 +34,48 @@ class Stitcher: Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. """ - def __init__(self, segmentation_file_names: list, network_group1_keywords, network_group2_keywords): + def __init__(self, segmentation_file_names: list, network_group1_keywords, network_group2_keywords, + endpoints_file_names=None): """ :param segmentation_file_names: List of filenames containing raw segmentations in Zinc format. :param network_group1_keywords: List of keywords. Segmented networks annotated with any of these keywords are initially assigned to network group 1, allowing them to be stitched together. :param network_group2_keywords: List of keywords. Segmented networks annotated with any of these keywords are initially assigned to network group 2, allowing them to be stitched together. + :param endpoints_file_names: Optional list of files defining additional markers for applying external labels for + ends of networks. Slicer3D markup json files currently supported. Files must start with the same stem name + as the segmentation files to load into that segment. """ + self._segmentation_file_names = sorted(segmentation_file_names, key=natural_sort_key) + self._network_group1_keywords = copy.deepcopy(network_group1_keywords) + self._network_group2_keywords = copy.deepcopy(network_group2_keywords) + self._endpoints_file_names = endpoints_file_names if endpoints_file_names else [] self._context = Context("Segmentation Stitcher") self._root_region = self._context.getDefaultRegion() self._stitch_region = self._root_region.createRegion() self._annotations = [] - self._network_group1_keywords = copy.deepcopy(network_group1_keywords) - self._network_group2_keywords = copy.deepcopy(network_group2_keywords) self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] self._connections = [] self._max_distance = 0.0 self._version = "1.0.0" # increment when new settings added to migrate older serialised settings + unused_endpoints_file_names = copy.copy(self._endpoints_file_names) + unused_endpoints_file_name_stems = [Path(file_path).stem for file_path in unused_endpoints_file_names] with HierarchicalChangeManager(self._root_region): max_range_reciprocal_sum = 0.0 - for segmentation_file_name in segmentation_file_names: - name = Path(segmentation_file_name).name + for segmentation_file_name in self._segmentation_file_names: + file_path = Path(segmentation_file_name) + name = file_path.name segment = Segment(name, segmentation_file_name, self._root_region) + name_stem = file_path.stem + used_endpoints_file_indexes = [] + for ix, endpoints_file_name_stem in enumerate(unused_endpoints_file_name_stems): + if name_stem in endpoints_file_name_stem: + segment.define_endpoints(unused_endpoints_file_names[ix]) + used_endpoints_file_indexes.append(ix) + for ix in reversed(used_endpoints_file_indexes): + del unused_endpoints_file_name_stems[ix] + del unused_endpoints_file_names[ix] max_range_reciprocal_sum += 1.0 / segment.get_max_range() self._segments.append(segment) segment_annotations = region_get_annotations( @@ -86,6 +114,8 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo segment.update_annotation_category_groups(self._annotations) for annotation in self._annotations: annotation.set_category_change_callback(self._annotation_category_change) + for endpoints_file_name in unused_endpoints_file_names: + logger.warning('Stitcher: No segment matched to endpoint file: ' + endpoints_file_name) SEGMENTATION_STITCHER_SETTINGS_ID = "segmentation stitcher settings" From 63dc74592c9c5951da412ec22d873983528bdcf0 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 4 Sep 2025 13:43:07 +1200 Subject: [PATCH 03/24] Add category groups to working region --- src/segmentationstitcher/annotation.py | 6 ++++ src/segmentationstitcher/segment.py | 46 ++++++++++++++++++++------ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 220afbd..68bfbb3 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -24,6 +24,12 @@ def get_group_name(self): """ return '.' + self.name + def get_lower_name(self): + """ + :return: Lower case category name. + """ + return self.name.lower() + def is_connectable(self): return self in (self.INDEPENDENT_NETWORK, self.NETWORK_GROUP_1, self.NETWORK_GROUP_2) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 61ecd73..eb34c4a 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -73,6 +73,12 @@ def __init__(self, name, segmentation_file_name, root_region): self._working_best_fit_line_orientation = find_or_create_field_finite_element( self._working_fieldmodule, "best_fit_line_orientation", 9) self._working_end_group = find_or_create_field_group(self._working_fieldmodule, "active_ends") + for category in AnnotationCategory: + if category.is_connectable(): + group_name = category.get_group_name() + group = self._working_fieldmodule.createFieldGroup() + group.setName(group_name) + group.setManaged(True) self._element_node_ids, self._node_element_ids = self._get_element_node_maps() self._end_node_ids = self._get_end_node_ids() self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) @@ -531,6 +537,12 @@ def get_working_region(self): """ return self._working_region + def get_working_fieldmodule(self): + """ + :return: Zinc Fieldmodule for working region. + """ + return self._working_fieldmodule + def get_working_end_group(self): """ Get group from working region containing connectable end points in segment. @@ -538,6 +550,15 @@ def get_working_end_group(self): """ return self._working_end_group + def get_working_category_group(self, category): + """ + Get group from working region containing connectable end points in segment and in the supplied category. + :param category: AnnotationCategory. + :return: Zinc group containing connectable end points in that category. + """ + category_group = self._working_fieldmodule.findFieldByName(category.get_group_name()).castGroup() + return category_group if category_group.isValid() else None + def update_annotation_category(self, annotation, old_category=AnnotationCategory.EXCLUDE): """ Ensures special groups representing annotion categories contain via addition or removal the @@ -581,27 +602,32 @@ def _update_working_end_group(self): """ Ensure working end group contains all connectable end points. """ - connectable_node_groups = [] - for category in AnnotationCategory: - if category.is_connectable(): - category_group = self.get_category_group(category) - node_group = category_group.getNodesetGroup(self._raw_nodes) - if node_group.isValid() and (node_group.getSize() > 0): - connectable_node_groups.append(node_group) + # list of (raw_category_node_group, working_category_node_group) capable of connections with ChangeManager(self._working_fieldmodule): - self._working_end_group.clear() working_datapoints = \ self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + connectable_node_groups = [] + for category in AnnotationCategory: + if category.is_connectable(): + category_group = self.get_category_group(category) + category_node_group = category_group.getNodesetGroup(self._raw_nodes) + working_category_group = self.get_working_category_group(category) + if category_node_group.isValid() and working_category_group: + working_category_group.clear() + working_category_node_group = working_category_group.getOrCreateNodesetGroup(working_datapoints) + connectable_node_groups.append((category_node_group, working_category_node_group)) + self._working_end_group.clear() working_node_group = self._working_end_group.getOrCreateNodesetGroup(working_datapoints) working_nodeiterator = working_datapoints.createNodeiterator() working_node = working_nodeiterator.next() while working_node.isValid(): node_identifier = working_node.getIdentifier() raw_node = self._raw_nodes.findNodeByIdentifier(node_identifier) - for node_group in connectable_node_groups: + for node_group, working_category_node_group in connectable_node_groups: if node_group.containsNode(raw_node): working_node_group.addNode(working_node) - break; + working_category_node_group.addNode(working_node) + break working_node = working_nodeiterator.next() def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): From d80546bd1e762ee00322d8307327124af85ba62f Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 4 Sep 2025 21:26:24 +1200 Subject: [PATCH 04/24] Store mean segment length Handle segments with zero range --- src/segmentationstitcher/stitcher.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 0173865..04fa31a 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -63,6 +63,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo unused_endpoints_file_name_stems = [Path(file_path).stem for file_path in unused_endpoints_file_names] with HierarchicalChangeManager(self._root_region): max_range_reciprocal_sum = 0.0 + zero_range_segments_count = 0 for segmentation_file_name in self._segmentation_file_names: file_path = Path(segmentation_file_name) name = file_path.name @@ -76,7 +77,11 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo for ix in reversed(used_endpoints_file_indexes): del unused_endpoints_file_name_stems[ix] del unused_endpoints_file_names[ix] - max_range_reciprocal_sum += 1.0 / segment.get_max_range() + segment_max_range = segment.get_max_range() + if segment_max_range > 0.0: + max_range_reciprocal_sum += 1.0 / segment_max_range + else: + zero_range_segments_count += 1 self._segments.append(segment) segment_annotations = region_get_annotations( segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, @@ -106,9 +111,13 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo (annotation.get_name() != "marker")): # print("Exclude general annotation", annotation.get_name(), "with no term") annotation.set_category(AnnotationCategory.EXCLUDE) + self._mean_segment_length = 1.0 if self._segments: + if max_range_reciprocal_sum > 0.0: + self._mean_segment_length = ( + (len(self._segments) - zero_range_segments_count) / max_range_reciprocal_sum) with HierarchicalChangeManager(self._root_region): - self._max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum + self._max_distance = 0.25 * self._mean_segment_length for segment in self._segments: segment.create_end_point_directions(self._annotations, self._max_distance) segment.update_annotation_category_groups(self._annotations) @@ -280,6 +289,13 @@ def get_root_region(self): def get_segments(self): return self._segments + def get_mean_segment_length(self): + """ + Get representative mean segment length for sizing graphics and tolerances. + :return: Real length > 0.0. + """ + return self._mean_segment_length + def get_version(self): """ :return: Stitcher version number string "major#.minor#.patch#" From c8afb6a3c916ddf9be44898a9697e1a918978d53 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 9 Sep 2025 12:54:21 +1200 Subject: [PATCH 05/24] Auto align either segment in connection --- src/segmentationstitcher/connection.py | 68 +++++++++++++++----------- tests/test_vagus.py | 25 +++++++--- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 446a24b..2bce063 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -13,6 +13,10 @@ from scipy.optimize import minimize from segmentationstitcher.annotation import AnnotationCategory import math +import logging + + +logger = logging.getLogger(__name__) class Connection: @@ -168,11 +172,20 @@ def get_coordinates_range(self): nodes = self._region.getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) return evaluate_field_nodeset_range(self._coordinates, nodes) - def optimise_transformation(self): + def auto_align_segment(self, dependent_segment_index): """ - Optimise transformation of second segment to align with position and direction of nearest points between - both segments. + Optimise transformation of one connected segment relative to the other, by getting best fit + alignment and connection between nearest end points between them. + :param dependent_segment_index: Index of segment to optimise transformation of. """ + segments_count = len(self._segments) + if (dependent_segment_index < 0) or (dependent_segment_index >= segments_count): + logger.error("auto_align_segment. Segment index " + str(dependent_segment_index) + " out of range") + return + if segments_count != 2: + logger.error("auto_align_segment. Not implemented for " + str(segments_count) + " segments") + return + fixed_segment_index = 1 if (dependent_segment_index == 0) else 0 segment_end_point_data = [] initial_rotation = [] initial_rotation_matrix = [] @@ -197,7 +210,6 @@ def optimise_transformation(self): mean_coordinates = [] mean_directions = [] for s, segment in enumerate(self._segments): - total_weight = 0.0 distances = [] max_distance = None for node_id0, transformed_coordinates0, _, _, _, annotation0 in segment_end_point_data[s]: @@ -220,9 +232,7 @@ def optimise_transformation(self): nearby_proportion = 0.1 # proportion of max distance under which distance weighting is the same nearby_distance = max_distance * nearby_proportion sum_coordinates = [0.0, 0.0, 0.0] - sum_transformed_coordinates = [0.0, 0.0, 0.0] sum_direction = [0.0, 0.0, 0.0] - sum_transformed_direction = [0.0, 0.0, 0.0] total_weight = 0.0 for p, data in enumerate(segment_end_point_data[s]): distance = distances[p] @@ -250,7 +260,7 @@ def optimise_transformation(self): mean_transformed_coordinates.append(x) unit_mean_transformed_directions.append(normalize(d)) - # optimise transformation of second segment so mean coordinates and directions coincide + # optimise transformation of dependent segment so mean coordinates and directions coincide def rotation_objective(trial_rotation, *args): target_direction, source_direction, target_side_direction, source_side_direction = args @@ -259,22 +269,24 @@ def rotation_objective(trial_rotation, *args): trans_side_direction = matrix_vector_mult(trial_rotation_matrix, source_side_direction) return dot(trans_direction, target_direction) + dot(target_side_direction, trans_side_direction) - # note the result is dependent on the initial position, but final optimisation should reduced effect + # note the result is dependent on the initial position, but final optimisation should reduce effect # get a side direction to minimise the unconstrained twist from the current direction + dependent_segment = self._segments[dependent_segment_index] + dependent_segment_end_point_data = segment_end_point_data[dependent_segment_index] axis = [1.0, 0.0, 0.0] - if dot(unit_mean_transformed_directions[0], axis) < 0.1: + if dot(unit_mean_transformed_directions[fixed_segment_index], axis) < 0.1: axis = [0.0, 1.0, 0.0] - target_side = normalize(cross(unit_mean_transformed_directions[0], axis)) + target_side = normalize(cross(unit_mean_transformed_directions[fixed_segment_index], axis)) source_side = normalize( - cross(cross(target_side, unit_mean_transformed_directions[1]), unit_mean_transformed_directions[1])) - if initial_rotation_matrix[1]: + cross(cross(target_side, unit_mean_transformed_directions[dependent_segment_index]), unit_mean_transformed_directions[dependent_segment_index])) + if initial_rotation_matrix[dependent_segment_index]: transformed_source_side = source_side - inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[1]) + inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[dependent_segment_index]) source_side = matrix_vector_mult(inverse_rotation_matrix, transformed_source_side) - initial_angles = [math.radians(angle_degrees) for angle_degrees in self._segments[1].get_rotation()] + initial_angles = [math.radians(angle_degrees) for angle_degrees in dependent_segment.get_rotation()] side_weight = 0.01 # so side has only a small effect on objective res = minimize(rotation_objective, initial_angles, - args=(unit_mean_transformed_directions[0], unit_mean_directions[1], + args=(unit_mean_transformed_directions[fixed_segment_index], unit_mean_directions[dependent_segment_index], mult(target_side, side_weight), mult(source_side, side_weight)), method='Nelder-Mead', tol=0.001) if not res.success: @@ -282,28 +294,28 @@ def rotation_objective(trial_rotation, *args): return rotation = [math.degrees(angle_radians) for angle_radians in res.x] rotation_matrix = euler_to_rotation_matrix(res.x) - rotated_mean_coordinates = matrix_vector_mult(rotation_matrix, mean_coordinates[1]) - translation = sub(mean_transformed_coordinates[0], rotated_mean_coordinates) + rotated_mean_coordinates = matrix_vector_mult(rotation_matrix, mean_coordinates[dependent_segment_index]) + translation = sub(mean_transformed_coordinates[fixed_segment_index], rotated_mean_coordinates) # update transformed_coordinates in second segment data - for p, data in enumerate(segment_end_point_data[1]): + for p, data in enumerate(dependent_segment_end_point_data): coordinates = data[2] transformed_coordinates = add(matrix_vector_mult(rotation_matrix, coordinates), translation) - segment_end_point_data[1][p] = (data[0], transformed_coordinates, data[2], data[3], data[4], data[5]) - unit_transformed_direction = matrix_vector_mult(rotation_matrix, unit_mean_directions[1]) + dependent_segment_end_point_data[p] = (data[0], transformed_coordinates, data[2], data[3], data[4], data[5]) + unit_transformed_direction = matrix_vector_mult(rotation_matrix, unit_mean_directions[dependent_segment_index]) # translate along unit_transformed_direction so no overlap between points total_overlap = 0.0 for s, segment in enumerate(self._segments): max_overlap = 0.0 for data in segment_end_point_data[s]: - overlap = dot(sub(data[1], mean_transformed_coordinates[0]), unit_transformed_direction) - if s == 0: + overlap = dot(sub(data[1], mean_transformed_coordinates[fixed_segment_index]), unit_transformed_direction) + if s == fixed_segment_index: overlap = -overlap if overlap > max_overlap: max_overlap = overlap total_overlap += max_overlap translation = sub(translation, mult(unit_transformed_direction, total_overlap)) - self._segments[1].set_rotation(rotation, notify=False) - self._segments[1].set_translation(translation, notify=False) + dependent_segment.set_rotation(rotation, notify=False) + dependent_segment.set_translation(translation, notify=False) # GRC temp # score = self.build_links(build_link_objects=False) @@ -313,8 +325,8 @@ def rotation_objective(trial_rotation, *args): def links_objective(rotation_translation, *args): rotation = list(rotation_translation[:3]) translation = list(rotation_translation[3:]) - self._segments[1].set_rotation(rotation, notify=False) - self._segments[1].set_translation(translation, notify=False) + dependent_segment.set_rotation(rotation, notify=False) + dependent_segment.set_translation(translation, notify=False) score = self.build_links(build_link_objects=False) # print("rotation", rotation, "translation", translation, "score", score) return score @@ -329,9 +341,9 @@ def links_objective(rotation_translation, *args): return rotation = list(res.x[:3]) translation = list(res.x[3:]) - self._segments[1].set_rotation(rotation, notify=False) + dependent_segment.set_rotation(rotation, notify=False) # this will invoke build_links: - self._segments[1].set_translation(translation) + dependent_segment.set_translation(translation) def build_links(self, build_link_objects=True): """ diff --git a/tests/test_vagus.py b/tests/test_vagus.py index f86d991..f3e132d 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -130,15 +130,17 @@ def test_align_stitch_vagus1(self): connection01 = stitcher.create_connection([segments[0], segments[1]]) connection12 = stitcher.create_connection([segments[1], segments[2]]) - connection01.optimise_transformation() - assertAlmostEqualList(self, [-2.894576, -5.574263, -63.93093], segments[1].get_rotation(), delta=TOL) - assertAlmostEqualList(self, [4.88866, -0.01213587, 0.01357185], segments[1].get_translation(), delta=TOL) + connection01.auto_align_segment(1) + rotation = segments[1].get_rotation() + translation = segments[1].get_translation() + assertAlmostEqualList(self, [-2.894576, -5.574263, -63.93093], rotation, delta=TOL) + assertAlmostEqualList(self, [4.88866, -0.01213587, 0.01357185], translation, delta=TOL) linked_nodes01 = connection01.get_linked_nodes() self.assertEqual(linked_nodes01, { "Fascicle": [[22, 28], [35, 12], [40, 23]], "left vagus X nerve trunk": [[11, 1]]}) - connection12.optimise_transformation() + connection12.auto_align_segment(1) assertAlmostEqualList(self, [-4.919549, -2.280625, -13.52467], segments[2].get_rotation(), delta=TOL) assertAlmostEqualList(self, [9.543171, -0.3494296, 0.03930248], segments[2].get_translation(), delta=TOL) linked_nodes12 = connection12.get_linked_nodes() @@ -146,6 +148,17 @@ def test_align_stitch_vagus1(self): "Fascicle": [[22, 15], [38, 25]], "left vagus X nerve trunk": [[11, 1]]}) + # now align first segment relative to second + connection01.auto_align_segment(0) + rotation = segments[0].get_rotation() + translation = segments[0].get_translation() + assertAlmostEqualList(self, [0.5904670871359933, 0.327008456961372, 0.08272009867790592], rotation, delta=TOL) + assertAlmostEqualList(self, [0.0010985770233687066, -0.05096474638998973, 0.02847203766782994], translation, delta=TOL) + linked_nodes01 = connection01.get_linked_nodes() + self.assertEqual(linked_nodes01, { + "Fascicle": [[22, 28], [35, 12], [40, 23]], + "left vagus X nerve trunk": [[11, 1]]}) + output_region = stitcher.get_root_region().createRegion() stitcher.stitch(output_region) self.assertEqual("1.0.0", stitcher.get_version()) @@ -156,8 +169,8 @@ def test_align_stitch_vagus1(self): datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) mesh = fieldmodule.findMeshByDimension(1) minimums, maximums = evaluate_field_nodeset_range(coordinates, nodes) - assertAlmostEqualList(self, [0.04674543239403558, -1.5276719288528786, -0.5804178855490847], minimums, delta=TOL) - assertAlmostEqualList(self, [13.538987060134247, 1.11238124203403, 0.6470665850902932], maximums, delta=TOL) + assertAlmostEqualList(self, [0.04749509590306315, -1.5276719288528786, -0.5661785379561856], minimums, delta=TOL) + assertAlmostEqualList(self, [13.538987060134247, 1.102601773020926, 0.6470665850902932], maximums, delta=TOL) fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() self.assertTrue(fascicle.isValid()) From 1acbbb97072467cf74ba72d492e95c900733cd18 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 16 Sep 2025 12:19:03 +1200 Subject: [PATCH 06/24] Add new linking and alignment --- src/segmentationstitcher/connection.py | 540 ++++++++++++++++++------- src/segmentationstitcher/segment.py | 10 +- src/segmentationstitcher/stitcher.py | 9 +- 3 files changed, 398 insertions(+), 161 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 2bce063..62a590f 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -2,7 +2,8 @@ A connection between segments in the segmentation data. """ from cmlibs.maths.vectorops import ( - add, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_vector_mult, mult, normalize, sub) + add, axis_angle_to_rotation_matrix, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_mult, + matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, sub) from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range @@ -50,7 +51,7 @@ def __init__(self, segments, root_region, annotations, max_distance): group = fieldmodule.createFieldGroup() group.setName(group_name) group.setManaged(True) - self._linked_nodes = {} # dict: annotation name --> list of [segment0_node_identifier, segment1_node_identifier]] + self._annotation_links = {} # dict: annotation name --> list of {'locked': bool, 'node identifiers': list} for segment in self._segments: segment.add_transformation_change_callback(self._segment_transformation_change) @@ -67,14 +68,51 @@ def decode_settings(self, settings_in: dict): Update segment settings from JSON dict containing serialised settings. :param settings_in: Dictionary of settings as produced by encode_settings(). """ - settings_name = self._separator.join(settings_in["segments"]) + settings_name = self._separator.join(settings_in['segments']) assert settings_name == self._name # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) - linked_nodes = settings.get("linked nodes") - if isinstance(linked_nodes, dict): - self._linked_nodes = linked_nodes + # migrate from previous 'linked nodes' which had a list of node identifiers + linked_nodes = settings.get('linked nodes') + if linked_nodes is not None: + # migrate to new annotation links + annotation_links = {} + annotation_names = list(linked_nodes.keys()) + for annotation_name in annotation_names: + links = linked_nodes[annotation_name] + new_links = [] + if isinstance(links[0], list): + for node_identifiers in links: + new_links.append({'locked': False, 'node identifiers': node_identifiers}) + annotation_links[annotation_name] = new_links + del settings['linked nodes'] + settings['annotation links'] = annotation_links + else: + annotation_links = settings['annotation links'] + # check nodes exist for all links, otherwise remove stale links + segment_nodes = [segment.get_raw_region().getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + for segment in self._segments] + annotation_names = list(annotation_links.keys()) + for annotation_name in annotation_names: + links = annotation_links[annotation_name] + invalid_indexes = [] + for i, link in enumerate(links): + node_identifiers = link['node identifiers'] + invalid_link = False + for s, node_identifier in enumerate(node_identifiers): + if not segment_nodes[s].findNodeByIdentifier(node_identifier).isValid(): + logger.warning('Stitcher connection ' + self._name + ' annotation ' + annotation_name + + ' link missing node ' + str(node_identifier) + ' from segment ' + str(s + 1) + + '. Removing link.') + invalid_link = True + if invalid_link: + invalid_indexes.append(i) + for i in reversed(invalid_indexes): + links.pop(i) + if len(links) == 0: + del annotation_links[annotation_name] + self._annotation_links = annotation_links def encode_settings(self) -> dict: """ @@ -82,12 +120,12 @@ def encode_settings(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ settings = { - "segments": [segment.get_name() for segment in self._segments], - "linked nodes": self._linked_nodes + 'segments': [segment.get_name() for segment in self._segments], + 'annotation links': self._annotation_links } return settings - def printLog(self): + def printZincLog(self): logger = self._region.getContext().getLogger() for index in range(logger.getNumberOfMessages()): print(logger.getMessageTextAtIndex(index)) @@ -136,23 +174,29 @@ def _segment_transformation_change(self, segment): self.build_links() self.update_annotation_category_groups(self._annotations) - def add_linked_nodes(self, annotation, node_id0, node_id1): + def add_linked_nodes(self, annotation, node_id0, node_id1, locked=False): """ :param annotation: Annotation to use for link. :param node_id0: Node identifier to link from segment[0]. - :param node_id1: Node identifier to link from segment[1]. + :param node_id1: Node identifier to link from segment[1]. + :param locked: True if link is """ annotation_name = annotation.get_name() - annotation_linked_nodes = self._linked_nodes.get(annotation_name) - if not annotation_linked_nodes: - self._linked_nodes[annotation_name] = annotation_linked_nodes = [] - annotation_linked_nodes.append([node_id0, node_id1]) - - def get_linked_nodes(self): + links = self._annotation_links.get(annotation_name) + if not links: + # first inserts at the end + self._annotation_links[annotation_name] = links = [] + # then reinsert any other names which should be after name + for name in list(self._annotation_links.keys()): + if name > annotation_name: + self._annotation_links[name] = self._annotation_links.pop(name) + links.append({'locked': locked, 'node identifiers': [node_id0, node_id1]}) + + def get_annotation_links(self): """ :return: Map annotation name -> list of paired nodes from segment1 and segment2 """ - return self._linked_nodes + return self._annotation_links def get_coordinates_midpoint(self): """ @@ -172,11 +216,12 @@ def get_coordinates_range(self): nodes = self._region.getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) return evaluate_field_nodeset_range(self._coordinates, nodes) - def auto_align_segment(self, dependent_segment_index): + def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): """ Optimise transformation of one connected segment relative to the other, by getting best fit alignment and connection between nearest end points between them. :param dependent_segment_index: Index of segment to optimise transformation of. + :param minimum_gap: Minimum gap between aligned segments. """ segments_count = len(self._segments) if (dependent_segment_index < 0) or (dependent_segment_index >= segments_count): @@ -186,79 +231,116 @@ def auto_align_segment(self, dependent_segment_index): logger.error("auto_align_segment. Not implemented for " + str(segments_count) + " segments") return fixed_segment_index = 1 if (dependent_segment_index == 0) else 0 - segment_end_point_data = [] - initial_rotation = [] - initial_rotation_matrix = [] - for s, segment in enumerate(self._segments): - translation = segment.get_translation() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] - initial_rotation.append(rotation) - rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None - initial_rotation_matrix.append(rotation_matrix) - end_point_data = [] - raw_end_point_data = segment.get_end_point_data() - for node_id, data in raw_end_point_data.items(): - coordinates, direction, radius, annotation = data - transformed_coordinates = coordinates - if (annotation is not None) and annotation.get_category().is_connectable(): - if rotation_matrix: - transformed_coordinates = matrix_vector_mult(rotation_matrix, transformed_coordinates) - transformed_coordinates = add(transformed_coordinates, translation) - end_point_data.append((node_id, transformed_coordinates, coordinates, direction, radius, annotation)) - segment_end_point_data.append(end_point_data) - mean_coordinates = [] - mean_directions = [] - for s, segment in enumerate(self._segments): - distances = [] - max_distance = None - for node_id0, transformed_coordinates0, _, _, _, annotation0 in segment_end_point_data[s]: - category0 = annotation0.get_category() - distance = None - for node_id1, transformed_coordinates1, _, _, _, annotation1 in segment_end_point_data[s - 1]: - category1 = annotation1.get_category() - if (category0 != category1) or ( - (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): - continue # end points are not allowed to join - tmp_distance = magnitude(sub(transformed_coordinates0, transformed_coordinates1)) - if (distance is None) or (tmp_distance < distance): - distance = tmp_distance - if (distance is not None) and ((max_distance is None) or (distance > max_distance)): - max_distance = distance - distances.append(distance) - if max_distance is None: - print("Segmentation Stitcher. No connectable points to optimise transformation with") - return - nearby_proportion = 0.1 # proportion of max distance under which distance weighting is the same - nearby_distance = max_distance * nearby_proportion - sum_coordinates = [0.0, 0.0, 0.0] - sum_direction = [0.0, 0.0, 0.0] - total_weight = 0.0 - for p, data in enumerate(segment_end_point_data[s]): - distance = distances[p] - if distance is None: - continue - _, transformed_coordinates, coordinates, direction, radius, annotation = data - if distance < nearby_distance: - distance = nearby_distance - weight = annotation.get_align_weight() * radius * radius / (distance * distance) - sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) - sum_direction = add(sum_direction, mult(direction, weight)) - total_weight += weight - mean_coordinates.append(div(sum_coordinates, total_weight)) - mean_directions.append(div(sum_direction, total_weight)) - unit_mean_directions = [normalize(v) for v in mean_directions] - mean_transformed_coordinates = [] - unit_mean_transformed_directions = [] - for s, segment in enumerate(self._segments): - x = mean_coordinates[s] - d = mean_directions[s] - if initial_rotation_matrix[s]: - x = matrix_vector_mult(initial_rotation_matrix[s], x) - d = matrix_vector_mult(initial_rotation_matrix[s], d) - x = add(x, segment.get_translation()) - mean_transformed_coordinates.append(x) - unit_mean_transformed_directions.append(normalize(d)) + number_of_iterations = 2 # so second iteration starts reliably close + for iter in range(number_of_iterations): + # get segment transformations and apply to end points + segment_end_point_data = [] + initial_rotation_matrix = [] + for s, segment in enumerate(self._segments): + translation = segment.get_translation() + rotation_radians = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation_matrix = euler_to_rotation_matrix(rotation_radians) + initial_rotation_matrix.append(rotation_matrix) + end_point_data = [] + raw_end_point_data = segment.get_end_point_data() + for node_id, data in raw_end_point_data.items(): + coordinates, direction, radius, annotation = data + transformed_coordinates = coordinates + if (annotation is not None) and annotation.get_category().is_connectable(): + if rotation_matrix: + transformed_coordinates = matrix_vector_mult(rotation_matrix, transformed_coordinates) + transformed_coordinates = add(transformed_coordinates, translation) + end_point_data.append((node_id, transformed_coordinates, coordinates, direction, radius, annotation)) + segment_end_point_data.append(end_point_data) + + # get weighted mean end coordinates and directions of segment end points weighted by closeness to other segment + mean_end_locations = [] + mean_end_directions = [] # unit mean untransformed directions + far_proportion = 0.5 # proportion of max_distance above which distance weighting is zero + far_distance = self._max_distance * far_proportion + minimum_gap + for s, segment in enumerate(self._segments): + distances = [] # min transformed distance from end points of this segment to linkable end points in other + max_distance = None + remove_end_point_indexes = [] + for index0, data0 in enumerate(segment_end_point_data[s]): + node_id0, transformed_coordinates0, _, _, _, annotation0 = data0 + category0 = annotation0.get_category() + distance = None + for node_id1, transformed_coordinates1, _, _, _, annotation1 in segment_end_point_data[s - 1]: + category1 = annotation1.get_category() + if (category0 != category1) or ( + (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): + continue # end points are not allowed to join + tmp_distance = magnitude(sub(transformed_coordinates0, transformed_coordinates1)) + if (tmp_distance < far_distance) and ((distance is None) or (tmp_distance < distance)): + distance = tmp_distance + if (distance is not None) and ((max_distance is None) or (distance > max_distance)): + max_distance = distance + if distance is None: + remove_end_point_indexes.append(index0) + else: + distances.append(distance) # can be None + if max_distance is None: + logger.warning("Segmentation Stitcher. No linkable points to optimise transformation with") + return + for ix in reversed(remove_end_point_indexes): + del segment_end_point_data[s][ix] + sum_coordinates = [0.0, 0.0, 0.0] + sum_direction = [0.0, 0.0, 0.0] + total_weight = 0.0 + for distance, data in zip(distances, segment_end_point_data[s]): + _, _, coordinates, direction, radius, annotation = data + weight = annotation.get_align_weight() * radius * radius * (far_distance - distance) + sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) + sum_direction = add(sum_direction, mult(direction, weight)) + total_weight += weight + mean_end_direction = normalize(sum_direction) + mean_end_directions.append(mean_end_direction) + mean_coordinates = div(sum_coordinates, total_weight) + # get mean_end_locations at furthermost point in mean_end_direction + mean_projection = dot(mean_coordinates, mean_end_direction) + max_projection = mean_projection + for data in segment_end_point_data[s]: + coordinates = data[2] + projection = dot(coordinates, mean_end_direction) + if projection > max_projection: + max_projection = projection + # add half minimum gap to each side + offset = max_projection - mean_projection + 0.5 * minimum_gap + mean_end_locations.append(add(mean_coordinates, mult(mean_end_direction, offset))) + + # get angle axis transformation of dependent direction onto fixed direction + rotated_mean_end_directions = [ + matrix_vector_mult(initial_rotation_matrix[s], mean_end_directions[s]) for s in range(2)] + # need to reverse fixed direction so inline + axis = cross(rotated_mean_end_directions[dependent_segment_index], + [-d for d in rotated_mean_end_directions[fixed_segment_index]]) + mag_axis = magnitude(axis) + rotation_matrix = initial_rotation_matrix[dependent_segment_index] + if mag_axis > 1.0E-6: + axis = div(axis, mag_axis) + theta = math.asin(mag_axis) + axis_angle_rotation_matrix = axis_angle_to_rotation_matrix(axis, theta) + rotation_matrix = matrix_mult(axis_angle_rotation_matrix, rotation_matrix) + rotation_radians = rotation_matrix_to_euler(rotation_matrix) + rotation = [math.degrees(angle_radians) for angle_radians in rotation_radians] + else: + rotation = self._segments[dependent_segment_index].get_rotation() + dependent_rotated_end_location = matrix_vector_mult( + rotation_matrix, mean_end_locations[dependent_segment_index]) + fixed_rotated_end_location = add( + matrix_vector_mult(initial_rotation_matrix[fixed_segment_index], mean_end_locations[fixed_segment_index]), + self._segments[fixed_segment_index].get_translation()) + translation = sub(fixed_rotated_end_location, dependent_rotated_end_location) + + # first part: + dependent_segment = self._segments[dependent_segment_index] + dependent_segment.set_rotation(rotation, notify=False) + dependent_segment.set_translation(translation, notify=False) + + dependent_segment.set_translation(translation) # GRC temporary to force notification + return # optimise transformation of dependent segment so mean coordinates and directions coincide @@ -271,14 +353,13 @@ def rotation_objective(trial_rotation, *args): # note the result is dependent on the initial position, but final optimisation should reduce effect # get a side direction to minimise the unconstrained twist from the current direction - dependent_segment = self._segments[dependent_segment_index] dependent_segment_end_point_data = segment_end_point_data[dependent_segment_index] axis = [1.0, 0.0, 0.0] - if dot(unit_mean_transformed_directions[fixed_segment_index], axis) < 0.1: + if dot(transformed_mean_directions[fixed_segment_index], axis) < 0.1: axis = [0.0, 1.0, 0.0] - target_side = normalize(cross(unit_mean_transformed_directions[fixed_segment_index], axis)) + target_side = normalize(cross(transformed_mean_directions[fixed_segment_index], axis)) source_side = normalize( - cross(cross(target_side, unit_mean_transformed_directions[dependent_segment_index]), unit_mean_transformed_directions[dependent_segment_index])) + cross(cross(target_side, transformed_mean_directions[dependent_segment_index]), transformed_mean_directions[dependent_segment_index])) if initial_rotation_matrix[dependent_segment_index]: transformed_source_side = source_side inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[dependent_segment_index]) @@ -286,11 +367,11 @@ def rotation_objective(trial_rotation, *args): initial_angles = [math.radians(angle_degrees) for angle_degrees in dependent_segment.get_rotation()] side_weight = 0.01 # so side has only a small effect on objective res = minimize(rotation_objective, initial_angles, - args=(unit_mean_transformed_directions[fixed_segment_index], unit_mean_directions[dependent_segment_index], + args=(transformed_mean_directions[fixed_segment_index], unit_mean_directions[dependent_segment_index], mult(target_side, side_weight), mult(source_side, side_weight)), method='Nelder-Mead', tol=0.001) if not res.success: - print("Segmentation Stitcher. Could not optimise initial rotation") + logger.warning("Segmentation Stitcher. Could not optimise initial rotation") return rotation = [math.degrees(angle_radians) for angle_radians in res.x] rotation_matrix = euler_to_rotation_matrix(res.x) @@ -315,7 +396,11 @@ def rotation_objective(trial_rotation, *args): total_overlap += max_overlap translation = sub(translation, mult(unit_transformed_direction, total_overlap)) dependent_segment.set_rotation(rotation, notify=False) - dependent_segment.set_translation(translation, notify=False) + # dependent_segment.set_translation(translation, notify=False) + + # GRC rotation only: + dependent_segment.set_translation(translation) + return # GRC temp # score = self.build_links(build_link_objects=False) @@ -337,7 +422,7 @@ def links_objective(rotation_translation, *args): # method='Nelder-Mead' res = minimize(links_objective, initial_parameters, method='Powell') # , tol=TOL) if not res.success: - print("Segmentation Stitcher. Could not optimise final rotation and translation") + logger.warning("Segmentation Stitcher. Could not optimise final rotation and translation") return rotation = list(res.x[:3]) translation = list(res.x[3:]) @@ -352,10 +437,19 @@ def build_links(self, build_link_objects=True): :return: Total link score. """ total_score = 0.0 - remaining_radius_factor = 0.25 - self._linked_nodes = {} + + # remember locked tuples of linked nodes to re-attach in algorithm below + locked_node_identifiers = set() + annotation_names = list(self._annotation_links.keys()) + for annotation_name in annotation_names: + for link in self._annotation_links[annotation_name]: + if link['locked']: + locked_node_identifiers.add(tuple(link['node identifiers'])) + self._annotation_links = {} + # filter, transform and sort end point data from largest to smallest radius segment_sorted_end_point_data = [] + min_area = None for s, segment in enumerate(self._segments): translation = segment.get_translation() rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] @@ -365,76 +459,209 @@ def build_links(self, build_link_objects=True): end_point_data = segment.get_end_point_data() for node_id, data in end_point_data.items(): coordinates, direction, radius, annotation = data + area = math.pi * radius * radius + if (min_area is None) or (area < min_area): + min_area = area if (annotation is not None) and annotation.get_category().is_connectable(): if rotation_matrix: coordinates = matrix_vector_mult(rotation_matrix, coordinates) direction = matrix_vector_mult(rotation_matrix, direction) coordinates = add(coordinates, translation) - for i, data in enumerate(sorted_end_point_data): - if radius > data[3]: + if area > data[3]: break else: i = len(sorted_end_point_data) - sorted_end_point_data.insert(i, (node_id, coordinates, direction, radius, annotation)) + sorted_end_point_data.insert(i, [node_id, coordinates, direction, area, annotation]) segment_sorted_end_point_data.append(sorted_end_point_data) sorted_end_point_data0 = segment_sorted_end_point_data[0] sorted_end_point_data1 = segment_sorted_end_point_data[1] - - while len(sorted_end_point_data0): - end_point_data0 = sorted_end_point_data0[0] - node_id0, coordinates0, direction0, radius0, annotation0 = end_point_data0 - category0 = annotation0.get_category() - best_index1 = None - lowest_score = 0.0 - weight = None + min_area *= 0.5 # so reliably below smallest end point area + + # make 2D array of base score independent of area and exclusivity + base_scores0 = [] # index over segment 0 endpoints, then segment 1 + max_mag_delta_coordinates = 0.5 * self._max_distance + # below this proportion of max_mag_delta_coordinates the closeness score is the same: + min_relative_distance = 0.01 + worst_base_score = 2.0 + for index0, end_point_data0 in enumerate(sorted_end_point_data0): + node_id0, coordinates0, direction0, area0, annotation0 = end_point_data0 + scores1 = [] for index1, end_point_data1 in enumerate(sorted_end_point_data1): - node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 - category1 = annotation1.get_category() - # inter-segment links are only to the same annotation; links within category will be done separately + node_id1, coordinates1, direction1, area1, annotation1 = end_point_data1 + # presently only allow links between same annotation even within network group if annotation0 != annotation1: - continue # end points are not allowed to join + scores1.append(worst_base_score) + continue # end points have different annotation direction_score = math.fabs(1.0 + dot(direction0, direction1)) + # direction_score = -dot(direction0, direction1) if direction_score > 0.5: # arbitrary factor + scores1.append(worst_base_score) continue # end points are not pointing towards each other delta_coordinates = sub(coordinates1, coordinates0) mag_delta_coordinates = magnitude(delta_coordinates) - tdistance = dot(direction0, delta_coordinates) - ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) - if mag_delta_coordinates > (0.5 * self._max_distance): - continue # point is too far away - distance_score = ((tdistance * tdistance + 100.0 * ndistance * ndistance) / - (self._max_distance * self._max_distance)) - tfactor = math.exp(-1000.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor - penetration_distance_score = ((tfactor * tdistance * tdistance) / - (self._max_distance * self._max_distance)) - delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale - radius_score = delta_radius * delta_radius - score = radius0 * (10.0 * direction_score + distance_score + radius_score) - if (best_index1 is None) or (score < lowest_score): - best_index1 = index1 - weight = 0.5 * (annotation0.get_align_weight() + annotation1.get_align_weight()) - lowest_score = score + penetration_distance_score - if best_index1 is not None: - # if category0 != AnnotationCategory.NETWORK_GROUP_1: - total_score += weight * lowest_score - node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] - self.add_linked_nodes(annotation1, node_id0, node_id1) - remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) - if (radius0 > radius1) and (remaining_radius > remaining_radius_factor * radius0): - for i in range(1, len(sorted_end_point_data0)): - if remaining_radius > sorted_end_point_data0[i][3]: - break - # sorted_end_point_data0.insert(i, (node_id0, coordinates0, direction0, remaining_radius, annotation0)) - elif remaining_radius > (remaining_radius_factor * radius1): - for i in range(best_index1, len(sorted_end_point_data1)): - if remaining_radius > sorted_end_point_data1[i][3]: - break - # sorted_end_point_data1.insert(i, (node_id1, coordinates1, direction1, remaining_radius, annotation1)) - sorted_end_point_data1.pop(best_index1) - else: - total_score += radius0 * 20.0 # arbitrary factor - sorted_end_point_data0.pop(0) + if mag_delta_coordinates > max_mag_delta_coordinates: + scores1.append(worst_base_score) + continue # end point are too far away from each other + relative_distance = mag_delta_coordinates / max_mag_delta_coordinates + closeness_score = relative_distance + # if relative_distance < min_relative_distance: + # closeness_score = 1.0 + # else: + # closeness_score = min_relative_distance / relative_distance + # closeness_score = 1.0 - mag_delta_coordinates / max_mag_delta_coordinates + # closeness_score *= closeness_score # square it so more significant effect close by + # align_score is a measure of how in-line the other end points are with each point and direction + if mag_delta_coordinates == 0.0: + align_score = 0.5 + else: + t0 = dot(direction0, delta_coordinates) + n0 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t0 * t0) + t1 = dot(direction1, delta_coordinates) + n1 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t1 * t1) + # lowest align score is 0.5 + align_score = 0.5 + 0.25 * (n0 + n1) / mag_delta_coordinates + score = closeness_score * align_score * direction_score + scores1.append(score) + base_scores0.append(scores1) + + def get_minimums_ratio(scores): + """ + Get ratio of lowest / next lowest score as measure of 'only optiob' for first link to end point. + :param scores: + :return: + """ + if len(scores1) >= 2: + min1 = min2 = float('inf') + for score in scores1: + if score < min1: + min1, min2 = score, min1 + elif score < min2: + min2 = score + return min1 / min2 + return 1.0 + + base_scores1 = [[score1[index1] for score1 in base_scores0] for index1 in range(len(sorted_end_point_data1))] + exclusive_base_scores0 = [get_minimums_ratio(scores1) for scores1 in base_scores0] + exclusive_base_scores1 = [get_minimums_ratio(scores0) for scores0 in base_scores1] + links_count0 = [0.0] * len(exclusive_base_scores0) + links_count1 = [0.0] * len(exclusive_base_scores1) + + print("Start") + best_score = 1.0 + while best_score is not None: + best_score = None + best_area = 0.0 + best_indexes = None + locked = False + for index0, end_point_data0 in enumerate(sorted_end_point_data0): + node_id0 = end_point_data0[0] + area0 = end_point_data0[3] + base_scores1 = base_scores0[index0] + for index1, end_point_data1 in enumerate(sorted_end_point_data1): + base_score = base_scores1[index1] + if base_score >= worst_base_score: + continue + node_id1 = end_point_data1[0] + area1 = end_point_data1[3] + area = min(area0, area1) + indexes = (index0, index1) + node_identifiers = (node_id0, node_id1) + if node_identifiers in locked_node_identifiers: + best_area = max(min_area, area) # don't want area to get negative + best_score = base_score / best_area + best_indexes = indexes + locked_node_identifiers.remove(node_identifiers) + locked = True + break + else: + if area < min_area: + continue + score = base_score / area + # lower score for first links by factor indicating 'only option' + exclusive_base_scores = [] + if links_count0[index0] == 0: + exclusive_base_scores.append(exclusive_base_scores0[index0]) + if links_count1[index1] == 0: + exclusive_base_scores.append(exclusive_base_scores1[index1]) + if exclusive_base_scores: + score *= min(exclusive_base_scores) + if (best_score is None) or (score < best_score): + best_score = score + best_area = area + best_indexes = indexes + if locked: + break + if best_score is not None: + end_point_data0 = sorted_end_point_data0[best_indexes[0]] + node_id0 = end_point_data0[0] + annotation = end_point_data0[4] + end_point_data1 = sorted_end_point_data1[best_indexes[1]] + node_id1 = end_point_data1[0] + self.add_linked_nodes(annotation, node_id0, node_id1) + print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) + end_point_data0[3] -= best_area + end_point_data1[3] -= best_area + links_count0[best_indexes[0]] += 1 + links_count1[best_indexes[1]] += 1 + # penetration score is only used for total score used by auto align + penetration_score = 0.0 # GRC todo + total_score += best_score * best_area + penetration_score + + # while len(sorted_end_point_data0): + # end_point_data0 = sorted_end_point_data0[0] + # node_id0, coordinates0, direction0, radius0, annotation0 = end_point_data0 + # category0 = annotation0.get_category() + # best_index1 = None + # lowest_score = 0.0 + # weight = None + # for index1, end_point_data1 in enumerate(sorted_end_point_data1): + # node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 + # category1 = annotation1.get_category() + # # inter-segment links are only to the same annotation; links within category will be done separately + # if annotation0 != annotation1: + # continue # end points are not allowed to join + # direction_score = math.fabs(1.0 + dot(direction0, direction1)) + # if direction_score > 0.5: # arbitrary factor + # continue # end points are not pointing towards each other + # delta_coordinates = sub(coordinates1, coordinates0) + # mag_delta_coordinates = magnitude(delta_coordinates) + # tdistance = dot(direction0, delta_coordinates) + # ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) + # if mag_delta_coordinates > (0.5 * self._max_distance): + # continue # point is too far away + # distance_score = ((tdistance * tdistance + 100.0 * ndistance * ndistance) / + # (self._max_distance * self._max_distance)) + # tfactor = math.exp(-1000.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor + # penetration_distance_score = ((tfactor * tdistance * tdistance) / + # (self._max_distance * self._max_distance)) + # delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale + # radius_score = delta_radius * delta_radius + # score = radius0 * (10.0 * direction_score + distance_score + radius_score) + # if (best_index1 is None) or (score < lowest_score): + # best_index1 = index1 + # weight = 0.5 * (annotation0.get_align_weight() + annotation1.get_align_weight()) + # lowest_score = score + penetration_distance_score + # if best_index1 is not None: + # # if category0 != AnnotationCategory.NETWORK_GROUP_1: + # total_score += weight * lowest_score + # node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] + # self.add_linked_nodes(annotation1, node_id0, node_id1) + # remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) + # if (radius0 > radius1) and (remaining_radius > remaining_radius_factor * radius0): + # for i in range(1, len(sorted_end_point_data0)): + # if remaining_radius > sorted_end_point_data0[i][3]: + # break + # # sorted_end_point_data0.insert(i, (node_id0, coordinates0, direction0, remaining_radius, annotation0)) + # elif remaining_radius > (remaining_radius_factor * radius1): + # for i in range(best_index1, len(sorted_end_point_data1)): + # if remaining_radius > sorted_end_point_data1[i][3]: + # break + # # sorted_end_point_data1.insert(i, (node_id1, coordinates1, direction1, remaining_radius, annotation1)) + # sorted_end_point_data1.pop(best_index1) + # else: + # total_score += radius0 * 20.0 # arbitrary factor + # sorted_end_point_data0.pop(0) if build_link_objects: self._build_link_objects() @@ -484,13 +711,14 @@ def _build_link_objects(self): with (ChangeManager(fieldmodule)): mesh1d.destroyAllElements() nodes.destroyAllNodes() - for group_name, linked_nodes_list in self._linked_nodes.items(): - group = find_or_create_field_group(fieldmodule, group_name) + for annotation_name, links in self._annotation_links.items(): + group = find_or_create_field_group(fieldmodule, annotation_name) nodeset_group = group.getOrCreateNodesetGroup(nodes) mesh_group = group.getOrCreateMeshGroup(mesh1d) - for linked_nodes in linked_nodes_list: + for link in links: + node_identifiers = link['node identifiers'] cnode_ids = [None, None] - for s, snode_id in enumerate(linked_nodes): + for s, snode_id in enumerate(node_identifiers): cnode_ids[s] = snode_id_to_cnode_id[s].get(snode_id) if not cnode_ids[s]: snode = snodes[s].findNodeByIdentifier(snode_id) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index eb34c4a..2438755 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -140,6 +140,14 @@ def define_endpoints(self, endpoints_file_name): for control_point in control_points: marker_labels.append(control_point["label"]) x = control_point["position"] + if "C1L" in self._name: + x = [-1000.0 * x[1] - 4500.0, 1000.0 * x[0] + 2250.0, -1000.0 * x[2] + 11500.0] + elif any(s in self._name for s in ["T5L", "T6L"]): + x = [-1000.0 * x[1] - 9500.0, 1000.0 * x[0] + 5000.0, -1000.0 * x[2]] + elif any(s in self._name for s in ["C2L", "C4L", "T2L", "T3L", "T4L"]): + x = [-9.0 * x[1], 9.0 * x[0], -9.0 * x[2]] + else: + x = [9.0 * x[1], -9.0 * x[0], -9.0 * x[2]] marker_positions.append(x) generate_datapoints(self._raw_region, marker_positions, field_names_and_values=[("marker_name", marker_labels)], @@ -355,7 +363,7 @@ def _track_path(self, end_node_id, annotations, max_length=None): if add_path_mean_r > 0.0: aspect_ratio += add_path_length / add_path_mean_r # 2nd iteration of fit line removes outliers: - start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.5)[0:3] + start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.25)[0:3] return path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r def create_end_point_directions(self, annotations, max_distance): diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 04fa31a..8cd7a39 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -524,14 +524,15 @@ def _output_connection_elements(connection, segment_node_maps, annotation_groups connection_group = find_or_create_field_group(fieldmodule, connection.get_name()) mesh = fieldmodule.findMeshByDimension(1) connection_mesh_group = connection_group.getOrCreateMeshGroup(mesh) - linked_nodes = connection.get_linked_nodes() - for annotation_name, annotation_linked_nodes in linked_nodes.items(): + annotation_links = connection.get_annotation_links() + for annotation_name, links in annotation_links.items(): groups = annotation_groups.get(annotation_name) mesh_groups = [group.getOrCreateMeshGroup(mesh) for group in groups] mesh_groups.append(connection_mesh_group) - for segment_node_identifiers in annotation_linked_nodes: + for link in links: + node_identifiers = link['node identifiers'] element = mesh.createElement(element_identifier, elementtemplate) - element.setNodesByIdentifier(eft, [segment_node_maps[n][segment_node_identifiers[n]] for n in range(2)]) + element.setNodesByIdentifier(eft, [segment_node_maps[s][node_identifiers[s]] for s in range(2)]) for mesh_group in mesh_groups: mesh_group.addElement(element) element_identifier += 1 From 14d6dfc38f7da16a8d325785e35191652a318d83 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 23 Sep 2025 15:11:31 +1200 Subject: [PATCH 07/24] Improve building links --- src/segmentationstitcher/annotation.py | 44 +++++++++++++--- src/segmentationstitcher/connection.py | 72 +++++++++++++------------- src/segmentationstitcher/segment.py | 18 +++++-- src/segmentationstitcher/stitcher.py | 15 +++--- 4 files changed, 96 insertions(+), 53 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 68bfbb3..5d0d057 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -5,6 +5,10 @@ from cmlibs.utils.zinc.field import get_group_list from cmlibs.utils.zinc.group import group_get_highest_dimension, groups_have_same_local_contents from cmlibs.zinc.field import Field +import logging + + +logger = logging.getLogger(__name__) class AnnotationCategory(Enum): @@ -62,8 +66,8 @@ def decode_settings(self, settings_in: dict): assert (settings_in.get("name") == self._name) and (settings_in.get("term") == self._term) settings_dimension = settings_in.get("dimension") if settings_dimension != self._dimension: - print("WARNING: Segmentation Stitcher. Annotation with name", self._name, "term", self._term, - "was dimension ", settings_dimension, "in settings, is now ", self._dimension, + logger.warning("Segmentation Stitcher. Annotation with name " + self._name, " term " + str(self._term) + + "was dimension " + str(settings_dimension), "in settings, is now " + str(self._dimension) + ". Have input files changed?") settings_in["dimension"] = self._dimension # update current settings to gain new ones and override old ones @@ -152,6 +156,15 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo annotations = [] term_annotations = [] datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + # these terms have had slight mismatches with contents of url group, so explicitly matching: + segment_name = region.getParent().getName() + known_terms = { + "epineurium": "http://uri.interlex.org/base/ilx_0103892", + "left cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794142", + "right cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794141", + "left thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0787543", + "right thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0786664" + } for group in groups: # clean up name to remove case and leading/trailing whitespace name = group.getName().strip() @@ -165,6 +178,8 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo continue # empty group if lower_name.isdigit(): continue # ignore as these can never be valid annotation names + if ' 0.5: # arbitrary factor - scores1.append(worst_base_score) + dot_directions = dot(direction0, direction1) # -1.0 if perfectly pointing at each other + if dot_directions > 0.2: # arbitrary factor + base_scores1.append(worst_base_score) continue # end points are not pointing towards each other + direction_score = 0.2 + (0.8 / 1.2) * (1.0 + dot_directions) # minimum 0.2 delta_coordinates = sub(coordinates1, coordinates0) mag_delta_coordinates = magnitude(delta_coordinates) if mag_delta_coordinates > max_mag_delta_coordinates: - scores1.append(worst_base_score) + base_scores1.append(worst_base_score) continue # end point are too far away from each other relative_distance = mag_delta_coordinates / max_mag_delta_coordinates - closeness_score = relative_distance - # if relative_distance < min_relative_distance: - # closeness_score = 1.0 - # else: - # closeness_score = min_relative_distance / relative_distance - # closeness_score = 1.0 - mag_delta_coordinates / max_mag_delta_coordinates - # closeness_score *= closeness_score # square it so more significant effect close by - # align_score is a measure of how in-line the other end points are with each point and direction + closeness_score = max(relative_distance, min_relative_distance) if mag_delta_coordinates == 0.0: - align_score = 0.5 + align_score = 0.2 # minimum 0.2 else: t0 = dot(direction0, delta_coordinates) n0 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t0 * t0) t1 = dot(direction1, delta_coordinates) n1 = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - t1 * t1) - # lowest align score is 0.5 - align_score = 0.5 + 0.25 * (n0 + n1) / mag_delta_coordinates - score = closeness_score * align_score * direction_score - scores1.append(score) - base_scores0.append(scores1) + align_score = 0.2 + 0.4 * (n0 + n1) / mag_delta_coordinates # minimum 0.5 + base_score = closeness_score * align_score * direction_score + base_scores1.append(base_score) + base_scores0.append(base_scores1) def get_minimums_ratio(scores): """ - Get ratio of lowest / next lowest score as measure of 'only optiob' for first link to end point. + Get ratio of lowest / next lowest score as measure of 'only option' for first link to end point. :param scores: :return: """ - if len(scores1) >= 2: - min1 = min2 = float('inf') - for score in scores1: + inf = float('inf') + min1 = min2 = inf + for score in scores: + if score is not None: if score < min1: min1, min2 = score, min1 elif score < min2: min2 = score + if min2 is not inf: return min1 / min2 return 1.0 @@ -547,10 +542,11 @@ def get_minimums_ratio(scores): links_count0 = [0.0] * len(exclusive_base_scores0) links_count1 = [0.0] * len(exclusive_base_scores1) - print("Start") best_score = 1.0 + cut_off_base_score = 0.1 while best_score is not None: best_score = None + best_nonexclusive_score = None best_area = 0.0 best_indexes = None locked = False @@ -560,7 +556,7 @@ def get_minimums_ratio(scores): base_scores1 = base_scores0[index0] for index1, end_point_data1 in enumerate(sorted_end_point_data1): base_score = base_scores1[index1] - if base_score >= worst_base_score: + if base_score is None: continue node_id1 = end_point_data1[0] area1 = end_point_data1[3] @@ -569,15 +565,17 @@ def get_minimums_ratio(scores): node_identifiers = (node_id0, node_id1) if node_identifiers in locked_node_identifiers: best_area = max(min_area, area) # don't want area to get negative - best_score = base_score / best_area + best_nonexclusive_score = best_score = base_score / math.log(best_area) best_indexes = indexes locked_node_identifiers.remove(node_identifiers) locked = True break else: + if base_score > cut_off_base_score: + continue if area < min_area: continue - score = base_score / area + nonexclusive_score = score = base_score / math.log(area) # lower score for first links by factor indicating 'only option' exclusive_base_scores = [] if links_count0[index0] == 0: @@ -585,9 +583,14 @@ def get_minimums_ratio(scores): if links_count1[index1] == 0: exclusive_base_scores.append(exclusive_base_scores1[index1]) if exclusive_base_scores: - score *= min(exclusive_base_scores) + score *= min(exclusive_base_scores) ** 2.0 + else: + if area < (3.0 * min_area): + continue + score *= (links_count0[index0] + links_count1[index1] + 1) if (best_score is None) or (score < best_score): best_score = score + best_nonexclusive_score = nonexclusive_score best_area = area best_indexes = indexes if locked: @@ -599,14 +602,13 @@ def get_minimums_ratio(scores): end_point_data1 = sorted_end_point_data1[best_indexes[1]] node_id1 = end_point_data1[0] self.add_linked_nodes(annotation, node_id0, node_id1) - print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) + # print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) end_point_data0[3] -= best_area end_point_data1[3] -= best_area links_count0[best_indexes[0]] += 1 links_count1[best_indexes[1]] += 1 - # penetration score is only used for total score used by auto align - penetration_score = 0.0 # GRC todo - total_score += best_score * best_area + penetration_score + # total score is not affected by exclusive measure used to match 'only option' links + total_score += best_nonexclusive_score * best_area # while len(sorted_end_point_data0): # end_point_data0 = sorted_end_point_data0[0] diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 2438755..700cc7a 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -196,19 +196,29 @@ def _get_end_node_ids(self): def _element_id_to_group(self, element_id, annotations): """ - Get the first Annotation zinc Group containing raw element of supplied identifier. + Get the first Annotation zinc Group containing raw element of supplied identifier, prioritizing + any annotation group with term ids. :param node_id: Identifier of [end] node to query. :param annotations: Global list of all annotations. :return: Zinc Group, MeshGroup or None, None if not found. """ element = self._raw_mesh1d.findElementByIdentifier(element_id) + best_group = None + best_mesh_group = None for annotation in annotations: + has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) + if best_group and not has_term: + continue group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() if group.isValid(): mesh_group = group.getMeshGroup(self._raw_mesh1d) if mesh_group.isValid() and mesh_group.containsElement(element): - return group, mesh_group - return None, None + best_group = group + best_mesh_group = mesh_group + if has_term: + # print("Found group", annotation.get_name(), annotation.get_term()) + break + return best_group, best_mesh_group def _track_segment(self, start_node_id, start_element_id, max_length=None, min_element_count=None, min_aspect_ratio=None): @@ -390,6 +400,8 @@ def create_end_point_directions(self, annotations, max_distance): if tmp_annotation.get_name() == annotation_group_name: annotation = tmp_annotation break + else: + print("No annotation group for node", end_node_id) self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) # set up visualization objects. End direction datapoints have same identifiers as raw end nodes node = self._working_datapoints.createNode(end_node_id, nodetemplate) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 8cd7a39..61fe3bd 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -94,15 +94,16 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo if annotation.get_name() == name: existing_term = annotation.get_term() if term != existing_term: - print("Warning: Found existing annotation with name", name, - "but existing term", existing_term, "does not equal new term", term) + logger.warning("Segment " + name + ": Found existing annotation with name " + name + + " but existing term " + str(existing_term) + + " does not equal new term " + str(term)) if term and (existing_term is None): annotation.set_term(term) break # exists already if name > annotation.get_name(): index += 1 else: - # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), + # print("Add annotation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" @@ -156,8 +157,8 @@ def decode_settings(self, settings_in: dict): processed_count += 1 break else: - print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, - "in settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + + "in settings not found; ignoring. Have input files changed?") if processed_count != len(self._annotations): for annotation in self._annotations: name = annotation.get_name() @@ -166,8 +167,8 @@ def decode_settings(self, settings_in: dict): if (annotation_settings["name"] == name) and (annotation_settings["term"] == term): break else: - print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, - "not found in settings; using defaults. Have input files changed?") + logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + + "not found in settings; using defaults. Have input files changed?") # update segment settings and warn about differences processed_count = 0 From d734188bbbaac26ef203ffbdb69228287469dc4e Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 24 Sep 2025 14:34:41 +1200 Subject: [PATCH 08/24] Add methods to lock, unlock, select locked links --- src/segmentationstitcher/connection.py | 139 ++++++++++++------------- 1 file changed, 64 insertions(+), 75 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 023b72e..f6ffcd7 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -4,10 +4,11 @@ from cmlibs.maths.vectorops import ( add, axis_angle_to_rotation_matrix, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_mult, matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, sub) +from cmlibs.utils.zinc.scene import scene_get_or_create_selection_group from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range -from cmlibs.utils.zinc.general import ChangeManager +from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager from cmlibs.utils.zinc.group import group_add_group_local_contents from cmlibs.zinc.element import Element, Elementbasis from cmlibs.zinc.field import Field @@ -51,7 +52,7 @@ def __init__(self, segments, root_region, annotations, max_distance): group = fieldmodule.createFieldGroup() group.setName(group_name) group.setManaged(True) - self._annotation_links = {} # dict: annotation name --> list of {'locked': bool, 'node identifiers': list} + self._annotation_links = {} # dict: annotation name --> list of {'lock': bool, 'node identifiers': list} for segment in self._segments: segment.add_transformation_change_callback(self._segment_transformation_change) @@ -84,7 +85,7 @@ def decode_settings(self, settings_in: dict): new_links = [] if isinstance(links[0], list): for node_identifiers in links: - new_links.append({'locked': False, 'node identifiers': node_identifiers}) + new_links.append({'lock': False, 'node identifiers': node_identifiers}) annotation_links[annotation_name] = new_links del settings['linked nodes'] settings['annotation links'] = annotation_links @@ -174,12 +175,12 @@ def _segment_transformation_change(self, segment): self.build_links() self.update_annotation_category_groups(self._annotations) - def add_linked_nodes(self, annotation, node_id0, node_id1, locked=False): + def add_linked_nodes(self, annotation, node_id0, node_id1, lock=False): """ :param annotation: Annotation to use for link. :param node_id0: Node identifier to link from segment[0]. :param node_id1: Node identifier to link from segment[1]. - :param locked: True if link is + :param lock: True to keep link connected until unlocked. """ annotation_name = annotation.get_name() links = self._annotation_links.get(annotation_name) @@ -190,7 +191,7 @@ def add_linked_nodes(self, annotation, node_id0, node_id1, locked=False): for name in list(self._annotation_links.keys()): if name > annotation_name: self._annotation_links[name] = self._annotation_links.pop(name) - links.append({'locked': locked, 'node identifiers': [node_id0, node_id1]}) + links.append({'lock': lock, 'node identifiers': [node_id0, node_id1]}) def get_annotation_links(self): """ @@ -443,7 +444,7 @@ def build_links(self, build_link_objects=True): annotation_names = list(self._annotation_links.keys()) for annotation_name in annotation_names: for link in self._annotation_links[annotation_name]: - if link['locked']: + if link['lock']: locked_node_identifiers.add(tuple(link['node identifiers'])) self._annotation_links = {} @@ -484,7 +485,7 @@ def build_links(self, build_link_objects=True): max_mag_delta_coordinates = 0.5 * self._max_distance # below this proportion of max_mag_delta_coordinates the closeness score is the same: min_relative_distance = 0.01 - worst_base_score = None + worst_base_score = 10.0 for index0, end_point_data0 in enumerate(sorted_end_point_data0): node_id0, coordinates0, direction0, area0, annotation0 = end_point_data0 base_scores1 = [] @@ -495,15 +496,15 @@ def build_links(self, build_link_objects=True): base_scores1.append(worst_base_score) continue # end points have different annotation dot_directions = dot(direction0, direction1) # -1.0 if perfectly pointing at each other - if dot_directions > 0.2: # arbitrary factor - base_scores1.append(worst_base_score) - continue # end points are not pointing towards each other + # if dot_directions > 0.2: # arbitrary factor + # base_scores1.append(worst_base_score) + # continue # end points are not pointing towards each other direction_score = 0.2 + (0.8 / 1.2) * (1.0 + dot_directions) # minimum 0.2 delta_coordinates = sub(coordinates1, coordinates0) mag_delta_coordinates = magnitude(delta_coordinates) - if mag_delta_coordinates > max_mag_delta_coordinates: - base_scores1.append(worst_base_score) - continue # end point are too far away from each other + # if mag_delta_coordinates > max_mag_delta_coordinates: + # base_scores1.append(worst_base_score) + # continue # end point are too far away from each other relative_distance = mag_delta_coordinates / max_mag_delta_coordinates closeness_score = max(relative_distance, min_relative_distance) if mag_delta_coordinates == 0.0: @@ -549,15 +550,13 @@ def get_minimums_ratio(scores): best_nonexclusive_score = None best_area = 0.0 best_indexes = None - locked = False + lock = False for index0, end_point_data0 in enumerate(sorted_end_point_data0): node_id0 = end_point_data0[0] area0 = end_point_data0[3] base_scores1 = base_scores0[index0] for index1, end_point_data1 in enumerate(sorted_end_point_data1): base_score = base_scores1[index1] - if base_score is None: - continue node_id1 = end_point_data1[0] area1 = end_point_data1[3] area = min(area0, area1) @@ -568,7 +567,7 @@ def get_minimums_ratio(scores): best_nonexclusive_score = best_score = base_score / math.log(best_area) best_indexes = indexes locked_node_identifiers.remove(node_identifiers) - locked = True + lock = True break else: if base_score > cut_off_base_score: @@ -593,7 +592,7 @@ def get_minimums_ratio(scores): best_nonexclusive_score = nonexclusive_score best_area = area best_indexes = indexes - if locked: + if lock: break if best_score is not None: end_point_data0 = sorted_end_point_data0[best_indexes[0]] @@ -601,7 +600,7 @@ def get_minimums_ratio(scores): annotation = end_point_data0[4] end_point_data1 = sorted_end_point_data1[best_indexes[1]] node_id1 = end_point_data1[0] - self.add_linked_nodes(annotation, node_id0, node_id1) + self.add_linked_nodes(annotation, node_id0, node_id1, lock) # print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) end_point_data0[3] -= best_area end_point_data1[3] -= best_area @@ -610,61 +609,6 @@ def get_minimums_ratio(scores): # total score is not affected by exclusive measure used to match 'only option' links total_score += best_nonexclusive_score * best_area - # while len(sorted_end_point_data0): - # end_point_data0 = sorted_end_point_data0[0] - # node_id0, coordinates0, direction0, radius0, annotation0 = end_point_data0 - # category0 = annotation0.get_category() - # best_index1 = None - # lowest_score = 0.0 - # weight = None - # for index1, end_point_data1 in enumerate(sorted_end_point_data1): - # node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 - # category1 = annotation1.get_category() - # # inter-segment links are only to the same annotation; links within category will be done separately - # if annotation0 != annotation1: - # continue # end points are not allowed to join - # direction_score = math.fabs(1.0 + dot(direction0, direction1)) - # if direction_score > 0.5: # arbitrary factor - # continue # end points are not pointing towards each other - # delta_coordinates = sub(coordinates1, coordinates0) - # mag_delta_coordinates = magnitude(delta_coordinates) - # tdistance = dot(direction0, delta_coordinates) - # ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) - # if mag_delta_coordinates > (0.5 * self._max_distance): - # continue # point is too far away - # distance_score = ((tdistance * tdistance + 100.0 * ndistance * ndistance) / - # (self._max_distance * self._max_distance)) - # tfactor = math.exp(-1000.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor - # penetration_distance_score = ((tfactor * tdistance * tdistance) / - # (self._max_distance * self._max_distance)) - # delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale - # radius_score = delta_radius * delta_radius - # score = radius0 * (10.0 * direction_score + distance_score + radius_score) - # if (best_index1 is None) or (score < lowest_score): - # best_index1 = index1 - # weight = 0.5 * (annotation0.get_align_weight() + annotation1.get_align_weight()) - # lowest_score = score + penetration_distance_score - # if best_index1 is not None: - # # if category0 != AnnotationCategory.NETWORK_GROUP_1: - # total_score += weight * lowest_score - # node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] - # self.add_linked_nodes(annotation1, node_id0, node_id1) - # remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) - # if (radius0 > radius1) and (remaining_radius > remaining_radius_factor * radius0): - # for i in range(1, len(sorted_end_point_data0)): - # if remaining_radius > sorted_end_point_data0[i][3]: - # break - # # sorted_end_point_data0.insert(i, (node_id0, coordinates0, direction0, remaining_radius, annotation0)) - # elif remaining_radius > (remaining_radius_factor * radius1): - # for i in range(best_index1, len(sorted_end_point_data1)): - # if remaining_radius > sorted_end_point_data1[i][3]: - # break - # # sorted_end_point_data1.insert(i, (node_id1, coordinates1, direction1, remaining_radius, annotation1)) - # sorted_end_point_data1.pop(best_index1) - # else: - # total_score += radius0 * 20.0 # arbitrary factor - # sorted_end_point_data0.pop(0) - if build_link_objects: self._build_link_objects() @@ -738,6 +682,51 @@ def _build_link_objects(self): element.setNodesByIdentifier(eft, cnode_ids) element_identifier += 1 + def set_link_locking_from_selection(self, lock: bool): + """ + Lock or unlock links for nodes matching any selected visualization elements. + :param lock: True to lock, False to unlock. + """ + root_scene = self._region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return + element_identifier = 1 + for annotation_name, links in self._annotation_links.items(): + for link in links: + link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() + if link_selected: + link['lock'] = lock + element_identifier += 1 + + def add_locked_links_to_selection(self): + """ + Add locked links to the scene selection. + """ + root_region = self._region.getRoot() + root_scene = root_region.getScene() + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + # create selection on demand if any links have a lock + root_selection_group = None + selection_mesh_group = None + element_identifier = 1 + with ChangeManager(root_scene), HierarchicalChangeManager(root_region): + for annotation_name, links in self._annotation_links.items(): + for link in links: + if link['lock']: + if not selection_mesh_group: + root_selection_group = scene_get_or_create_selection_group(root_scene) + selection_mesh_group = root_selection_group.getOrCreateMeshGroup(mesh1d) + link_element = mesh1d.findElementByIdentifier(element_identifier) + selection_mesh_group.addElement(link_element) + element_identifier += 1 + def update_annotation_category_groups(self, annotations): """ Rebuild all annotation category groups e.g. after loading settings. From f6f508a38dd4ce6e3466f628eed945caccf96e0e Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 24 Sep 2025 19:27:55 +1200 Subject: [PATCH 09/24] Handle having no radius or rgb values --- src/segmentationstitcher/stitcher.py | 30 +++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 61fe3bd..2751ec0 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -10,6 +10,7 @@ from cmlibs.zinc.element import Element, Elementbasis from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node +from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.connection import Connection from segmentationstitcher.segment import Segment from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations @@ -358,6 +359,7 @@ def stitch(self, region): segment_node_maps = [{} for segment in self._segments] # maps from segment node id to output node id # stitch segments in order of connections, followed by unconnected segments + default_radius = self._max_distance * 0.01 for connection in self._connections: segment_node_map_pair = [segment_node_maps[self._segments.index(segment)] for segment in connection.get_segments()] @@ -367,7 +369,7 @@ def stitch(self, region): node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier) output_segment_elements = True processed_segments.append(segment) if segment is connection.get_segments()[1]: @@ -386,7 +388,7 @@ def stitch(self, region): node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier) element_identifier = _output_segment_elements( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, @@ -401,7 +403,21 @@ def write_output_segmentation_file(self, file_name): def _output_segment_nodes_and_markers( segment, segment_node_map, annotation_groups, fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, - nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier): + nodetemplate, marker_nodetemplate, default_radius, node_identifier, datapoint_identifier): + """ + :param segment: The segment to output. + :param segment_node_map: maps from segment node id to output node id + :param annotation_groups: map from annotation name to list of Zinc groups (2nd is term group) + :param fieldmodule: Fieldmodule for output region. + :param fieldcache: Fieldcache for output region. + :param coordinates: Coordinates field. + :param radius: Radius field. + :param rgb: Optional rgb field. + :param default_radius: Radius value to use if not define on nodes or datapoints. + :param node_identifier: starting node identfier. + :param datapoint_identifier: starting datapoint identifier. + :return: Next node_identifier, next datapoint_identifier + """ raw_region = segment.get_raw_region() raw_fieldmodule = raw_region.getFieldmodule() raw_coordinates = raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -444,9 +460,13 @@ def _output_segment_nodes_and_markers( x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + if result != RESULT_OK: + r = default_radius radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) if rgb: result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + if result != RESULT_OK: + rgb_value = [1.0, 1.0, 1.0] rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) segment_node_map[raw_node_identifier] = node_identifier segment_node_group.addNode(node) @@ -465,9 +485,13 @@ def _output_segment_nodes_and_markers( x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + if result != RESULT_OK: + r = default_radius radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) if rgb: result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + if result != RESULT_OK: + rgb_value = [1.0, 1.0, 1.0] rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) name = raw_marker_name.evaluateString(raw_fieldcache) marker_name.assignString(fieldcache, name) From a00d28db7fba1f98a617c03fb4c6f9b77c465a2f Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 26 Sep 2025 13:58:39 +1200 Subject: [PATCH 10/24] Reimplement auto-align segment Use radians for rotations internally. Use common utility for rotating about a centre point and axis. --- src/segmentationstitcher/connection.py | 166 +++++++++---------------- src/segmentationstitcher/segment.py | 64 ++++++++-- src/segmentationstitcher/stitcher.py | 3 +- tests/test_vagus.py | 67 ++++++---- 4 files changed, 158 insertions(+), 142 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index f6ffcd7..8ab6e14 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -3,7 +3,7 @@ """ from cmlibs.maths.vectorops import ( add, axis_angle_to_rotation_matrix, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_mult, - matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, sub) + matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, set_magnitude, sub) from cmlibs.utils.zinc.scene import scene_get_or_create_selection_group from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) @@ -233,6 +233,8 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): return fixed_segment_index = 1 if (dependent_segment_index == 0) else 0 + fixed_transformed_end_location = None + fixed_transformed_end_direction = None number_of_iterations = 2 # so second iteration starts reliably close for iter in range(number_of_iterations): # get segment transformations and apply to end points @@ -240,8 +242,7 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): initial_rotation_matrix = [] for s, segment in enumerate(self._segments): translation = segment.get_translation() - rotation_radians = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] - rotation_matrix = euler_to_rotation_matrix(rotation_radians) + rotation_matrix = euler_to_rotation_matrix(segment.get_rotation_radians()) initial_rotation_matrix.append(rotation_matrix) end_point_data = [] raw_end_point_data = segment.get_end_point_data() @@ -258,8 +259,8 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): # get weighted mean end coordinates and directions of segment end points weighted by closeness to other segment mean_end_locations = [] mean_end_directions = [] # unit mean untransformed directions - far_proportion = 0.5 # proportion of max_distance above which distance weighting is zero - far_distance = self._max_distance * far_proportion + minimum_gap + # distance above which distance weighting is zero + far_distance = self._max_distance + minimum_gap for s, segment in enumerate(self._segments): distances = [] # min transformed distance from end points of this segment to linkable end points in other max_distance = None @@ -312,124 +313,75 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): mean_end_locations.append(add(mean_coordinates, mult(mean_end_direction, offset))) # get angle axis transformation of dependent direction onto fixed direction + dependent_segment = self._segments[dependent_segment_index] rotated_mean_end_directions = [ matrix_vector_mult(initial_rotation_matrix[s], mean_end_directions[s]) for s in range(2)] + fixed_transformed_end_direction = rotated_mean_end_directions[fixed_segment_index] # need to reverse fixed direction so inline axis = cross(rotated_mean_end_directions[dependent_segment_index], [-d for d in rotated_mean_end_directions[fixed_segment_index]]) mag_axis = magnitude(axis) - rotation_matrix = initial_rotation_matrix[dependent_segment_index] if mag_axis > 1.0E-6: axis = div(axis, mag_axis) - theta = math.asin(mag_axis) - axis_angle_rotation_matrix = axis_angle_to_rotation_matrix(axis, theta) - rotation_matrix = matrix_mult(axis_angle_rotation_matrix, rotation_matrix) - rotation_radians = rotation_matrix_to_euler(rotation_matrix) - rotation = [math.degrees(angle_radians) for angle_radians in rotation_radians] + angle_radians = math.asin(mag_axis) + centre = mean_end_locations[dependent_segment_index] + dependent_segment.rotate_about_point_axis(centre, axis, angle_radians, notify=False) + dependent_rotation_matrix = euler_to_rotation_matrix(dependent_segment.get_rotation_radians()) else: - rotation = self._segments[dependent_segment_index].get_rotation() + dependent_rotation_matrix = initial_rotation_matrix[dependent_segment_index] dependent_rotated_end_location = matrix_vector_mult( - rotation_matrix, mean_end_locations[dependent_segment_index]) - fixed_rotated_end_location = add( + dependent_rotation_matrix, mean_end_locations[dependent_segment_index]) + fixed_transformed_end_location = add( matrix_vector_mult(initial_rotation_matrix[fixed_segment_index], mean_end_locations[fixed_segment_index]), self._segments[fixed_segment_index].get_translation()) - translation = sub(fixed_rotated_end_location, dependent_rotated_end_location) - - # first part: - dependent_segment = self._segments[dependent_segment_index] - dependent_segment.set_rotation(rotation, notify=False) + translation = sub(fixed_transformed_end_location, dependent_rotated_end_location) dependent_segment.set_translation(translation, notify=False) - dependent_segment.set_translation(translation) # GRC temporary to force notification - return + # first stage only + # dependent_segment.set_translation(translation) # force notification + # return - # optimise transformation of dependent segment so mean coordinates and directions coincide - - def rotation_objective(trial_rotation, *args): - target_direction, source_direction, target_side_direction, source_side_direction = args - trial_rotation_matrix = euler_to_rotation_matrix(trial_rotation) - trans_direction = matrix_vector_mult(trial_rotation_matrix, source_direction) - trans_side_direction = matrix_vector_mult(trial_rotation_matrix, source_side_direction) - return dot(trans_direction, target_direction) + dot(target_side_direction, trans_side_direction) - - # note the result is dependent on the initial position, but final optimisation should reduce effect - # get a side direction to minimise the unconstrained twist from the current direction - dependent_segment_end_point_data = segment_end_point_data[dependent_segment_index] - axis = [1.0, 0.0, 0.0] - if dot(transformed_mean_directions[fixed_segment_index], axis) < 0.1: - axis = [0.0, 1.0, 0.0] - target_side = normalize(cross(transformed_mean_directions[fixed_segment_index], axis)) - source_side = normalize( - cross(cross(target_side, transformed_mean_directions[dependent_segment_index]), transformed_mean_directions[dependent_segment_index])) - if initial_rotation_matrix[dependent_segment_index]: - transformed_source_side = source_side - inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[dependent_segment_index]) - source_side = matrix_vector_mult(inverse_rotation_matrix, transformed_source_side) - initial_angles = [math.radians(angle_degrees) for angle_degrees in dependent_segment.get_rotation()] - side_weight = 0.01 # so side has only a small effect on objective - res = minimize(rotation_objective, initial_angles, - args=(transformed_mean_directions[fixed_segment_index], unit_mean_directions[dependent_segment_index], - mult(target_side, side_weight), mult(source_side, side_weight)), - method='Nelder-Mead', tol=0.001) - if not res.success: - logger.warning("Segmentation Stitcher. Could not optimise initial rotation") - return - rotation = [math.degrees(angle_radians) for angle_radians in res.x] - rotation_matrix = euler_to_rotation_matrix(res.x) - rotated_mean_coordinates = matrix_vector_mult(rotation_matrix, mean_coordinates[dependent_segment_index]) - translation = sub(mean_transformed_coordinates[fixed_segment_index], rotated_mean_coordinates) - # update transformed_coordinates in second segment data - for p, data in enumerate(dependent_segment_end_point_data): - coordinates = data[2] - transformed_coordinates = add(matrix_vector_mult(rotation_matrix, coordinates), translation) - dependent_segment_end_point_data[p] = (data[0], transformed_coordinates, data[2], data[3], data[4], data[5]) - unit_transformed_direction = matrix_vector_mult(rotation_matrix, unit_mean_directions[dependent_segment_index]) - # translate along unit_transformed_direction so no overlap between points - total_overlap = 0.0 - for s, segment in enumerate(self._segments): - max_overlap = 0.0 - for data in segment_end_point_data[s]: - overlap = dot(sub(data[1], mean_transformed_coordinates[fixed_segment_index]), unit_transformed_direction) - if s == fixed_segment_index: - overlap = -overlap - if overlap > max_overlap: - max_overlap = overlap - total_overlap += max_overlap - translation = sub(translation, mult(unit_transformed_direction, total_overlap)) - dependent_segment.set_rotation(rotation, notify=False) - # dependent_segment.set_translation(translation, notify=False) - - # GRC rotation only: - dependent_segment.set_translation(translation) - return + # optimise rotation and translation in plane - # GRC temp - # score = self.build_links(build_link_objects=False) - # print("part 1 rotation", rotation, "translation", translation, "score", score) + centre = fixed_transformed_end_location + axis3 = fixed_transformed_end_direction + # get 2 orthogonal axes for translations, scaled by max_distance so parameter scale similar to rotation radians: + axis1 = cross([1.0, 0.0, 0.0], axis3) + if magnitude(axis1) < 0.1: + axis1 = cross([0.0, 1.0, 0.0], axis3) + axis1 = set_magnitude(axis1, 0.5 * self._max_distance) + axis2 = cross(axis3, axis1) + initial_rotation = dependent_segment.get_rotation_radians() + initial_translation = dependent_segment.get_translation() - # optimise angles and translation def links_objective(rotation_translation, *args): - rotation = list(rotation_translation[:3]) - translation = list(rotation_translation[3:]) - dependent_segment.set_rotation(rotation, notify=False) - dependent_segment.set_translation(translation, notify=False) + angle_radians = rotation_translation[0] + translation1 = rotation_translation[1] + translation2 = rotation_translation[2] + dependent_segment.set_rotation_radians(initial_rotation, notify=False) + dependent_segment.set_translation(initial_translation, notify=False) + dependent_segment.rotate_about_point_axis(centre, axis3, angle_radians, notify=False) + dependent_segment.translate(add(mult(axis1, translation1), mult(axis2, translation2)), notify=False) score = self.build_links(build_link_objects=False) - # print("rotation", rotation, "translation", translation, "score", score) + # print(rotation_translation, "score", score) return score - initial_parameters = rotation + translation - initial_score = links_objective(initial_parameters, ()) - # TOL = initial_score * 1.0E-6 - # method='Nelder-Mead' - res = minimize(links_objective, initial_parameters, method='Powell') # , tol=TOL) - if not res.success: - logger.warning("Segmentation Stitcher. Could not optimise final rotation and translation") - return - rotation = list(res.x[:3]) - translation = list(res.x[3:]) - dependent_segment.set_rotation(rotation, notify=False) - # this will invoke build_links: - dependent_segment.set_translation(translation) + initial_rotation_translation = [0.0, 0.0, 0.0] + res = minimize(links_objective, initial_rotation_translation, + args=(), + method='Nelder-Mead', # method='Powell', + bounds=[(-0.5, 0.5), (-0.5, 0.5), (-0.5, 0.5)]) # , tol=TOL) + if res.success: + links_objective(res.x) # to ensure the last values are converted to rotation and translation + # this will invoke build_links and build_link_objects: + dependent_segment.set_translation(dependent_segment.get_translation()) + else: + logger.warning("Segmentation Stitcher. Could not optimise rotation and translation") + # restore transformation + dependent_segment.set_rotation_radians(initial_rotation, notify=False) + # this will invoke build_links and build_link_objects: + dependent_segment.set_translation(initial_translation) + return def build_links(self, build_link_objects=True): """ @@ -453,7 +405,7 @@ def build_links(self, build_link_objects=True): min_area = None for s, segment in enumerate(self._segments): translation = segment.get_translation() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation = segment.get_rotation_radians() rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None sorted_end_point_data = [] @@ -484,7 +436,7 @@ def build_links(self, build_link_objects=True): base_scores0 = [] # index over segment 0 endpoints, then segment 1 max_mag_delta_coordinates = 0.5 * self._max_distance # below this proportion of max_mag_delta_coordinates the closeness score is the same: - min_relative_distance = 0.01 + min_relative_distance = 0.0001 worst_base_score = 10.0 for index0, end_point_data0 in enumerate(sorted_end_point_data0): node_id0, coordinates0, direction0, area0, annotation0 = end_point_data0 @@ -564,7 +516,7 @@ def get_minimums_ratio(scores): node_identifiers = (node_id0, node_id1) if node_identifiers in locked_node_identifiers: best_area = max(min_area, area) # don't want area to get negative - best_nonexclusive_score = best_score = base_score / math.log(best_area) + best_nonexclusive_score = best_score = base_score / math.sqrt(best_area) best_indexes = indexes locked_node_identifiers.remove(node_identifiers) lock = True @@ -574,7 +526,7 @@ def get_minimums_ratio(scores): continue if area < min_area: continue - nonexclusive_score = score = base_score / math.log(area) + nonexclusive_score = score = base_score / math.sqrt(area) # lower score for first links by factor indicating 'only option' exclusive_base_scores = [] if links_count0[index0] == 0: @@ -639,7 +591,7 @@ def _build_link_objects(self): snodes.append(sfieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)) sfieldcache.append(sfieldmodule.createFieldcache()) tr_coordinates = sfieldmodule.findFieldByName("coordinates").castFiniteElement() - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation = segment.get_rotation_radians() if rotation != [0.0, 0.0, 0.0]: rotation_matrix = euler_to_rotation_matrix(rotation) tr_coordinates = sfieldmodule.createFieldMatrixMultiply( diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 700cc7a..ee3b4e9 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -2,8 +2,9 @@ A segment of the segmentation data, generally from a separate image block. """ from cmlibs.maths.vectorops import ( - add, cross, dot, euler_to_rotation_matrix, magnitude, matrix_mult, matrix_vector_mult, mult, normalize, - set_magnitude, sub) + add, axis_angle_to_rotation_matrix, cross, dot, euler_to_rotation_matrix, magnitude, matrix_mult, + matrix_vector_mult, mult, normalize, rotation_matrix_to_euler, set_magnitude, sub) + from cmlibs.utils.zinc.field import ( get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, find_or_create_field_stored_string) @@ -14,6 +15,7 @@ from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.annotation import AnnotationCategory +import copy import json import logging import math @@ -92,7 +94,7 @@ def decode_settings(self, settings_in: dict): # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) - self._rotation = settings["rotation"] + self._rotation = [math.radians(deg) for deg in settings["rotation"]] self._translation = settings["translation"] def encode_settings(self) -> dict: @@ -102,7 +104,7 @@ def encode_settings(self) -> dict: """ settings = { "name": self._name, - "rotation": self._rotation, + "rotation": [math.degrees(rad) for rad in self._rotation], "translation": self._translation } return settings @@ -490,7 +492,7 @@ def transform_coordinates(self, position): :param position Coordinates x, y, z in the segment. :return: Transformed position. """ - rotation_matrix = euler_to_rotation_matrix([math.radians(deg) for deg in self._rotation]) + rotation_matrix = euler_to_rotation_matrix(self._rotation) return add(matrix_vector_mult(rotation_matrix, position), self._translation) def get_raw_region(self): @@ -522,17 +524,48 @@ def _transformation_change(self): for transformation_change_callback in self._transformation_change_callbacks: transformation_change_callback(self) - def get_rotation(self): + def get_rotation_radians(self): return self._rotation - def set_rotation(self, rotation, notify=True): + def set_rotation_radians(self, rotation, notify=True): """ Set segment rotation, which applies before translation. - :param rotation: Rotation as list of 3 Euler angles in degrees. + :param rotation: Rotation as list of 3 Euler angles in radians. :param notify: Set to False to avoid notification to clients if setting translation afterwards. """ assert len(rotation) == 3 - self._rotation = rotation + self._rotation = copy.copy(rotation) + if notify: + self._transformation_change() + + def get_rotation_degrees(self): + return [math.degrees(rad) for rad in self._rotation] + + def set_rotation_degrees(self, rotation, notify=True): + """ + Set segment rotation, which applies before translation. + :param rotation: Rotation as list of 3 Euler angles in degrees. + :param notify: Set to False to avoid notification to clients if setting translation afterwards. + """ + self.set_rotation_radians([math.radians(deg) for deg in rotation], notify) + + def rotate_about_point_axis(self, centre, axis, angle_radians, notify=True): + """ + Update rotation and translation parameters to include a subsequent rotation about a centre. + :param centre: Centre of subsequent rotation (after initial rotation and translation applied). + :param axis: Axis of subsequent rotation (after initial rotation and translation applied). + :param angle_radians: Rotation in radians in a right hand sense about axis. + :param notify: Set to False to avoid notification to clients if setting rotation afterwards. + """ + mat1 = euler_to_rotation_matrix(self._rotation) + centre_translation1 = matrix_vector_mult(mat1, centre) + mat2 = axis_angle_to_rotation_matrix(axis, angle_radians) + product_mat = matrix_mult(mat2, mat1) + centre_translation2 = matrix_vector_mult(product_mat, centre) + self._rotation = rotation_matrix_to_euler(product_mat) + # correct translation of centre by new rotation: + centre_offset = sub(centre_translation1, centre_translation2) + self._translation = add(self._translation, centre_offset) if notify: self._transformation_change() @@ -546,7 +579,18 @@ def set_translation(self, translation, notify=True): :param notify: Set to False to avoid notification to clients if setting rotation afterwards. """ assert len(translation) == 3 - self._translation = translation + self._translation = copy.copy(translation) + if notify: + self._transformation_change() + + def translate(self, offset, notify=True): + """ + :param offset: 3 value to add to translation + :param notify: Set to False to avoid notification to clients if setting rotation afterwards. + """ + assert len(offset) == 3 + for c in range(3): + self._translation[c] += offset[c] if notify: self._transformation_change() diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 2751ec0..44e7880 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -425,8 +425,7 @@ def _output_segment_nodes_and_markers( raw_rgb = raw_fieldmodule.findFieldByName("rgb").castFiniteElement() if rgb else None raw_marker_name = raw_fieldmodule.findFieldByName("marker_name").castStoredString() raw_nodes = raw_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) - rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] - rotation_matrix = euler_to_rotation_matrix(rotation) + rotation_matrix = euler_to_rotation_matrix(segment.get_rotation_radians()) translation = segment.get_translation() raw_groups = get_group_list(raw_fieldmodule) raw_nodeset_groups = [] diff --git a/tests/test_vagus.py b/tests/test_vagus.py index f3e132d..864da5b 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -110,7 +110,7 @@ def test_align_stitch_vagus1(self): stitcher = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) segments = stitcher.get_segments() - segments[1].set_rotation([0.0, -10.0, -60.0]) + segments[1].set_rotation_degrees([0.0, -10.0, -60.0]) segments[1].set_translation([5.0, 0.0, 0.0]) segments[2].set_translation([10.0, 0.0, 0.5]) @@ -130,34 +130,55 @@ def test_align_stitch_vagus1(self): connection01 = stitcher.create_connection([segments[0], segments[1]]) connection12 = stitcher.create_connection([segments[1], segments[2]]) + expected_annotation_links01 = { + "Fascicle": [ + {'lock': False, + 'node identifiers': [22, 28]}, + {'lock': False, + 'node identifiers': [35, 12]}, + {'lock': False, + 'node identifiers': [40, 23]}], + "left vagus X nerve trunk": [ + {'lock': False, + 'node identifiers': [11, 1]}]} + connection01.auto_align_segment(1) - rotation = segments[1].get_rotation() + rotation = segments[1].get_rotation_degrees() translation = segments[1].get_translation() - assertAlmostEqualList(self, [-2.894576, -5.574263, -63.93093], rotation, delta=TOL) - assertAlmostEqualList(self, [4.88866, -0.01213587, 0.01357185], translation, delta=TOL) - linked_nodes01 = connection01.get_linked_nodes() - self.assertEqual(linked_nodes01, { - "Fascicle": [[22, 28], [35, 12], [40, 23]], - "left vagus X nerve trunk": [[11, 1]]}) + assertAlmostEqualList(self, [-4.459501969125895, -8.161074730792063, -58.089501540814254], rotation, delta=TOL) + assertAlmostEqualList(self, [4.901057529124233, 0.004805043555627213, -0.04779580320829241], + translation, delta=TOL) + annotation_links01 = connection01.get_annotation_links() + self.assertEqual(expected_annotation_links01, annotation_links01) + + expected_annotation_links12 = { + "Fascicle": [ + {'lock': False, + 'node identifiers': [38, 25]}, + {'lock': False, + 'node identifiers': [22, 15]}], + "left vagus X nerve trunk": [ + {'lock': False, + 'node identifiers': [11, 1]}]} connection12.auto_align_segment(1) - assertAlmostEqualList(self, [-4.919549, -2.280625, -13.52467], segments[2].get_rotation(), delta=TOL) - assertAlmostEqualList(self, [9.543171, -0.3494296, 0.03930248], segments[2].get_translation(), delta=TOL) - linked_nodes12 = connection12.get_linked_nodes() - self.assertEqual(linked_nodes12, { - "Fascicle": [[22, 15], [38, 25]], - "left vagus X nerve trunk": [[11, 1]]}) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [-3.216043371586617, -5.467042596782779, -0.4267779669299892], rotation, delta=TOL) + assertAlmostEqualList(self, [9.537442541080164, -0.3524223146102781, 0.28070488408317984], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) # now align first segment relative to second connection01.auto_align_segment(0) - rotation = segments[0].get_rotation() + rotation = segments[0].get_rotation_degrees() translation = segments[0].get_translation() - assertAlmostEqualList(self, [0.5904670871359933, 0.327008456961372, 0.08272009867790592], rotation, delta=TOL) - assertAlmostEqualList(self, [0.0010985770233687066, -0.05096474638998973, 0.02847203766782994], translation, delta=TOL) - linked_nodes01 = connection01.get_linked_nodes() - self.assertEqual(linked_nodes01, { - "Fascicle": [[22, 28], [35, 12], [40, 23]], - "left vagus X nerve trunk": [[11, 1]]}) + assertAlmostEqualList(self, [0.0022017172050087866, -0.05083254291897361, 1.5180006139100206], + rotation, delta=TOL) + assertAlmostEqualList(self, [-1.7169803818076998e-06, -0.00037724526155702106, -0.0029090218733886335], + translation, delta=TOL) + annotation_links01 = connection01.get_annotation_links() + self.assertEqual(expected_annotation_links01, annotation_links01) output_region = stitcher.get_root_region().createRegion() stitcher.stitch(output_region) @@ -169,8 +190,8 @@ def test_align_stitch_vagus1(self): datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) mesh = fieldmodule.findMeshByDimension(1) minimums, maximums = evaluate_field_nodeset_range(coordinates, nodes) - assertAlmostEqualList(self, [0.04749509590306315, -1.5276719288528786, -0.5661785379561856], minimums, delta=TOL) - assertAlmostEqualList(self, [13.538987060134247, 1.102601773020926, 0.6470665850902932], maximums, delta=TOL) + assertAlmostEqualList(self, [0.04678894233410661, -1.3448619475857166, -0.5849221355942552], minimums, delta=TOL) + assertAlmostEqualList(self, [13.528908286654149, 1.12292211593189, 1.0461133166576715], maximums, delta=TOL) fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() self.assertTrue(fascicle.isValid()) From 3b15eb09afa9ff207a3aa99fb588e91b6c1d12b4 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 30 Sep 2025 16:14:52 +1300 Subject: [PATCH 11/24] Add method for linking selected end points --- src/segmentationstitcher/connection.py | 31 +++++++++++++++++++++++--- src/segmentationstitcher/segment.py | 26 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 8ab6e14..aa02144 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -175,8 +175,10 @@ def _segment_transformation_change(self, segment): self.build_links() self.update_annotation_category_groups(self._annotations) - def add_linked_nodes(self, annotation, node_id0, node_id1, lock=False): + def set_linked_nodes(self, annotation, node_id0, node_id1, lock=False): """ + Ensure there is a link between node_id0 and node_id1 for annotation with the chosen lock state. + If link already exists, updates the lock state only. :param annotation: Annotation to use for link. :param node_id0: Node identifier to link from segment[0]. :param node_id1: Node identifier to link from segment[1]. @@ -191,7 +193,12 @@ def add_linked_nodes(self, annotation, node_id0, node_id1, lock=False): for name in list(self._annotation_links.keys()): if name > annotation_name: self._annotation_links[name] = self._annotation_links.pop(name) - links.append({'lock': lock, 'node identifiers': [node_id0, node_id1]}) + node_identifiers = [node_id0, node_id1] + for link in links: + if link['node identifiers'] == node_identifiers: + link['lock'] = lock + return + links.append({'lock': lock, 'node identifiers': node_identifiers}) def get_annotation_links(self): """ @@ -552,7 +559,7 @@ def get_minimums_ratio(scores): annotation = end_point_data0[4] end_point_data1 = sorted_end_point_data1[best_indexes[1]] node_id1 = end_point_data1[0] - self.add_linked_nodes(annotation, node_id0, node_id1, lock) + self.set_linked_nodes(annotation, node_id0, node_id1, lock) # print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) end_point_data0[3] -= best_area end_point_data1[3] -= best_area @@ -634,6 +641,24 @@ def _build_link_objects(self): element.setNodesByIdentifier(eft, cnode_ids) element_identifier += 1 + def link_and_lock_selected_ends(self): + """ + Create and lock links between all permutations of selected end points in selected elements of each segment. + """ + end_node_identifiers0, end_annotations0 = self._segments[0].get_selected_end_points() + end_node_identifiers1, end_annotations1 = self._segments[1].get_selected_end_points() + new_links_count = 0 + for node_id0, annotation0 in zip(end_node_identifiers0, end_annotations0): + for node_id1, annotation1 in zip(end_node_identifiers1, end_annotations1): + if annotation0 == annotation1: + self.set_linked_nodes(annotation0, node_id0, node_id1, lock=True) + new_links_count += 1 + if new_links_count: + self.build_links() + self.update_annotation_category_groups(self._annotations) + else: + logger.warning('Connection ' + self._name + '. Link and lock selected ends. No valid links exist') + def set_link_locking_from_selection(self, lock: bool): """ Lock or unlock links for nodes matching any selected visualization elements. diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index ee3b4e9..fc0c6b9 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -694,6 +694,32 @@ def _update_working_end_group(self): break working_node = working_nodeiterator.next() + def get_selected_end_points(self): + """ + Get end points in selected elements in segment. + :return: List of node identifiers, list of annotations for each. + """ + root_scene = self._base_region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return [], [] + fieldmodule = self._raw_region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return [], [] + node_ids = [] + annotations = [] + for node_id in self._end_node_ids: + element_id = self._node_element_ids[node_id][0] + element = selection_mesh_group.findElementByIdentifier(element_id) + if element.isValid(): + node_ids.append(node_id) + annotation = self._end_point_data[node_id][3] + annotations.append(annotation) + return node_ids, annotations + + def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): """ Compute best fit line to path coordinates, and mean radius of unfiltered points. From 7147525e50b5a712f1d684b20215e28c5bb76d53 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 2 Oct 2025 15:50:29 +1300 Subject: [PATCH 12/24] Add 2 stage auto align Sort links into node identifier order --- src/segmentationstitcher/connection.py | 63 +++++++++++++++++++------- tests/test_vagus.py | 30 +++++++++++- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index aa02144..ce5d4ea 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -194,11 +194,17 @@ def set_linked_nodes(self, annotation, node_id0, node_id1, lock=False): if name > annotation_name: self._annotation_links[name] = self._annotation_links.pop(name) node_identifiers = [node_id0, node_id1] - for link in links: + for index, link in enumerate(links): + if link['node identifiers'] < node_identifiers: + continue if link['node identifiers'] == node_identifiers: link['lock'] = lock return - links.append({'lock': lock, 'node identifiers': node_identifiers}) + break + else: + index = len(links) + # insert in order of lowest first then second node identifier + links.insert(index, {'lock': lock, 'node identifiers': node_identifiers}) def get_annotation_links(self): """ @@ -224,13 +230,19 @@ def get_coordinates_range(self): nodes = self._region.getFieldmodule().findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) return evaluate_field_nodeset_range(self._coordinates, nodes) - def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): + def auto_align_segment(self, dependent_segment_index, phase1_align=True, gap_distance=0.0, phase_2_optimize=True): """ Optimise transformation of one connected segment relative to the other, by getting best fit alignment and connection between nearest end points between them. :param dependent_segment_index: Index of segment to optimise transformation of. - :param minimum_gap: Minimum gap between aligned segments. - """ + :param phase1_align: True if performing phase 1 align ends. + :param gap_distance: Gap distance to apply in phase 1. Can be negative to overlap. + :param phase_2_optimize: True if performing phase 2 optimize transformation in plane. + """ + max_gap_distance = 0.5 * self._max_distance + if math.fabs(gap_distance) > max_gap_distance: + logger.warning("Auto align gap distance is too large, limiting to " + str(max_gap_distance)) + gap_distance = math.copysign(max_gap_distance, gap_distance) segments_count = len(self._segments) if (dependent_segment_index < 0) or (dependent_segment_index >= segments_count): logger.error("auto_align_segment. Segment index " + str(dependent_segment_index) + " out of range") @@ -239,6 +251,7 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): logger.error("auto_align_segment. Not implemented for " + str(segments_count) + " segments") return fixed_segment_index = 1 if (dependent_segment_index == 0) else 0 + dependent_segment = self._segments[dependent_segment_index] fixed_transformed_end_location = None fixed_transformed_end_direction = None @@ -267,7 +280,7 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): mean_end_locations = [] mean_end_directions = [] # unit mean untransformed directions # distance above which distance weighting is zero - far_distance = self._max_distance + minimum_gap + far_distance = self._max_distance + gap_distance for s, segment in enumerate(self._segments): distances = [] # min transformed distance from end points of this segment to linkable end points in other max_distance = None @@ -315,12 +328,16 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): projection = dot(coordinates, mean_end_direction) if projection > max_projection: max_projection = projection - # add half minimum gap to each side - offset = max_projection - mean_projection + 0.5 * minimum_gap + offset = max_projection - mean_projection + # add gap_distance to fixed side + if s == fixed_segment_index: + offset += gap_distance mean_end_locations.append(add(mean_coordinates, mult(mean_end_direction, offset))) + if not phase1_align: + break # not transforming here and no need for multiple iterations + # get angle axis transformation of dependent direction onto fixed direction - dependent_segment = self._segments[dependent_segment_index] rotated_mean_end_directions = [ matrix_vector_mult(initial_rotation_matrix[s], mean_end_directions[s]) for s in range(2)] fixed_transformed_end_direction = rotated_mean_end_directions[fixed_segment_index] @@ -344,14 +361,18 @@ def auto_align_segment(self, dependent_segment_index, minimum_gap=0.0): translation = sub(fixed_transformed_end_location, dependent_rotated_end_location) dependent_segment.set_translation(translation, notify=False) - # first stage only - # dependent_segment.set_translation(translation) # force notification - # return + if not phase_2_optimize: + if phase1_align: + dependent_segment.set_translation(translation) # force notification + return # optimise rotation and translation in plane - centre = fixed_transformed_end_location - axis3 = fixed_transformed_end_direction + translation = dependent_segment.get_translation() + rotation_matrix = euler_to_rotation_matrix(dependent_segment.get_rotation_radians()) + centre = add(matrix_vector_mult(rotation_matrix, mean_end_locations[dependent_segment_index]), translation) + axis3 = [-c for c in matrix_vector_mult(rotation_matrix, mean_end_directions[dependent_segment_index])] + # get 2 orthogonal axes for translations, scaled by max_distance so parameter scale similar to rotation radians: axis1 = cross([1.0, 0.0, 0.0], axis3) if magnitude(axis1) < 0.1: @@ -374,10 +395,11 @@ def links_objective(rotation_translation, *args): return score initial_rotation_translation = [0.0, 0.0, 0.0] + # 0.75 ~ 43 degrees res = minimize(links_objective, initial_rotation_translation, args=(), method='Nelder-Mead', # method='Powell', - bounds=[(-0.5, 0.5), (-0.5, 0.5), (-0.5, 0.5)]) # , tol=TOL) + bounds=[(-0.75, 0.75), (-0.75, 0.75), (-0.75, 0.75)]) # , tol=TOL) if res.success: links_objective(res.x) # to ensure the last values are converted to rotation and translation # this will invoke build_links and build_link_objects: @@ -675,10 +697,17 @@ def set_link_locking_from_selection(self, lock: bool): return element_identifier = 1 for annotation_name, links in self._annotation_links.items(): + for annotation in self._annotations: + if annotation.get_name() == annotation_name: + break + else: + logger.error('Segmentation stitcher connect ' + self._name + + ': No annotation of name ' + annotation_name) + continue for link in links: link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() - if link_selected: - link['lock'] = lock + node_id0, node_id1 = link['node_identifiers'] + self.set_linked_nodes(annotation, node_id0, node_id1, lock=True) element_identifier += 1 def add_locked_links_to_selection(self): diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 864da5b..75a00a8 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -154,9 +154,9 @@ def test_align_stitch_vagus1(self): expected_annotation_links12 = { "Fascicle": [ {'lock': False, - 'node identifiers': [38, 25]}, + 'node identifiers': [22, 15]}, {'lock': False, - 'node identifiers': [22, 15]}], + 'node identifiers': [38, 25]}], "left vagus X nerve trunk": [ {'lock': False, 'node identifiers': [11, 1]}]} @@ -206,5 +206,31 @@ def test_align_stitch_vagus1(self): marker_datapoint_group = marker.getNodesetGroup(datapoints) self.assertEqual(marker_datapoint_group.getSize(), 5) + # try auto-align with gap in 2 stages + + segments[0].set_rotation_degrees([0.0, 0.0, 0.0]) + segments[0].set_translation([0.0, 0.0, 0.0]) + segments[1].set_rotation_degrees([0.0, -10.0, -60.0]) + segments[1].set_translation([5.0, 0.0, 0.0]) + segments[2].set_rotation_degrees([0.0, 0.0, 40.0]) + segments[2].set_translation([10.0, 0.0, 0.5]) + + connection12.auto_align_segment(1, phase1_align=True, gap_distance=0.1, phase_2_optimize=False) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [2.7968079813220417, -7.433708312768542, 39.583915044651825], rotation, delta=TOL) + assertAlmostEqualList(self, [9.734631815723224, -0.028181186581394506, 0.505539399215602], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) + + connection12.auto_align_segment(1, phase1_align=False, gap_distance=0.1, phase_2_optimize=True) + rotation = segments[2].get_rotation_degrees() + translation = segments[2].get_translation() + assertAlmostEqualList(self, [1.1774294709982658, -7.223345962981031, -3.1504154683525827], rotation, delta=TOL) + assertAlmostEqualList(self, [9.735859443921962, -0.003902802918894957, 0.4936970140282092], translation, delta=TOL) + annotation_links12 = connection12.get_annotation_links() + self.assertEqual(expected_annotation_links12, annotation_links12) + + if __name__ == "__main__": unittest.main() From 5d3670880ccdb5a3e2b5ea75cf33d3769afe9ed9 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 2 Oct 2025 16:26:35 +1300 Subject: [PATCH 13/24] Fix link locking from selection --- src/segmentationstitcher/connection.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index ce5d4ea..c0fab0e 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -697,17 +697,10 @@ def set_link_locking_from_selection(self, lock: bool): return element_identifier = 1 for annotation_name, links in self._annotation_links.items(): - for annotation in self._annotations: - if annotation.get_name() == annotation_name: - break - else: - logger.error('Segmentation stitcher connect ' + self._name + - ': No annotation of name ' + annotation_name) - continue for link in links: link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() - node_id0, node_id1 = link['node_identifiers'] - self.set_linked_nodes(annotation, node_id0, node_id1, lock=True) + if link_selected: + link['lock'] = lock element_identifier += 1 def add_locked_links_to_selection(self): From cb775398566f7f4512b8c2e71ac7a4ee96ca62f1 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 2 Oct 2025 17:03:43 +1300 Subject: [PATCH 14/24] Fix best fit line side direction --- src/segmentationstitcher/segment.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index fc0c6b9..4be7ecd 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -413,10 +413,11 @@ def create_end_point_directions(self, annotations, max_distance): self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, radius_direction) direction1 = sub(end_x, start_x) - axis = [1.0, 0.0, 0.0] - if dot(normalize(direction1), axis) < 0.1: - axis = [0.0, 1.0, 0.0] - direction2 = set_magnitude(cross(axis, direction1), mean_r) + norm_direction1 = normalize(direction1) + for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: + if math.fabs(dot(norm_direction1, side)) < 0.1: + break + direction2 = set_magnitude(cross(side, direction1), mean_r) direction3 = set_magnitude(cross(direction1, direction2), mean_r) self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, direction1 + direction2 + direction3) From e912deaa5f9d92929ce15225a5a034a3dcc2b262 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 28 Nov 2025 15:10:31 +1300 Subject: [PATCH 15/24] Handle zero radius and single point paths --- src/segmentationstitcher/segment.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 4be7ecd..d1d0135 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -413,12 +413,15 @@ def create_end_point_directions(self, annotations, max_distance): self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, radius_direction) direction1 = sub(end_x, start_x) - norm_direction1 = normalize(direction1) - for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: - if math.fabs(dot(norm_direction1, side)) < 0.1: - break - direction2 = set_magnitude(cross(side, direction1), mean_r) - direction3 = set_magnitude(cross(direction1, direction2), mean_r) + if (magnitude(direction1) > 0.0) and (mean_r > 0.0): + norm_direction1 = normalize(direction1) + for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: + if math.fabs(dot(norm_direction1, side)) < 0.1: + break + direction2 = set_magnitude(cross(side, direction1), mean_r) + direction3 = set_magnitude(cross(direction1, direction2), mean_r) + else: + direction2 = direction3 = [0.0, 0.0, 0.0] self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, direction1 + direction2 + direction3) From 7803059d2ea6c67b1e5e13dc3e7869e4efd9908b Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 16 Jan 2026 14:20:05 +1300 Subject: [PATCH 16/24] Allow links between different annotations in same network group Use second segment annotation for these mixed links --- src/segmentationstitcher/annotation.py | 16 ++++++++++++++++ src/segmentationstitcher/connection.py | 14 ++++++++------ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 5d0d057..c747d77 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -37,6 +37,9 @@ def get_lower_name(self): def is_connectable(self): return self in (self.INDEPENDENT_NETWORK, self.NETWORK_GROUP_1, self.NETWORK_GROUP_2) + def is_connectable_different_annotation(self): + return self.is_connectable() and not (self == self.INDEPENDENT_NETWORK) + class Annotation: """ @@ -135,6 +138,19 @@ def set_term(self, term): assert self._term is None self._term = term + def is_connectable_with(self, other_annotation): + """ + Query whether ends annotated with self and other_annotation can be connected. + :param other_annotation: Annotation Annotation object. + :return: True if self and other_annotation are allowed to be connected by a link. + """ + if self._category.is_connectable() and (self._category == other_annotation.get_category()): + if other_annotation is self: + return True + elif self._category.is_connectable_different_annotation(): + return True + return False + def region_get_annotations(region, network_group1_keywords, network_group2_keywords, term_keywords): """ diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index c0fab0e..3068921 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -472,10 +472,10 @@ def build_links(self, build_link_objects=True): base_scores1 = [] for index1, end_point_data1 in enumerate(sorted_end_point_data1): node_id1, coordinates1, direction1, area1, annotation1 = end_point_data1 - # presently only allow links between same annotation even within network group - if annotation0 != annotation1: + if not annotation0.is_connectable_with(annotation1): + # use worst score for end points which cannot connect base_scores1.append(worst_base_score) - continue # end points have different annotation + continue dot_directions = dot(direction0, direction1) # -1.0 if perfectly pointing at each other # if dot_directions > 0.2: # arbitrary factor # base_scores1.append(worst_base_score) @@ -578,9 +578,10 @@ def get_minimums_ratio(scores): if best_score is not None: end_point_data0 = sorted_end_point_data0[best_indexes[0]] node_id0 = end_point_data0[0] - annotation = end_point_data0[4] end_point_data1 = sorted_end_point_data1[best_indexes[1]] node_id1 = end_point_data1[0] + # use annotation1 as it may be a branch which should logically own the link + annotation = end_point_data1[4] self.set_linked_nodes(annotation, node_id0, node_id1, lock) # print("Link nodes", node_id0, node_id1, "score", best_score, "area", best_area, end_point_data0[-1].get_name()) end_point_data0[3] -= best_area @@ -672,8 +673,9 @@ def link_and_lock_selected_ends(self): new_links_count = 0 for node_id0, annotation0 in zip(end_node_identifiers0, end_annotations0): for node_id1, annotation1 in zip(end_node_identifiers1, end_annotations1): - if annotation0 == annotation1: - self.set_linked_nodes(annotation0, node_id0, node_id1, lock=True) + if annotation0.is_connectable_with(annotation1): + # use annotation1 as it may be a branch which should logically own the link + self.set_linked_nodes(annotation1, node_id0, node_id1, lock=True) new_links_count += 1 if new_links_count: self.build_links() From dcc323eac8497d7746e24580fb8ab62ca1c17418 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 12 Feb 2026 13:45:42 +1300 Subject: [PATCH 17/24] Add interior end points where annotation changes Add remove selected links function --- src/segmentationstitcher/connection.py | 45 ++++- src/segmentationstitcher/segment.py | 263 +++++++++++++++---------- src/segmentationstitcher/stitcher.py | 42 +--- tests/resources/vagus-segment2.exf | 19 +- tests/test_vagus.py | 22 ++- 5 files changed, 231 insertions(+), 160 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 3068921..2b8bd49 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -173,7 +173,7 @@ def get_segments(self): def _segment_transformation_change(self, segment): self.build_links() - self.update_annotation_category_groups(self._annotations) + self.update_annotation_category_groups() def set_linked_nodes(self, annotation, node_id0, node_id1, lock=False): """ @@ -421,12 +421,12 @@ def build_links(self, build_link_objects=True): total_score = 0.0 # remember locked tuples of linked nodes to re-attach in algorithm below - locked_node_identifiers = set() + locked_node_identifiers_list = [] annotation_names = list(self._annotation_links.keys()) for annotation_name in annotation_names: for link in self._annotation_links[annotation_name]: if link['lock']: - locked_node_identifiers.add(tuple(link['node identifiers'])) + locked_node_identifiers_list.append(tuple(link['node identifiers'])) self._annotation_links = {} # filter, transform and sort end point data from largest to smallest radius @@ -543,11 +543,11 @@ def get_minimums_ratio(scores): area = min(area0, area1) indexes = (index0, index1) node_identifiers = (node_id0, node_id1) - if node_identifiers in locked_node_identifiers: + if node_identifiers in locked_node_identifiers_list: best_area = max(min_area, area) # don't want area to get negative best_nonexclusive_score = best_score = base_score / math.sqrt(best_area) best_indexes = indexes - locked_node_identifiers.remove(node_identifiers) + locked_node_identifiers_list.remove(node_identifiers) lock = True break else: @@ -679,10 +679,38 @@ def link_and_lock_selected_ends(self): new_links_count += 1 if new_links_count: self.build_links() - self.update_annotation_category_groups(self._annotations) + self.update_annotation_category_groups() else: logger.warning('Connection ' + self._name + '. Link and lock selected ends. No valid links exist') + def remove_selected_links(self): + """ + Unlock and remove links corresponding to selected visualization elements in this connection. + This does not prevent links from being automatically re-found after either segment is moved. + """ + root_scene = self._region.getRoot().getScene() + root_selection_group = root_scene.getSelectionField().castGroup() + if not root_selection_group.isValid(): + return + fieldmodule = self._region.getFieldmodule() + mesh1d = fieldmodule.findMeshByDimension(1) + selection_mesh_group = root_selection_group.getMeshGroup(mesh1d) + if not selection_mesh_group.isValid(): + return + element_identifier = 1 + for annotation_name, links in self._annotation_links.items(): + remove_link_indexes = [] + for link in links: + link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() + if link_selected: + remove_link_indexes.insert(0, element_identifier - 1) # reverse order + element_identifier += 1 + if remove_link_indexes: + for remove_link_index in remove_link_indexes: + del links[remove_link_index] + self._build_link_objects() + self.update_annotation_category_groups() + def set_link_locking_from_selection(self, lock: bool): """ Lock or unlock links for nodes matching any selected visualization elements. @@ -728,10 +756,9 @@ def add_locked_links_to_selection(self): selection_mesh_group.addElement(link_element) element_identifier += 1 - def update_annotation_category_groups(self, annotations): + def update_annotation_category_groups(self): """ Rebuild all annotation category groups e.g. after loading settings. - :param annotations: List of all annotations from stitcher. """ fieldmodule = self._region.getFieldmodule() with ChangeManager(fieldmodule): @@ -739,7 +766,7 @@ def update_annotation_category_groups(self, annotations): for category in AnnotationCategory: category_group = self.get_category_group(category) category_group.clear() - for annotation in annotations: + for annotation in self._annotations: annotation_group = self.get_annotation_group(annotation) if annotation_group: category_group = self.get_category_group(annotation.get_category()) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index d1d0135..dc9cc57 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -14,7 +14,8 @@ from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK -from segmentationstitcher.annotation import AnnotationCategory +from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations + import copy import json import logging @@ -30,11 +31,19 @@ class Segment: A segment of the segmentation data, generally from a separate image block. """ - def __init__(self, name, segmentation_file_name, root_region): + def __init__(self, name, segmentation_file_name, root_region, annotations, + network_group1_keywords, network_group2_keywords, term_keywords): """ :param name: Unique name of segment, usually derived from the file name. :param segmentation_file_name: Path and file name of raw segmentation file, in Zinc format. :param root_region: Zinc root region to create segment region under. + :param annotations: List of all annotations to add to. + :param network_group1_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 1, allowing them to be stitched together. + :param network_group2_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 2, allowing them to be stitched together. + :param term_keywords: List of term keywords; if found in group name, group is considered to mark the term for + the separately named group. """ self._name = name self._segmentation_file_name = segmentation_file_name @@ -45,6 +54,7 @@ def __init__(self, name, segmentation_file_name, root_region): # the raw region contains the original segment data which is not modified apart from building # groups to categorise data for stitching a visualisation, including selecting for display. self._raw_region = self._base_region.createChild("raw") + self._annotations = annotations result = self._raw_region.readFile(segmentation_file_name) assert result == RESULT_OK, \ "Could not read segmentation file " + segmentation_file_name @@ -82,9 +92,34 @@ def __init__(self, name, segmentation_file_name, root_region): group.setName(group_name) group.setManaged(True) self._element_node_ids, self._node_element_ids = self._get_element_node_maps() - self._end_node_ids = self._get_end_node_ids() + # following are determing by client call to create_end_point_directions() + self._end_node_ids = [] # mesh end points = nodes in only 1 element + self._interior_end_node_ids = [] # certain interior points where annotation changes, also connectable self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) + segment_annotations = region_get_annotations( + self._raw_region, network_group1_keywords, network_group2_keywords, term_keywords) + for segment_annotation in segment_annotations: + name = segment_annotation.get_name() + term = segment_annotation.get_term() + index = 0 + for annotation in self._annotations: + if annotation.get_name() == name: + existing_term = annotation.get_term() + if term != existing_term: + logger.warning("Segment " + name + ": Found existing annotation with name " + name + + " but existing term " + str(existing_term) + + " does not equal new term " + str(term)) + if term and (existing_term is None): + annotation.set_term(term) + break # exists already + if name > annotation.get_name(): + index += 1 + else: + # print("Add annotation name", name, "term", term, "dim", segment_annotation.get_dimension(), + # "category", segment_annotation.get_category()) + self._annotations.insert(index, segment_annotation) + def decode_settings(self, settings_in: dict): """ Update segment settings from JSON dict containing serialised settings. @@ -186,50 +221,40 @@ def _get_element_node_maps(self): element = elem_iter.next() return element_node_ids, node_element_ids - def _get_end_node_ids(self): - """ - :return: List of identifiers of nodes at end points i.e. in only 1 element. - """ - end_node_ids = [] - for node_id, element_ids in self._node_element_ids.items(): - if len(element_ids) == 1: - end_node_ids.append(node_id) - return end_node_ids - - def _element_id_to_group(self, element_id, annotations): + def _element_id_to_annotation(self, element_id): """ - Get the first Annotation zinc Group containing raw element of supplied identifier, prioritizing - any annotation group with term ids. - :param node_id: Identifier of [end] node to query. - :param annotations: Global list of all annotations. - :return: Zinc Group, MeshGroup or None, None if not found. + Get first Annotation containing raw element of supplied identifier, prioritizing annotations with term ids. + :param element_id: Identifier of element from raw region to query. + :return: Annotation or None if not found. """ element = self._raw_mesh1d.findElementByIdentifier(element_id) - best_group = None - best_mesh_group = None - for annotation in annotations: + for annotation in self._annotations: has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) - if best_group and not has_term: - continue - group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() - if group.isValid(): - mesh_group = group.getMeshGroup(self._raw_mesh1d) - if mesh_group.isValid() and mesh_group.containsElement(element): - best_group = group - best_mesh_group = mesh_group - if has_term: - # print("Found group", annotation.get_name(), annotation.get_term()) - break - return best_group, best_mesh_group + if has_term: + group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if group.isValid(): + mesh_group = group.getMeshGroup(self._raw_mesh1d) + if mesh_group.isValid() and mesh_group.containsElement(element): + return annotation + for annotation in self._annotations: + has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) + if not has_term: + group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if group.isValid(): + mesh_group = group.getMeshGroup(self._raw_mesh1d) + if mesh_group.isValid() and mesh_group.containsElement(element): + return annotation + return None - def _track_segment(self, start_node_id, start_element_id, + def _track_segment(self, start_node_id, start_element_id, annotation, max_length=None, min_element_count=None, min_aspect_ratio=None): """ Get coordinates and radii along segment from start_node_id in start_element_id, proceeding - first to other local node in element, until junction, end point or max_distance is tracked. + first to other local node in element, until junction, end point, annotation change or max_distance is tracked. Can finish earlier if min_element_count, min_aspect_ratio reached, but both must be reached if both in use. :param start_node_id: First node in path. :param start_element_id: Element containing start_node_id and another node to be added. + :param annotation: Annotation which element needs to be in. :param max_length: Maximum length to track from first node coordinates, or None for no limit. :param min_element_count: Minimum number of elements to track, or None to not test. :param min_aspect_ratio: Minimum ratio of length / mean radius to end tracking, or None to not test. @@ -241,7 +266,7 @@ def _track_segment(self, start_node_id, start_element_id, path_coordinates = [] path_radii = [] path_node_ids = [] - lastNode = False + is_last_node = False sum_r = 0.0 while True: if node_id in path_node_ids: @@ -259,7 +284,7 @@ def _track_segment(self, start_node_id, start_element_id, r = 1.0 path_radii.append(r) sum_r += r - if lastNode: + if is_last_node: break point_count = len(path_coordinates) if point_count > 1: @@ -279,29 +304,32 @@ def _track_segment(self, start_node_id, start_element_id, node_id = node_ids[1] if (node_ids[0] == node_id) else node_ids[0] element_ids = self._node_element_ids[node_id] if len(element_ids) != 2: - lastNode = True + is_last_node = True continue element_id = element_ids[1] if (element_ids[0] == element_id) else element_ids[0] + # Future: more efficient to check element is in annotation mesh group? + next_annotation = self._element_id_to_annotation(element_id) + if next_annotation != annotation: + is_last_node = True return path_coordinates, path_radii, path_node_ids, element_id - def _track_path(self, end_node_id, annotations, max_length=None): + def _track_path(self, start_node_id, start_element_id, max_length=None): """ - Get coordinates and radii along path from end_node_id, continuing along - branches if in similar direction. - :param end_node_id: End node identifier to track from. Must be in only one element. - :param annotations: Global list of all annotations. + Get coordinates and radii along path from start_node_id, across start_element_id, continuing along branches if + in a similar direction. Stops at another end point (1 parent element) or annotation change. + :param start_node_id: Start node identifier to track from. + :param start_element_id: Start element identifier to track across. :param max_length: Maximum length to track along, or None for no limit. - :return: coordinates list, radius list, path node ids, path group, start_x, end_x, mean_r + :return: coordinates list, radius list, path node ids, path annotation, start_x, end_x, mean_r """ - element_ids = self._node_element_ids[end_node_id] - assert len(element_ids) == 1 - path_group = self._element_id_to_group(element_ids[0], annotations)[0] + path_annotation = self._element_id_to_annotation(start_element_id) path_coordinates = [] path_radii = [] path_node_ids = [] path_mean_r = None - stop_node_id = end_node_id + stop_node_id = start_node_id + element_ids = [start_element_id] stop_element_id = None start_x = None end_x = None @@ -323,11 +351,11 @@ def _track_path(self, end_node_id, annotations, max_length=None): for element_id in element_ids: if element_id == stop_element_id: continue - segment_group = self._element_id_to_group(element_id, annotations)[0] - if path_group and (segment_group != path_group): + segment_annotation = self._element_id_to_annotation(element_id) + if path_annotation and (segment_annotation != path_annotation): continue segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id = self._track_segment( - stop_node_id, element_id, + stop_node_id, element_id, path_annotation, max_length=max_length - length, min_element_count=min_element_count - element_count, min_aspect_ratio=min_aspect_ratio - aspect_ratio) @@ -376,54 +404,71 @@ def _track_path(self, end_node_id, annotations, max_length=None): aspect_ratio += add_path_length / add_path_mean_r # 2nd iteration of fit line removes outliers: start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.25)[0:3] - return path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r + return path_coordinates, path_radii, path_node_ids, path_annotation, start_x, end_x, mean_r - def create_end_point_directions(self, annotations, max_distance): + def create_end_point_directions(self, max_distance): """ Track mean directions of network end points and create working objects for visualisation. - :param annotations: Global list of all annotations. :param max_distance: Maximum length to track back from end point. Stored for link tolerance. """ - nodetemplate = self._working_datapoints.createNodetemplate() - nodetemplate.defineField(self._working_coordinates) - nodetemplate.defineField(self._working_radius_direction) - nodetemplate.defineField(self._working_best_fit_line_orientation) - fieldcache = self._working_fieldmodule.createFieldcache() - self._end_point_data = {} - for end_node_id in self._end_node_ids: - path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r =( - self._track_path(end_node_id, annotations, max_distance)) - # Future: want to extend length to be equivalent to path_coordinates - direction = sub(start_x, end_x) - annotation = None - annotation_group_name = path_group.getName() if path_group else None - if annotation_group_name: - for tmp_annotation in annotations: - if tmp_annotation.get_name() == annotation_group_name: - annotation = tmp_annotation + # following are determined here: + self._end_node_ids = [] # mesh end points = nodes in only 1 element + self._interior_end_node_ids = [] # certain interior points where annotation changes, also connectable + self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) + for node_id, element_ids in self._node_element_ids.items(): + element_count = len(element_ids) + if element_count == 1: + self._end_node_ids.append(node_id) + with ChangeManager(self._working_fieldmodule): + nodetemplate = self._working_datapoints.createNodetemplate() + nodetemplate.defineField(self._working_coordinates) + nodetemplate.defineField(self._working_radius_direction) + nodetemplate.defineField(self._working_best_fit_line_orientation) + fieldcache = self._working_fieldmodule.createFieldcache() + for end_node_id in self._end_node_ids: + end_element_id = self._node_element_ids[end_node_id][0] + while True: + path_coordinates, path_radii, path_node_ids, annotation, start_x, end_x, mean_r =( + self._track_path(end_node_id, end_element_id, max_distance)) + # Future: want to extend length to be equivalent to path_coordinates + direction = sub(start_x, end_x) + if not annotation: + print("No annotation group for node", end_node_id) + self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) + # set up visualization objects. End direction datapoints have same identifiers as raw end nodes + node = self._working_datapoints.createNode(end_node_id, nodetemplate) + fieldcache.setNode(node) + radius_direction = set_magnitude(direction, mean_r) + self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) + self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, + radius_direction) + direction1 = sub(end_x, start_x) + if (magnitude(direction1) > 0.0) and (mean_r > 0.0): + norm_direction1 = normalize(direction1) + for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: + if math.fabs(dot(norm_direction1, side)) < 0.1: + break + direction2 = set_magnitude(cross(side, direction1), mean_r) + direction3 = set_magnitude(cross(direction1, direction2), mean_r) + else: + direction2 = direction3 = [0.0, 0.0, 0.0] + self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, + direction1 + direction2 + direction3) + # determine if path ended on a change of annotation = interior end node + end_node_id = path_node_ids[-1] + end_element_id = None + element_ids = self._node_element_ids[end_node_id] + # simple 2 element junctions on nodes not already identified as interior + if (len(element_ids) != 2) or (end_node_id in self._interior_end_node_ids): break - else: - print("No annotation group for node", end_node_id) - self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) - # set up visualization objects. End direction datapoints have same identifiers as raw end nodes - node = self._working_datapoints.createNode(end_node_id, nodetemplate) - fieldcache.setNode(node) - radius_direction = set_magnitude(direction, mean_r) - self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) - self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, - radius_direction) - direction1 = sub(end_x, start_x) - if (magnitude(direction1) > 0.0) and (mean_r > 0.0): - norm_direction1 = normalize(direction1) - for side in [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]: - if math.fabs(dot(norm_direction1, side)) < 0.1: + for element_id in element_ids: + tmp_annotation = self._element_id_to_annotation(element_id) + if tmp_annotation != annotation: + end_element_id = element_id + self._interior_end_node_ids.append(end_node_id) + break + else: break - direction2 = set_magnitude(cross(side, direction1), mean_r) - direction3 = set_magnitude(cross(direction1, direction2), mean_r) - else: - direction2 = direction3 = [0.0, 0.0, 0.0] - self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, - direction1 + direction2 + direction3) def get_end_point_data(self): """ @@ -649,17 +694,16 @@ def update_annotation_category(self, annotation, old_category=AnnotationCategory group_add_group_local_contents(new_category_group, annotation_group) self._update_working_end_group() - def update_annotation_category_groups(self, annotations): + def update_annotation_category_groups(self): """ Rebuild all annotation category groups e.g. after loading settings. - :param annotations: List of all annotations from stitcher. """ with ChangeManager(self._raw_fieldmodule): # clear all category groups for category in AnnotationCategory: category_group = self.get_category_group(category) category_group.clear() - for annotation in annotations: + for annotation in self._annotations: annotation_group = self.get_annotation_group(annotation) if annotation_group: category_group = self.get_category_group(annotation.get_category()) @@ -700,8 +744,9 @@ def _update_working_end_group(self): def get_selected_end_points(self): """ - Get end points in selected elements in segment. - :return: List of node identifiers, list of annotations for each. + Get end points in selected elements in segment. This includes node points with a single parent element, and + certain interior node points where the annotation changed. + :return: List of node identifiers, list of annotations for each node. """ root_scene = self._base_region.getRoot().getScene() root_selection_group = root_scene.getSelectionField().castGroup() @@ -713,16 +758,18 @@ def get_selected_end_points(self): if not selection_mesh_group.isValid(): return [], [] node_ids = [] - annotations = [] - for node_id in self._end_node_ids: - element_id = self._node_element_ids[node_id][0] - element = selection_mesh_group.findElementByIdentifier(element_id) - if element.isValid(): - node_ids.append(node_id) - annotation = self._end_point_data[node_id][3] - annotations.append(annotation) - return node_ids, annotations - + node_annotations = [] + elementiterator = selection_mesh_group.createElementiterator() + element = elementiterator.next() + while element.isValid(): + element_id = element.getIdentifier() + for node_id in self._element_node_ids[element_id]: + if (node_id in self._end_node_ids) or (node_id in self._interior_end_node_ids): + if node_id not in node_ids: + node_ids.append(node_id) + node_annotations.append(self._end_point_data[node_id][3]) + element = elementiterator.next() + return node_ids, node_annotations def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): """ diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 44e7880..28084a2 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -13,7 +13,7 @@ from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.connection import Connection from segmentationstitcher.segment import Segment -from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations +from segmentationstitcher.annotation import AnnotationCategory import copy import logging @@ -54,7 +54,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._context = Context("Segmentation Stitcher") self._root_region = self._context.getDefaultRegion() self._stitch_region = self._root_region.createRegion() - self._annotations = [] + self._annotations = [] # note all segments and connections share this common list self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] self._connections = [] @@ -68,7 +68,8 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo for segmentation_file_name in self._segmentation_file_names: file_path = Path(segmentation_file_name) name = file_path.name - segment = Segment(name, segmentation_file_name, self._root_region) + segment = Segment(name, segmentation_file_name, self._root_region, self._annotations, + self._network_group1_keywords, self._network_group2_keywords, self._term_keywords) name_stem = file_path.stem used_endpoints_file_indexes = [] for ix, endpoints_file_name_stem in enumerate(unused_endpoints_file_name_stems): @@ -84,29 +85,6 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo else: zero_range_segments_count += 1 self._segments.append(segment) - segment_annotations = region_get_annotations( - segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, - self._term_keywords) - for segment_annotation in segment_annotations: - name = segment_annotation.get_name() - term = segment_annotation.get_term() - index = 0 - for annotation in self._annotations: - if annotation.get_name() == name: - existing_term = annotation.get_term() - if term != existing_term: - logger.warning("Segment " + name + ": Found existing annotation with name " + name + - " but existing term " + str(existing_term) + - " does not equal new term " + str(term)) - if term and (existing_term is None): - annotation.set_term(term) - break # exists already - if name > annotation.get_name(): - index += 1 - else: - # print("Add annotation name", name, "term", term, "dim", segment_annotation.get_dimension(), - # "category", segment_annotation.get_category()) - self._annotations.insert(index, segment_annotation) # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" for annotation in self._annotations: if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and @@ -121,8 +99,8 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo with HierarchicalChangeManager(self._root_region): self._max_distance = 0.25 * self._mean_segment_length for segment in self._segments: - segment.create_end_point_directions(self._annotations, self._max_distance) - segment.update_annotation_category_groups(self._annotations) + segment.create_end_point_directions(self._max_distance) + segment.update_annotation_category_groups() for annotation in self._annotations: annotation.set_category_change_callback(self._annotation_category_change) for endpoints_file_name in unused_endpoints_file_names: @@ -210,9 +188,9 @@ def decode_settings(self, settings_in: dict): with HierarchicalChangeManager(self._root_region): for segment in self._segments: - segment.update_annotation_category_groups(self._annotations) + segment.update_annotation_category_groups() for connection in self._connections: - connection.update_annotation_category_groups(self._annotations) + connection.update_annotation_category_groups() def encode_settings(self) -> dict: """ @@ -239,7 +217,7 @@ def _annotation_category_change(self, annotation, old_category): segment.update_annotation_category(annotation, old_category) for connection in self._connections: connection.build_links(self._max_distance) - connection.update_annotation_category_groups(self._annotations) + connection.update_annotation_category_groups() def get_annotations(self): return self._annotations @@ -265,7 +243,7 @@ def create_connection(self, segments, connection_settings={}): connection.decode_settings(connection_settings) self._connections.append(connection) connection.build_links() - connection.update_annotation_category_groups(self._annotations) + connection.update_annotation_category_groups() return connection def delete_connection(self, connection): diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index 03d340a..6ae51b0 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -381,6 +381,11 @@ Node: 74 3.870986066902640e-02 -5.955710515963688e-01 1.000000000000000e-02 +Node: 75 + 4.349351322583406e+00 + -1.214784316068148e-00 + -2.929821893185177e-01 + 3.000000000000000e-01 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -394,7 +399,7 @@ Shape. Dimension=0 3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 -Node: 75 +Node: 76 2.398918879869167e+00 -1.432597145502751e-01 2.241355737421549e-01 @@ -646,6 +651,9 @@ Element: 70 Element: 71 Nodes: 74 63 +Element: 72 + Nodes: + 11 75 Group name: 00001 !#nodeset nodes Node group: @@ -702,6 +710,13 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..35 +Group name: left A branch END +!#nodeset nodes +Node group: +11,75 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +72 Group name: left vagus X nerve trunk !#nodeset nodes Node group: @@ -712,4 +727,4 @@ Element group: Group name: marker !#nodeset datapoints Node group: -75 +76 diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 75a00a8..a45e789 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -35,7 +35,7 @@ def test_io_vagus1(self): assertAlmostEqualList(self, zero, segment12.get_translation(), delta=TOL) segment12.set_translation(new_translation) annotations1 = stitcher1.get_annotations() - self.assertEqual(7, len(annotations1)) + self.assertEqual(8, len(annotations1)) self.assertEqual("1.0.0", stitcher1.get_version()) annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) @@ -46,10 +46,14 @@ def test_io_vagus1(self): self.assertEqual("http://uri.interlex.org/base/ilx_0738426", annotation12.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_2, annotation12.get_category()) annotation15 = annotations1[4] - self.assertEqual("left vagus X nerve trunk", annotation15.get_name()) - self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation15.get_term()) + self.assertEqual("left A branch END", annotation15.get_name()) + self.assertIsNone(annotation15.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation15.get_category()) - annotation17 = annotations1[6] + annotation16 = annotations1[5] + self.assertEqual("left vagus X nerve trunk", annotation16.get_name()) + self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation16.get_term()) + self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation16.get_category()) + annotation17 = annotations1[7] self.assertEqual("unknown", annotation17.get_name()) self.assertEqual(AnnotationCategory.EXCLUDE, annotation17.get_category()) @@ -80,10 +84,10 @@ def test_io_vagus1(self): settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) - self.assertEqual(7, len(settings["annotations"])) + self.assertEqual(8, len(settings["annotations"])) self.assertEqual("1.0.0", settings["version"]) assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][6]["category"]) + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][7]["category"]) stitcher2 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) stitcher2.decode_settings(settings) @@ -91,8 +95,8 @@ def test_io_vagus1(self): segment22 = segments2[1] assertAlmostEqualList(self, new_translation, segment22.get_translation(), delta=TOL) annotations2 = stitcher2.get_annotations() - annotation27 = annotations2[6] - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation27.get_category()) + annotation28 = annotations2[7] + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation28.get_category()) def test_align_stitch_vagus1(self): """ @@ -191,7 +195,7 @@ def test_align_stitch_vagus1(self): mesh = fieldmodule.findMeshByDimension(1) minimums, maximums = evaluate_field_nodeset_range(coordinates, nodes) assertAlmostEqualList(self, [0.04678894233410661, -1.3448619475857166, -0.5849221355942552], minimums, delta=TOL) - assertAlmostEqualList(self, [13.528908286654149, 1.12292211593189, 1.0461133166576715], maximums, delta=TOL) + assertAlmostEqualList(self, [13.528908286654149, 1.12292211593189, 1.4370793304399627], maximums, delta=TOL) fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() self.assertTrue(fascicle.isValid()) From cc29781fa517bc1e60345956a13d307a9e44516f Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 12 Feb 2026 13:53:29 +1300 Subject: [PATCH 18/24] Comment out user end points tweaks --- src/segmentationstitcher/segment.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index dc9cc57..8e92632 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -177,14 +177,14 @@ def define_endpoints(self, endpoints_file_name): for control_point in control_points: marker_labels.append(control_point["label"]) x = control_point["position"] - if "C1L" in self._name: - x = [-1000.0 * x[1] - 4500.0, 1000.0 * x[0] + 2250.0, -1000.0 * x[2] + 11500.0] - elif any(s in self._name for s in ["T5L", "T6L"]): - x = [-1000.0 * x[1] - 9500.0, 1000.0 * x[0] + 5000.0, -1000.0 * x[2]] - elif any(s in self._name for s in ["C2L", "C4L", "T2L", "T3L", "T4L"]): - x = [-9.0 * x[1], 9.0 * x[0], -9.0 * x[2]] - else: - x = [9.0 * x[1], -9.0 * x[0], -9.0 * x[2]] + # if "C1L" in self._name: + # x = [-1000.0 * x[1] - 4500.0, 1000.0 * x[0] + 2250.0, -1000.0 * x[2] + 11500.0] + # elif any(s in self._name for s in ["T5L", "T6L"]): + # x = [-1000.0 * x[1] - 9500.0, 1000.0 * x[0] + 5000.0, -1000.0 * x[2]] + # elif any(s in self._name for s in ["C2L", "C4L", "T2L", "T3L", "T4L"]): + # x = [-9.0 * x[1], 9.0 * x[0], -9.0 * x[2]] + # else: + # x = [9.0 * x[1], -9.0 * x[0], -9.0 * x[2]] marker_positions.append(x) generate_datapoints(self._raw_region, marker_positions, field_names_and_values=[("marker_name", marker_labels)], From a19f33b3dbe3d13c0665ef407e00331bb40bce0a Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 13 Feb 2026 15:16:03 +1300 Subject: [PATCH 19/24] Improve association of annotation terms --- src/segmentationstitcher/annotation.py | 24 ++++++++++++++++++++---- src/segmentationstitcher/segment.py | 10 +++++----- src/segmentationstitcher/stitcher.py | 24 ++++++++++++------------ 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index c747d77..160f196 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -70,7 +70,7 @@ def decode_settings(self, settings_in: dict): settings_dimension = settings_in.get("dimension") if settings_dimension != self._dimension: logger.warning("Segmentation Stitcher. Annotation with name " + self._name, " term " + str(self._term) + - "was dimension " + str(settings_dimension), "in settings, is now " + str(self._dimension) + + " was dimension " + str(settings_dimension), "in settings, is now " + str(self._dimension) + ". Have input files changed?") settings_in["dimension"] = self._dimension # update current settings to gain new ones and override old ones @@ -196,7 +196,7 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo continue # ignore as these can never be valid annotation names if ' annotation.get_name(): + if annotation_name > annotation.get_name(): index += 1 else: - # print("Add annotation name", name, "term", term, "dim", segment_annotation.get_dimension(), + # print("Add annotation name", annotation_name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 28084a2..cb0be21 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -137,7 +137,7 @@ def decode_settings(self, settings_in: dict): break else: logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + - "in settings not found; ignoring. Have input files changed?") + " in settings not found; ignoring. Have input files changed?") if processed_count != len(self._annotations): for annotation in self._annotations: name = annotation.get_name() @@ -147,29 +147,29 @@ def decode_settings(self, settings_in: dict): break else: logger.warning("Segmentation Stitcher. Annotation with name " + name + " term " + str(term) + - "not found in settings; using defaults. Have input files changed?") + " not found in settings; using defaults. Have input files changed?") # update segment settings and warn about differences processed_count = 0 for segment_settings in settings["segments"]: - name = segment_settings["name"] + segment_name = segment_settings["name"] for segment in self._segments: - if segment.get_name() == name: + if segment.get_name() == segment_name: segment.decode_settings(segment_settings) processed_count += 1 break else: - print("WARNING: Segmentation Stitcher. Segment with name", name, - "in settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " in settings not found; ignoring. Have input files changed?") if processed_count != len(self._segments): for segment in self._segments: - name = segment.get_name() + segment_name = segment.get_name() for segment_settings in settings["segments"]: - if segment_settings["name"] == name: + if segment_settings["name"] == segment_name: break else: - print("WARNING: Segmentation Stitcher. Segment with name", name, - "not found in settings; using defaults. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " not found in settings; using defaults. Have input files changed?") # create connections from stitcher settings' connection serialisations assert len(self._connections) == 0, "Cannot decode connections after any exist" @@ -181,8 +181,8 @@ def decode_settings(self, settings_in: dict): connection_segments.append(segment) break else: - print("WARNING: Segmentation Stitcher. Segment with name", segment_name, - "in connection settings not found; ignoring. Have input files changed?") + logger.warning("Segmentation Stitcher. Segment with name " + segment_name + + " in connection settings not found; ignoring. Have input files changed?") if len(connection_segments) >= 2: connection = self.create_connection(connection_segments, connection_settings) From 9c34f48bd7ebc0dd9e614f818934cde653b64269 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 16 Feb 2026 16:43:03 +1300 Subject: [PATCH 20/24] Improve association of annotation terms 2 Fix remove links --- src/segmentationstitcher/annotation.py | 24 ++++++++++++++++++------ src/segmentationstitcher/connection.py | 4 +++- src/segmentationstitcher/segment.py | 5 ++--- tests/resources/vagus-segment1.exf | 4 ++-- tests/resources/vagus-segment2.exf | 4 ++-- tests/resources/vagus-segment3.exf | 4 ++-- tests/test_vagus.py | 4 ++-- 7 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 160f196..cb908da 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -138,6 +138,12 @@ def set_term(self, term): assert self._term is None self._term = term + def clear_term(self): + """ + Clear term to None, call in cases of mismatched terms for the same group name. + """ + self._term = None + def is_connectable_with(self, other_annotation): """ Query whether ends annotated with self and other_annotation can be connected. @@ -179,7 +185,9 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo "left cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794142", "right cervical vagus nerve": "http://uri.interlex.org/base/ilx_0794141", "left thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0787543", - "right thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0786664" + "right thoracic vagus nerve": "http://uri.interlex.org/base/ilx_0786664", + "left vagus x nerve trunk": "http://uri.interlex.org/base/ilx_0736691", + "right vagus x nerve trunk": "http://uri.interlex.org/base/ilx_0730515" } for group in groups: # clean up name to remove case and leading/trailing whitespace @@ -206,7 +214,7 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo if keyword in lower_name: category = AnnotationCategory.NETWORK_GROUP_2 break - term = known_terms.get(name) + term = known_terms.get(lower_name) annotation = Annotation(name, term, dimension, category) is_term = False if category in (AnnotationCategory.GENERAL, AnnotationCategory.EXCLUDE): @@ -237,9 +245,12 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo term = term_annotation.get_name() term_group = fieldmodule.findFieldByName(term).castGroup() dimension = term_annotation.get_dimension() + term_matched = False for annotation in sorted_annotations: if annotation.get_dimension() != dimension: continue + if annotation.get_category() == AnnotationCategory.EXCLUDE: + continue name = annotation.get_name() name_group = fieldmodule.findFieldByName(name).castGroup() if groups_have_same_local_contents(name_group, term_group): @@ -253,16 +264,17 @@ def region_get_annotations(region, network_group1_keywords, network_group2_keywo # logger.info("Segment " + segment_name + ": " + "Annotation name " + name + # " discovered term " + term + ".") annotation.set_term(term) - break + term_matched = True + # do not break to allow all groups with matching contents to get the term else: known_term = known_terms.get(name.lower()) if known_term == term: logger.warning("Segment " + segment_name + ": " + "Known annotation name " + name + " and term " + term + " groups differ. Using name group.") - break - else: + term_matched = True + if not term_matched: logger.warning("Segment " + segment_name + ": " + - ". Did not find matching annotation name for term" + term + ". Adding separate annotation.") + ". Did not find matching annotation name for term " + term + ". Adding separate annotation.") term_annotation.set_term(term) index = 0 for annotation in annotations: diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 2b8bd49..c0cd95c 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -700,11 +700,13 @@ def remove_selected_links(self): element_identifier = 1 for annotation_name, links in self._annotation_links.items(): remove_link_indexes = [] + index = 0 for link in links: link_selected = selection_mesh_group.findElementByIdentifier(element_identifier).isValid() if link_selected: - remove_link_indexes.insert(0, element_identifier - 1) # reverse order + remove_link_indexes.insert(0, index) # reverse order element_identifier += 1 + index += 1 if remove_link_indexes: for remove_link_index in remove_link_indexes: del links[remove_link_index] diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index eb01d73..ac81440 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -109,9 +109,8 @@ def __init__(self, name, segmentation_file_name, root_region, annotations, if term != existing_term: logger.warning("Segment " + name + ": Found existing annotation with name " + annotation_name + " but existing term " + str(existing_term) + - " does not equal new term " + str(term)) - if term and (existing_term is None): - annotation.set_term(term) + " does not equal new term " + str(term) + ". Clearing term") + annotation.clear_term() break # exists already if annotation_name > annotation.get_name(): index += 1 diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf index d34997a..0b8f0a4 100644 --- a/tests/resources/vagus-segment1.exf +++ b/tests/resources/vagus-segment1.exf @@ -770,14 +770,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..38,43..46 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 49..84 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 47..82 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..11 diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index 6ae51b0..f6ddc15 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -689,14 +689,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 11..35 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 39..74 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 36..71 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..11 diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf index 55e9809..7d5ed2f 100644 --- a/tests/resources/vagus-segment3.exf +++ b/tests/resources/vagus-segment3.exf @@ -632,14 +632,14 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 14..38 -Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +Group name: http://uri.interlex.org/base/ilx_0103892 !#nodeset nodes Node group: 43..68 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 40..65 -Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +Group name: http://uri.interlex.org/base/ilx_0736691 !#nodeset nodes Node group: 1..10 diff --git a/tests/test_vagus.py b/tests/test_vagus.py index a45e789..66cd109 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -39,7 +39,7 @@ def test_io_vagus1(self): self.assertEqual("1.0.0", stitcher1.get_version()) annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) - self.assertEqual("http://purl.obolibrary.org/obo/UBERON_0000124", annotation11.get_term()) + self.assertEqual("http://uri.interlex.org/base/ilx_0103892", annotation11.get_term()) self.assertEqual(AnnotationCategory.GENERAL, annotation11.get_category()) annotation12 = annotations1[1] self.assertEqual("Fascicle", annotation12.get_name()) @@ -51,7 +51,7 @@ def test_io_vagus1(self): self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation15.get_category()) annotation16 = annotations1[5] self.assertEqual("left vagus X nerve trunk", annotation16.get_name()) - self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation16.get_term()) + self.assertEqual('http://uri.interlex.org/base/ilx_0736691', annotation16.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation16.get_category()) annotation17 = annotations1[7] self.assertEqual("unknown", annotation17.get_name()) From 62da86b1202006be0d8f48a486a2d49b5f8b50f2 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 16 Feb 2026 17:44:25 +1300 Subject: [PATCH 21/24] Do not build links on loading existing connections This stops removed links from being refound --- src/segmentationstitcher/connection.py | 6 +++--- src/segmentationstitcher/stitcher.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index c0cd95c..62d194e 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -592,11 +592,11 @@ def get_minimums_ratio(scores): total_score += best_nonexclusive_score * best_area if build_link_objects: - self._build_link_objects() + self.build_link_objects() return total_score - def _build_link_objects(self): + def build_link_objects(self): """ Make link nodes/elements for visualisation. """ @@ -710,7 +710,7 @@ def remove_selected_links(self): if remove_link_indexes: for remove_link_index in remove_link_indexes: del links[remove_link_index] - self._build_link_objects() + self.build_link_objects() self.update_annotation_category_groups() def set_link_locking_from_selection(self, lock: bool): diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index cb0be21..6a190b2 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -184,7 +184,7 @@ def decode_settings(self, settings_in: dict): logger.warning("Segmentation Stitcher. Segment with name " + segment_name + " in connection settings not found; ignoring. Have input files changed?") if len(connection_segments) >= 2: - connection = self.create_connection(connection_segments, connection_settings) + connection = self.create_connection(connection_segments, connection_settings, build_links=False) with HierarchicalChangeManager(self._root_region): for segment in self._segments: @@ -222,10 +222,12 @@ def _annotation_category_change(self, annotation, old_category): def get_annotations(self): return self._annotations - def create_connection(self, segments, connection_settings={}): + def create_connection(self, segments, connection_settings={}, build_links=True): """ :param segments: List of 2 Stitcher Segment objects to connect. :param connection_settings: Optional serialisation of connection to read before building links. + :param build_links: If True (default) automatically build any links. If False, called when decoding settings, + only the link graphics are made for existing links, which is required to not rebuild removed links. :return: Connection object or None if invalid segments or connection between segments already exists """ if len(segments) != 2: @@ -242,7 +244,10 @@ def create_connection(self, segments, connection_settings={}): if connection_settings: connection.decode_settings(connection_settings) self._connections.append(connection) - connection.build_links() + if build_links: + connection.build_links() + else: + connection.build_link_objects() connection.update_annotation_category_groups() return connection From ee6f3f57fec6292c33ecd91883cb01bcfbc22e4b Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 18 Feb 2026 00:05:57 +1300 Subject: [PATCH 22/24] Support interior end points with multiple branches Prioritise end point annotations that are connectable --- src/segmentationstitcher/annotation.py | 6 ++ src/segmentationstitcher/segment.py | 97 ++++++++++++++++---------- 2 files changed, 68 insertions(+), 35 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index cb908da..efb00b8 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -144,6 +144,12 @@ def clear_term(self): """ self._term = None + def is_connectable(self): + """ + :return: True if annotation's category is connectable. + """ + return self._category.is_connectable() + def is_connectable_with(self, other_annotation): """ Query whether ends annotated with self and other_annotation can be connected. diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index ac81440..799d1b7 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -222,22 +222,28 @@ def _get_element_node_maps(self): def _element_id_to_annotation(self, element_id): """ - Get first Annotation containing raw element of supplied identifier, prioritizing annotations with term ids. + Get first Annotation containing raw element of supplied identifier, prioritizing connectable annotations + then those with term ids. :param element_id: Identifier of element from raw region to query. :return: Annotation or None if not found. """ element = self._raw_mesh1d.findElementByIdentifier(element_id) - for annotation in self._annotations: - has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) - if has_term: - group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() - if group.isValid(): - mesh_group = group.getMeshGroup(self._raw_mesh1d) - if mesh_group.isValid() and mesh_group.containsElement(element): - return annotation - for annotation in self._annotations: - has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) - if not has_term: + for priority in range(3): + for annotation in self._annotations: + is_connectable = annotation.is_connectable() + if priority == 0: + if not is_connectable: + continue + else: + if is_connectable: + continue + has_term = (annotation.get_term() is not None) and (not "http" in annotation.get_name()) + if priority == 1: + if not has_term: + continue + else: + if has_term: + continue group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() if group.isValid(): mesh_group = group.getMeshGroup(self._raw_mesh1d) @@ -259,7 +265,6 @@ def _track_segment(self, start_node_id, start_element_id, annotation, :param min_aspect_ratio: Minimum ratio of length / mean radius to end tracking, or None to not test. :return: coordinates list, radius list, node id list, endElementId """ - self._element_node_ids, self._node_element_ids node_id = start_node_id element_id = start_element_id path_coordinates = [] @@ -329,6 +334,8 @@ def _track_path(self, start_node_id, start_element_id, max_length=None): path_mean_r = None stop_node_id = start_node_id element_ids = [start_element_id] + # node_ids = self._element_node_ids[start_element_id] + # start_node_index = 0 if (node_ids[0] == start_node_id) else -1 stop_element_id = None start_x = None end_x = None @@ -350,9 +357,12 @@ def _track_path(self, start_node_id, start_element_id, max_length=None): for element_id in element_ids: if element_id == stop_element_id: continue - segment_annotation = self._element_id_to_annotation(element_id) - if path_annotation and (segment_annotation != path_annotation): + element_annotation = self._element_id_to_annotation(element_id) + if path_annotation and (element_annotation != path_annotation): continue + # node_ids = self._element_node_ids[element_id] + # if node_ids[start_node_index] != stop_node_id: + # continue # change of element orientation: stops if same-named branches eminate from a junction segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id = self._track_segment( stop_node_id, element_id, path_annotation, max_length=max_length - length, @@ -424,9 +434,18 @@ def create_end_point_directions(self, max_distance): nodetemplate.defineField(self._working_radius_direction) nodetemplate.defineField(self._working_best_fit_line_orientation) fieldcache = self._working_fieldmodule.createFieldcache() - for end_node_id in self._end_node_ids: - end_element_id = self._node_element_ids[end_node_id][0] - while True: + interior_end_node_element_ids = {} # map from interior end node to untracked element ids + for interior in (False, True): + for end_node_id in (sorted(interior_end_node_element_ids) if interior else self._end_node_ids): + if interior: + end_element_ids = interior_end_node_element_ids[end_node_id] + if len(end_element_ids) != 1: + continue # not a valid interior end node + self._interior_end_node_ids.append(end_node_id) + end_element_id = end_element_ids[0] + else: + end_element_id = self._node_element_ids[end_node_id][0] + path_coordinates, path_radii, path_node_ids, annotation, start_x, end_x, mean_r =( self._track_path(end_node_id, end_element_id, max_distance)) # Future: want to extend length to be equivalent to path_coordinates @@ -451,23 +470,31 @@ def create_end_point_directions(self, max_distance): direction3 = set_magnitude(cross(direction1, direction2), mean_r) else: direction2 = direction3 = [0.0, 0.0, 0.0] - self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, - direction1 + direction2 + direction3) - # determine if path ended on a change of annotation = interior end node - end_node_id = path_node_ids[-1] - end_element_id = None - element_ids = self._node_element_ids[end_node_id] - # simple 2 element junctions on nodes not already identified as interior - if (len(element_ids) != 2) or (end_node_id in self._interior_end_node_ids): - break - for element_id in element_ids: - tmp_annotation = self._element_id_to_annotation(element_id) - if tmp_annotation != annotation: - end_element_id = element_id - self._interior_end_node_ids.append(end_node_id) - break - else: - break + self._working_best_fit_line_orientation.setNodeParameters( + fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, direction1 + direction2 + direction3) + + if not interior: + stop_node_id = path_node_ids[-1] + if stop_node_id not in self._end_node_ids: + # determine the stop element for this path + stop_element_id = None + prev_node_id = path_node_ids[-2] + element_ids = self._node_element_ids[stop_node_id] + for element_id in element_ids: + node_ids = self._element_node_ids[element_id] + if prev_node_id in node_ids: + stop_element_id = element_id + break + # determine if path ended on a change of annotation = interior end node + for element_id in element_ids: + element_annotation = self._element_id_to_annotation(element_id) + if element_annotation != annotation: + end_element_ids = interior_end_node_element_ids.get(stop_node_id) + if not end_element_ids: + interior_end_node_element_ids[stop_node_id] = end_element_ids =\ + copy.copy(element_ids) + end_element_ids.remove(stop_element_id) + break def get_end_point_data(self): """ From cefd22ffb3534dfeaacfd5d866d07de60f64e960 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 18 Feb 2026 17:16:47 +1300 Subject: [PATCH 23/24] Support segment ignore orientation --- src/segmentationstitcher/segment.py | 13 +++++++++++ src/segmentationstitcher/stitcher.py | 19 ++++++++++++---- tests/resources/vagus-segment1.exf | 23 +++++++++++-------- tests/resources/vagus-segment2.exf | 20 +++------------- tests/resources/vagus-segment3.exf | 23 +++++++++++-------- tests/test_vagus.py | 34 ++++++++++++++++------------ 6 files changed, 75 insertions(+), 57 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 799d1b7..b791e87 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -68,6 +68,7 @@ def __init__(self, name, segmentation_file_name, root_region, annotations, group.setManaged(True) self._rotation = [0.0, 0.0, 0.0] self._translation = [0.0, 0.0, 0.0] + self._ignore_orientation = False self._transformation_change_callbacks = [] self._raw_fieldcache = self._raw_fieldmodule.createFieldcache() self._raw_coordinates = self._raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -130,6 +131,7 @@ def decode_settings(self, settings_in: dict): settings.update(settings_in) self._rotation = [math.radians(deg) for deg in settings["rotation"]] self._translation = settings["translation"] + self._ignore_orientation = settings["ignore orientation"] def encode_settings(self) -> dict: """ @@ -137,6 +139,7 @@ def encode_settings(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ settings = { + "ignore orientation": self._ignore_orientation, "name": self._name, "rotation": [math.degrees(rad) for rad in self._rotation], "translation": self._translation @@ -658,6 +661,16 @@ def set_translation(self, translation, notify=True): if notify: self._transformation_change() + def is_ignore_orientation(self): + return self._ignore_orientation + + def set_ignore_orientation(self, ignore_orientation: bool): + """ + :param ignore_orientation: If True, on export put all groups starting with 'orientation' in segment into a group + 'ignore orientation' for subsequent tools to ignore orientation. + """ + self._ignore_orientation = ignore_orientation + def translate(self, offset, notify=True): """ :param offset: 3 value to add to translation diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 6a190b2..de6c36b 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -59,7 +59,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._segments = [] self._connections = [] self._max_distance = 0.0 - self._version = "1.0.0" # increment when new settings added to migrate older serialised settings + self._version = "1.1.0" # increment when new settings added to migrate older serialised settings unused_endpoints_file_names = copy.copy(self._endpoints_file_names) unused_endpoints_file_name_stems = [Path(file_path).stem for file_path in unused_endpoints_file_names] with HierarchicalChangeManager(self._root_region): @@ -85,10 +85,12 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo else: zero_range_segments_count += 1 self._segments.append(segment) - # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" + # by default put all GENERAL annotations without terms into the EXCLUDE category, + # except "marker" and those starting with "orientation" for annotation in self._annotations: + annotation_name = annotation.get_name() if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and - (annotation.get_name() != "marker")): + ((annotation_name != "marker") and not annotation_name.startswith("orientation"))): # print("Exclude general annotation", annotation.get_name(), "with no term") annotation.set_category(AnnotationCategory.EXCLUDE) self._mean_segment_length = 1.0 @@ -123,7 +125,7 @@ def decode_settings(self, settings_in: dict): # migrate from integer version number to string "major#.minor#.patch#" if isinstance(settings_version, int): settings_version = settings["version"] = "1.0.0" - assert settings_version == "1.0.0" # future: migrate if version changes + # assert settings_version == "1.1.0" # future: migrate if version changes # update annotations and warn about differences processed_count = 0 @@ -418,6 +420,10 @@ def _output_segment_nodes_and_markers( segment_group = find_or_create_field_group(fieldmodule, segment.get_name()) segment_node_group = segment_group.getOrCreateNodesetGroup(nodes) segment_datapoint_group = segment_group.getOrCreateNodesetGroup(datapoints) + orientation_ignore_group = find_or_create_field_group(fieldmodule, "orientation ignore") \ + if segment.is_ignore_orientation() else None + orientation_ignore_nodeset_group =\ + orientation_ignore_group.getOrCreateNodesetGroup(nodes) if orientation_ignore_group else None for raw_group in raw_groups: group_name = raw_group.getName() groups = annotation_groups.get(group_name) @@ -425,7 +431,10 @@ def _output_segment_nodes_and_markers( raw_nodeset_group = raw_group.getNodesetGroup(raw_nodes) if raw_nodeset_group.isValid() and (raw_nodeset_group.getSize() > 0): raw_nodeset_groups.append(raw_nodeset_group) - nodeset_group_lists.append([group.getOrCreateNodesetGroup(nodes) for group in groups]) + nodeset_group_list = [group.getOrCreateNodesetGroup(nodes) for group in groups] + if orientation_ignore_group and group_name.startswith('orientation'): + nodeset_group_list.append(orientation_ignore_nodeset_group) + nodeset_group_lists.append(nodeset_group_list) raw_fieldcache = raw_fieldmodule.createFieldcache() raw_nodeiterator = raw_nodes.createNodeiterator() raw_node = raw_nodeiterator.next() diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf index 0b8f0a4..c3edbf6 100644 --- a/tests/resources/vagus-segment1.exf +++ b/tests/resources/vagus-segment1.exf @@ -431,6 +431,11 @@ Node: 84 -4.364982277527066e-01 -3.626326906335586e-01 1.000000000000000e-02 +Node: 85 + 2.532597965468619e+00 + 1.278581726279602e-01 + 5.305649151562682e-01 + 1.000000000000000e-02 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -444,18 +449,12 @@ Shape. Dimension=0 3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 -Node: 85 +Node: 1 9.014003757666196e-01 -3.826434512807377e-03 -1.393280944888287e-02 - "landmark 1" - 0.01 -Node: 86 - 2.532597965468619e+00 - 1.278581726279602e-01 - 5.305649151562682e-01 - orientation - 0.01 + "left level of inferior border of jugular foramen on the vagus nerve" + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -808,4 +807,8 @@ Element group: Group name: marker !#nodeset datapoints Node group: -85..86 +1 +Group name: orientation anterior +!#nodeset nodes +Node group: +85 diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index f6ddc15..a2e152f 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -386,25 +386,11 @@ Node: 75 -1.214784316068148e-00 -2.929821893185177e-01 3.000000000000000e-01 -!#nodeset datapoints -Define node template: node2 -Shape. Dimension=0 -#Fields=3 -1) coordinates, coordinate, rectangular cartesian, real, #Components=3 - x. #Values=1 (value) - y. #Values=1 (value) - z. #Values=1 (value) -2) marker_name, field, string, #Components=1 - 1. #Values=1 (value) -3) radius, field, rectangular cartesian, real, #Components=1 - 1. #Values=1 (value) -Node template: node2 Node: 76 2.398918879869167e+00 -1.432597145502751e-01 2.241355737421549e-01 - orientation - 0.01 + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -724,7 +710,7 @@ Node group: !#mesh mesh1d, dimension=1, nodeset=nodes Element group: 1..10 -Group name: marker -!#nodeset datapoints +Group name: orientation anterior +!#nodeset nodes Node group: 76 diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf index 7d5ed2f..cc25c37 100644 --- a/tests/resources/vagus-segment3.exf +++ b/tests/resources/vagus-segment3.exf @@ -351,6 +351,11 @@ Node: 68 -4.359756419261278e-01 -5.458579284260732e-01 1.000000000000000e-02 +Node: 69 + 3.000000000000000e+00 + 0.000000000000000e+00 + 7.000000000000000e-01 + 1.000000000000000e-02 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -364,18 +369,12 @@ Shape. Dimension=0 3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 -Node: 69 - 3.000000000000000e+00 - 0.000000000000000e+00 - 7.000000000000000e-01 - orientation - 0.01 -Node: 70 +Node: 1 1.599724956533351e+00 3.788960603141545e-03 4.423695723249146e-03 - "landmark 2" - 0.01 + "left level of superior border of the clavicle on the vagus nerve" + 1.000000000000000e-02 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -670,7 +669,11 @@ Element group: Group name: marker !#nodeset datapoints Node group: -69..70 +1 +Group name: orientation anterior +!#nodeset nodes +Node group: +69 Group name: unknown !#nodeset nodes Node group: diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 66cd109..9fec4f2 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -35,8 +35,8 @@ def test_io_vagus1(self): assertAlmostEqualList(self, zero, segment12.get_translation(), delta=TOL) segment12.set_translation(new_translation) annotations1 = stitcher1.get_annotations() - self.assertEqual(8, len(annotations1)) - self.assertEqual("1.0.0", stitcher1.get_version()) + self.assertEqual(9, len(annotations1)) + self.assertEqual("1.1.0", stitcher1.get_version()) annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) self.assertEqual("http://uri.interlex.org/base/ilx_0103892", annotation11.get_term()) @@ -54,8 +54,12 @@ def test_io_vagus1(self): self.assertEqual('http://uri.interlex.org/base/ilx_0736691', annotation16.get_term()) self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation16.get_category()) annotation17 = annotations1[7] - self.assertEqual("unknown", annotation17.get_name()) - self.assertEqual(AnnotationCategory.EXCLUDE, annotation17.get_category()) + self.assertEqual("orientation anterior", annotation17.get_name()) + self.assertEqual(AnnotationCategory.GENERAL, annotation17.get_category()) + annotation18 = annotations1[8] + self.assertEqual("unknown", annotation18.get_name()) + self.assertEqual(AnnotationCategory.EXCLUDE, annotation18.get_category()) + stitcher1.create_connection([segments1[0], segments1[1]]) connections = stitcher1.get_connections() @@ -73,10 +77,10 @@ def test_io_vagus1(self): self.assertEqual(1, exclude13_mesh_group.getSize()) self.assertEqual(26, general13_mesh_group.getSize()) self.assertFalse(indep13_mesh_group.isValid()) - annotation17_group = segment13.get_annotation_group(annotation17) - annotation17_mesh_group = annotation17_group.getMeshGroup(mesh1d) - self.assertEqual(1, annotation17_mesh_group.getSize()) - annotation17.set_category(AnnotationCategory.INDEPENDENT_NETWORK) + annotation18_group = segment13.get_annotation_group(annotation18) + annotation18_mesh_group = annotation18_group.getMeshGroup(mesh1d) + self.assertEqual(1, annotation18_mesh_group.getSize()) + annotation18.set_category(AnnotationCategory.INDEPENDENT_NETWORK) indep13_mesh_group = indep13_group.getMeshGroup(mesh1d) self.assertEqual(0, exclude13_mesh_group.getSize()) self.assertEqual(26, general13_mesh_group.getSize()) @@ -84,10 +88,10 @@ def test_io_vagus1(self): settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) - self.assertEqual(8, len(settings["annotations"])) - self.assertEqual("1.0.0", settings["version"]) + self.assertEqual(9, len(settings["annotations"])) + self.assertEqual("1.1.0", settings["version"]) assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][7]["category"]) + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][8]["category"]) stitcher2 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) stitcher2.decode_settings(settings) @@ -95,8 +99,8 @@ def test_io_vagus1(self): segment22 = segments2[1] assertAlmostEqualList(self, new_translation, segment22.get_translation(), delta=TOL) annotations2 = stitcher2.get_annotations() - annotation28 = annotations2[7] - self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation28.get_category()) + annotation29 = annotations2[8] + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation29.get_category()) def test_align_stitch_vagus1(self): """ @@ -186,7 +190,7 @@ def test_align_stitch_vagus1(self): output_region = stitcher.get_root_region().createRegion() stitcher.stitch(output_region) - self.assertEqual("1.0.0", stitcher.get_version()) + self.assertEqual("1.1.0", stitcher.get_version()) fieldmodule = output_region.getFieldmodule() coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement() @@ -208,7 +212,7 @@ def test_align_stitch_vagus1(self): marker = fieldmodule.findFieldByName("marker").castGroup() self.assertTrue(marker.isValid()) marker_datapoint_group = marker.getNodesetGroup(datapoints) - self.assertEqual(marker_datapoint_group.getSize(), 5) + self.assertEqual(marker_datapoint_group.getSize(), 2) # try auto-align with gap in 2 stages From ef7c0590fd851357b4b7960d8b9e277de113bd79 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 20 Feb 2026 12:36:22 +1300 Subject: [PATCH 24/24] Docstring fixes --- src/segmentationstitcher/annotation.py | 2 +- src/segmentationstitcher/connection.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index efb00b8..601feef 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -153,7 +153,7 @@ def is_connectable(self): def is_connectable_with(self, other_annotation): """ Query whether ends annotated with self and other_annotation can be connected. - :param other_annotation: Annotation Annotation object. + :param other_annotation: Another Annotation object. :return: True if self and other_annotation are allowed to be connected by a link. """ if self._category.is_connectable() and (self._category == other_annotation.get_category()): diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 62d194e..ed21707 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -503,8 +503,7 @@ def build_links(self, build_link_objects=True): def get_minimums_ratio(scores): """ Get ratio of lowest / next lowest score as measure of 'only option' for first link to end point. - :param scores: - :return: + :param scores: List of real values or None to ignore an entry. """ inf = float('inf') min1 = min2 = inf