2121
2222try :
2323 import waterz
24+ from waterz import merge_function_to_scoring , dust_merge_from_region_graph
2425
2526 WATERZ_AVAILABLE = True
2627except ImportError :
2728 WATERZ_AVAILABLE = False
2829
2930
30- # ---------------------------------------------------------------------------
31- # Shorthand -> C++ scoring function conversion
32- # ---------------------------------------------------------------------------
33-
34- _RG = "RegionGraphType"
35- _SV = "ScoreValue"
36-
37-
38- def _merge_function_to_scoring (shorthand : str ) -> str :
39- """Convert a shorthand merge function name to a C++ scoring type string.
40-
41- Supported shorthands (examples):
42- aff50_his256 -> OneMinus<HistogramQuantileAffinity<RG, 50, SV, 256>>
43- aff85_his256 -> OneMinus<HistogramQuantileAffinity<RG, 85, SV, 256>>
44- aff50_his0 -> OneMinus<QuantileAffinity<RG, 50, SV>>
45- max10 -> OneMinus<MeanMaxKAffinity<RG, 10, SV>>
46- *_ran255 -> One255Minus<...> instead of OneMinus<...>
47- """
48- parts = {tok [:3 ]: tok [3 :] for tok in shorthand .split ("_" )}
49- use_255 = parts .get ("ran" ) == "255"
50- wrapper = "One255Minus" if use_255 else "OneMinus"
51-
52- if "aff" in parts :
53- quantile = parts ["aff" ]
54- his_bins = parts .get ("his" , "0" )
55- if his_bins and his_bins != "0" :
56- inner = f"HistogramQuantileAffinity<{ _RG } , { quantile } , { _SV } , { his_bins } >"
57- else :
58- inner = f"QuantileAffinity<{ _RG } , { quantile } , { _SV } >"
59- return f"{ wrapper } <{ inner } >"
60-
61- if "max" in parts :
62- k = parts ["max" ]
63- inner = f"MeanMaxKAffinity<{ _RG } , { k } , { _SV } >"
64- return f"{ wrapper } <{ inner } >"
65-
66- # If it already looks like a C++ type string, pass through
67- if "<" in shorthand :
68- return shorthand
69-
70- raise ValueError (
71- f"Unknown merge_function shorthand: '{ shorthand } '. "
72- "Expected format like 'aff50_his256', 'aff85_his256', 'max10', etc."
73- )
74-
75-
76- def _build_segment_counts (seg : np .ndarray ) -> np .ndarray :
77- """Build a dense counts array indexed by segment id."""
78- ids , cnts = np .unique (seg , return_counts = True )
79- max_id = int (ids .max ()) if len (ids ) else 0
80- counts = np .zeros (max_id + 1 , dtype = np .uint64 )
81- counts [ids ] = cnts
82- return counts
83-
84-
8531def decode_waterz (
8632 predictions : np .ndarray ,
8733 thresholds : Union [float , Sequence [float ]] = 0.3 ,
@@ -282,7 +228,7 @@ def decode_waterz(
282228 thresholds_list = [_to_u8 (t ) for t in thresholds_list ]
283229
284230 # Convert shorthand merge function to C++ scoring function string
285- scoring_function = _merge_function_to_scoring (merge_function )
231+ scoring_function = merge_function_to_scoring (merge_function )
286232
287233 aff_low = _to_u8 (aff_threshold [0 ]) if is_uint8 else float (aff_threshold [0 ])
288234 aff_high = _to_u8 (aff_threshold [1 ]) if is_uint8 else float (aff_threshold [1 ])
@@ -334,31 +280,15 @@ def decode_waterz(
334280 border_mask = xy_mean < border_threshold
335281 n_removed = int (border_mask .sum ())
336282 seg [border_mask ] = 0
337- print ( f "border_threshold={ border_threshold } : zeroed { n_removed } voxels" )
283+ logger . info ( "border_threshold=%s : zeroed %d voxels" , border_threshold , n_removed )
338284
339285 # Size+affinity dust merge reusing the agglomeration's region graph
340286 # (accumulated histogram statistics, properly root-mapped IDs).
341- # Invert OneMinus/One255Minus scores to raw affinities.
342287 if do_dust_merge :
343288 seg = seg .astype (np .uint64 , copy = False )
344- n_edges = len (region_graph )
345- rg_affs = np .empty (n_edges , dtype = np .float32 )
346- id1 = np .empty (n_edges , dtype = np .uint64 )
347- id2 = np .empty (n_edges , dtype = np .uint64 )
348- score_max = 255.0 if is_uint8 else 1.0
349- for idx , edge in enumerate (region_graph ):
350- rg_affs [idx ] = score_max - float (edge ["score" ])
351- id1 [idx ] = int (edge ["u" ])
352- id2 [idx ] = int (edge ["v" ])
353- if n_edges :
354- np .clip (rg_affs , 0.0 , score_max , out = rg_affs )
355- order = np .argsort (rg_affs )[::- 1 ]
356- rg_affs = np .ascontiguousarray (rg_affs [order ])
357- id1 = np .ascontiguousarray (id1 [order ])
358- id2 = np .ascontiguousarray (id2 [order ])
359- counts = _build_segment_counts (seg )
360- waterz .merge_segments (
361- seg , rg_affs , id1 , id2 , counts ,
289+ dust_merge_from_region_graph (
290+ seg , region_graph ,
291+ is_uint8 = is_uint8 ,
362292 size_th = dust_merge_size ,
363293 weight_th = dust_merge_affinity ,
364294 dust_th = dust_remove_size ,
@@ -368,7 +298,7 @@ def decode_waterz(
368298 from .branch_merge import branch_merge as _branch_merge
369299
370300 n_before = len (np .unique (seg )) - (1 if 0 in seg else 0 )
371- print ( f "branch_merge: starting on { n_before } segments" )
301+ logger . info ( "branch_merge: starting on %d segments" , n_before )
372302 seg = _branch_merge (
373303 seg ,
374304 affinities = affs ,
@@ -380,7 +310,7 @@ def decode_waterz(
380310 channel_order = "zyx" , # already converted above
381311 )
382312 n_after = len (np .unique (seg )) - (1 if 0 in seg else 0 )
383- print ( f "branch_merge: { n_before } -> { n_after } segments" )
313+ logger . info ( "branch_merge: %d -> %d segments" , n_before , n_after )
384314
385315 if min_instance_size > 0 :
386316 from ..utils import remove_small_instances
0 commit comments