diff --git a/esm/utils/decoding.py b/esm/utils/decoding.py index 3865717..0ce967a 100644 --- a/esm/utils/decoding.py +++ b/esm/utils/decoding.py @@ -161,8 +161,15 @@ def decode_structure( plddt = plddt[0, 1:-1].detach().cpu() ptm = decoder_output.get("ptm", None) + if ptm is not None: + ptm = ptm.detach().cpu() # fix memory leak pae = decoder_output.get("predicted_aligned_error", None) + if pae is not None: + pae = pae.detach().cpu() # fix memory leak + # free decoder output + del decoder_output + torch.cuda.empty_cache() chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) chain = chain.infer_oxygen()