From 92b031517d149d360f54851de21a9797f36616a9 Mon Sep 17 00:00:00 2001 From: Justin Jung Date: Sat, 3 Jan 2026 20:28:05 -0800 Subject: [PATCH] Fix CUDA memory leak in decode_structures. pae and ptm tensors were not moved to CPU. This behavior is benign when batch size is 1, but when batch size > 1 causes memory to accumulate and OOM --- esm/utils/decoding.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 3865717..0ce967a 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -161,8 +161,15 @@ def decode_structure( plddt = plddt[0, 1:-1].detach().cpu() ptm = decoder_output.get("ptm", None) + if ptm is not None: + ptm = ptm.detach().cpu() # fix memory leak pae = decoder_output.get("predicted_aligned_error", None) + if pae is not None: + pae = pae.detach().cpu() # fix memory leak + # free decoder output + del decoder_output + torch.cuda.empty_cache() chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) chain = chain.infer_oxygen()