Skip to content

Commit 24d2cbf

Browse files
Donglai Weiclaude
andcommitted
saved_prediction_path always runs decoding (treat as raw prediction)
No fake TTA suffix needed — test pipeline checks saved_prediction_path directly to decide whether to run decoding. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1bce0b7 commit 24d2cbf

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

connectomics/training/lightning/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,7 @@ def _load_cached_predictions(
604604
pred = read_volume(str(pred_file), dataset="main")
605605
if pred.ndim < 4:
606606
pred = pred[np.newaxis, ...]
607-
# Return a TTA-like suffix so the pipeline treats this as
608-
# intermediate predictions (runs decoding), not final.
609-
return pred, True, "_tta_x1_prediction.h5"
607+
return pred, True, "_prediction.h5"
610608
else:
611609
raise FileNotFoundError(
612610
f"inference.saved_prediction_path not found: {pred_file}"

connectomics/training/lightning/test_pipeline.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -977,10 +977,15 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
977977
mode,
978978
)
979979

980-
# *_decoding*.h5 = already decoded (skip decoding)
981-
# *_prediction*.h5 with TTA suffix = intermediate (run decoding)
980+
# Determine whether loaded predictions need decoding:
981+
# - saved_prediction_path → always run decoding (it's raw affinity)
982+
# - *_decoding*.h5 → already decoded, skip decoding
983+
# - *_prediction*.h5 with TTA suffix → intermediate, run decoding
984+
# - other → final, skip decoding
985+
_saved_pred = getattr(getattr(module.cfg, "inference", None), "saved_prediction_path", "")
986+
_from_saved_path = bool(loaded_from_file and _saved_pred)
982987
_is_decoding_file = loaded_from_file and "_decoding" in (loaded_suffix or "")
983-
loaded_final_predictions = loaded_from_file and (
988+
loaded_final_predictions = loaded_from_file and not _from_saved_path and (
984989
_is_decoding_file or not is_tta_cache_suffix(loaded_suffix)
985990
)
986991
loaded_intermediate_predictions = loaded_from_file and not loaded_final_predictions

0 commit comments

Comments
 (0)