Skip to content

Commit 2d04d87

Browse files
Donglai Weiclaude
andcommitted
Use waterz.merge_function_to_scoring and dust_merge_from_region_graph
Remove 60+ lines of duplicated code from decode_waterz.py: - _merge_function_to_scoring + _RG/_SV constants → waterz.merge_function_to_scoring - _build_segment_counts → waterz._merge._build_segment_counts - 20-line region graph inversion boilerplate → waterz.dust_merge_from_region_graph - 3 bare print() → logger.info() Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fc4cc45 commit 2d04d87

2 files changed

Lines changed: 17 additions & 79 deletions

File tree

connectomics/decoding/decoders/waterz.py

Lines changed: 8 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -21,67 +21,13 @@
2121

2222
try:
2323
import waterz
24+
from waterz import merge_function_to_scoring, dust_merge_from_region_graph
2425

2526
WATERZ_AVAILABLE = True
2627
except ImportError:
2728
WATERZ_AVAILABLE = False
2829

2930

30-
# ---------------------------------------------------------------------------
31-
# Shorthand -> C++ scoring function conversion
32-
# ---------------------------------------------------------------------------
33-
34-
_RG = "RegionGraphType"
35-
_SV = "ScoreValue"
36-
37-
38-
def _merge_function_to_scoring(shorthand: str) -> str:
39-
"""Convert a shorthand merge function name to a C++ scoring type string.
40-
41-
Supported shorthands (examples):
42-
aff50_his256 -> OneMinus<HistogramQuantileAffinity<RG, 50, SV, 256>>
43-
aff85_his256 -> OneMinus<HistogramQuantileAffinity<RG, 85, SV, 256>>
44-
aff50_his0 -> OneMinus<QuantileAffinity<RG, 50, SV>>
45-
max10 -> OneMinus<MeanMaxKAffinity<RG, 10, SV>>
46-
*_ran255 -> One255Minus<...> instead of OneMinus<...>
47-
"""
48-
parts = {tok[:3]: tok[3:] for tok in shorthand.split("_")}
49-
use_255 = parts.get("ran") == "255"
50-
wrapper = "One255Minus" if use_255 else "OneMinus"
51-
52-
if "aff" in parts:
53-
quantile = parts["aff"]
54-
his_bins = parts.get("his", "0")
55-
if his_bins and his_bins != "0":
56-
inner = f"HistogramQuantileAffinity<{_RG}, {quantile}, {_SV}, {his_bins}>"
57-
else:
58-
inner = f"QuantileAffinity<{_RG}, {quantile}, {_SV}>"
59-
return f"{wrapper}<{inner}>"
60-
61-
if "max" in parts:
62-
k = parts["max"]
63-
inner = f"MeanMaxKAffinity<{_RG}, {k}, {_SV}>"
64-
return f"{wrapper}<{inner}>"
65-
66-
# If it already looks like a C++ type string, pass through
67-
if "<" in shorthand:
68-
return shorthand
69-
70-
raise ValueError(
71-
f"Unknown merge_function shorthand: '{shorthand}'. "
72-
"Expected format like 'aff50_his256', 'aff85_his256', 'max10', etc."
73-
)
74-
75-
76-
def _build_segment_counts(seg: np.ndarray) -> np.ndarray:
77-
"""Build a dense counts array indexed by segment id."""
78-
ids, cnts = np.unique(seg, return_counts=True)
79-
max_id = int(ids.max()) if len(ids) else 0
80-
counts = np.zeros(max_id + 1, dtype=np.uint64)
81-
counts[ids] = cnts
82-
return counts
83-
84-
8531
def decode_waterz(
8632
predictions: np.ndarray,
8733
thresholds: Union[float, Sequence[float]] = 0.3,
@@ -282,7 +228,7 @@ def decode_waterz(
282228
thresholds_list = [_to_u8(t) for t in thresholds_list]
283229

284230
# Convert shorthand merge function to C++ scoring function string
285-
scoring_function = _merge_function_to_scoring(merge_function)
231+
scoring_function = merge_function_to_scoring(merge_function)
286232

287233
aff_low = _to_u8(aff_threshold[0]) if is_uint8 else float(aff_threshold[0])
288234
aff_high = _to_u8(aff_threshold[1]) if is_uint8 else float(aff_threshold[1])
@@ -334,31 +280,15 @@ def decode_waterz(
334280
border_mask = xy_mean < border_threshold
335281
n_removed = int(border_mask.sum())
336282
seg[border_mask] = 0
337-
print(f"border_threshold={border_threshold}: zeroed {n_removed} voxels")
283+
logger.info("border_threshold=%s: zeroed %d voxels", border_threshold, n_removed)
338284

339285
# Size+affinity dust merge reusing the agglomeration's region graph
340286
# (accumulated histogram statistics, properly root-mapped IDs).
341-
# Invert OneMinus/One255Minus scores to raw affinities.
342287
if do_dust_merge:
343288
seg = seg.astype(np.uint64, copy=False)
344-
n_edges = len(region_graph)
345-
rg_affs = np.empty(n_edges, dtype=np.float32)
346-
id1 = np.empty(n_edges, dtype=np.uint64)
347-
id2 = np.empty(n_edges, dtype=np.uint64)
348-
score_max = 255.0 if is_uint8 else 1.0
349-
for idx, edge in enumerate(region_graph):
350-
rg_affs[idx] = score_max - float(edge["score"])
351-
id1[idx] = int(edge["u"])
352-
id2[idx] = int(edge["v"])
353-
if n_edges:
354-
np.clip(rg_affs, 0.0, score_max, out=rg_affs)
355-
order = np.argsort(rg_affs)[::-1]
356-
rg_affs = np.ascontiguousarray(rg_affs[order])
357-
id1 = np.ascontiguousarray(id1[order])
358-
id2 = np.ascontiguousarray(id2[order])
359-
counts = _build_segment_counts(seg)
360-
waterz.merge_segments(
361-
seg, rg_affs, id1, id2, counts,
289+
dust_merge_from_region_graph(
290+
seg, region_graph,
291+
is_uint8=is_uint8,
362292
size_th=dust_merge_size,
363293
weight_th=dust_merge_affinity,
364294
dust_th=dust_remove_size,
@@ -368,7 +298,7 @@ def decode_waterz(
368298
from .branch_merge import branch_merge as _branch_merge
369299

370300
n_before = len(np.unique(seg)) - (1 if 0 in seg else 0)
371-
print(f"branch_merge: starting on {n_before} segments")
301+
logger.info("branch_merge: starting on %d segments", n_before)
372302
seg = _branch_merge(
373303
seg,
374304
affinities=affs,
@@ -380,7 +310,7 @@ def decode_waterz(
380310
channel_order="zyx", # already converted above
381311
)
382312
n_after = len(np.unique(seg)) - (1 if 0 in seg else 0)
383-
print(f"branch_merge: {n_before} -> {n_after} segments")
313+
logger.info("branch_merge: %d -> %d segments", n_before, n_after)
384314

385315
if min_instance_size > 0:
386316
from ..utils import remove_small_instances

connectomics/training/lightning/test_pipeline.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,15 @@ def _process_decoding_postprocessing(
791791
logger.info(f" Min: {decoded_predictions.min()}")
792792
logger.info(f" Max: {decoded_predictions.max()}")
793793
logger.info(f" Instances: {decoded_predictions.max()} (max label)")
794-
logger.info(f" Unique IDs: {len(np.unique(decoded_predictions))}")
794+
max_summary_voxels = 100_000_000
795+
if decoded_predictions.size <= max_summary_voxels:
796+
logger.info(f" Unique IDs: {len(np.unique(decoded_predictions))}")
797+
else:
798+
logger.info(
799+
" Unique IDs: skipped for large volume (%d voxels > %d)",
800+
decoded_predictions.size,
801+
max_summary_voxels,
802+
)
795803

796804
if save_final_predictions:
797805
logger.info("[STAGE: Saving Final Predictions]")

0 commit comments

Comments
 (0)