Skip to content

Commit ee2e4a0

Browse files
author
dmoi
committed
fixing ft2 treebuilder ancestral reconstructions
1 parent b9936e0 commit ee2e4a0

2 files changed

Lines changed: 53 additions & 33 deletions

File tree

foldtree2/ft2treebuilder.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch_geometric
88
import multiprocessing as mp
99
import argparse
10+
import numpy as np
1011

1112
from foldtree2.src import encoder as ecdr
1213
from foldtree2.src import mono_decoders
@@ -17,11 +18,12 @@
1718
import tqdm
1819
import pandas as pd
1920
import os
20-
import ete3
21+
import ete3
22+
from scipy import sparse
2123
import sys
2224

2325
class treebuilder():
24-
def __init__ ( self , model , mafftmat = None , submat = None , raxml_path= None, charmaps=None, **kwargs ):
26+
def __init__ ( self , model , decoder_model = None, mafftmat = None , submat = None , raxml_path= None, charmaps=None, **kwargs ):
2527

2628
#make fasta is shifted by 1 and goes from 1-248 included
2729
#0x01 – 0xFF excluding > (0x3E), = (0x3D), < (0x3C), - (0x2D), Space (0x20), Carriage Return (0x0d) and Line Feed (0x0a)
@@ -68,13 +70,14 @@ def __init__ ( self , model , mafftmat = None , submat = None , raxml_path= Non
6870
self.revmap = { v:k for k,v in data['char_position_map'].items() }
6971
self.raxml_indices = data['raxml_char_position_map']
7072
self.rev_raxml_indices = { v:k for k,v in data['raxml_char_position_map'].items() }
71-
self.revmap_raxml = { v:k for k,v in data['raxml_char_position_map'].items() }
73+
self.revmap_raxml = self.raxml_indices
7274
self.raxmlchars = data['raxml_charset']
7375

7476
self.ordset = set([ ord(c) for c in self.alphabet ])
7577
#load pickled model
7678
self.model = model
7779
self.encoder = torch.load(model , map_location=torch.device('cpu') , weights_only=False)
80+
self.decoder = torch.load( decoder_model , map_location=torch.device('cpu') , weights_only=False ) if decoder_model is not None else None
7881
if 'aapropcsv' in kwargs and kwargs['aapropcsv'] is not None:
7982
self.converter = PDB2PyG(aapropcsv=kwargs['aapropcsv'])
8083
else:
@@ -87,8 +90,13 @@ def __init__ ( self , model , mafftmat = None , submat = None , raxml_path= Non
8790
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8891
self.encoder = self.encoder.to(self.device)
8992
self.encoder.device = self.device
93+
if self.decoder is not None:
94+
self.decoder = self.decoder.to(self.device)
95+
self.decoder.device = self.device
9096

9197
self.encoder.eval()
98+
if self.decoder is not None:
99+
self.decoder.eval()
92100

93101

94102
#load the mafftmat and submat matrices
@@ -402,48 +410,46 @@ def ancestralfasta2df(self, outfasta ):
402410

403411
def decoder_reconstruction( self, ords , verbose = False):
404412
data = HeteroData()
405-
z = self.encoder.vector_quantizer.embeddings( ords ).to('cpu')
413+
ords = ords.to( self.device )
414+
z = self.encoder.vector_quantizer.embeddings( ords ).to(self.device)
406415
edge_index = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ] , dtype = torch.long).T
407416
godnode_index = np.vstack([np.zeros(z.shape[0]), [ i for i in range(z.shape[0]) ] ])
408417
godnode_rev = np.vstack([ [ i for i in range(z.shape[0]) ] , np.zeros(z.shape[0]) ])
409418
#generate a backbone for the decoder
410419
data['res'].x = z
411-
backbone, backbone_rev = self.converter.get_backbone( z.shape[0] )
420+
data['res'].batch = torch.tensor([0 for i in range(z.shape[0])], dtype=torch.long)
421+
backbone, backbone_rev = self.converter.get_backbone( chainlen=z.shape[0] )
412422
backbone = sparse.csr_matrix(backbone)
413423
backbone_rev = sparse.csr_matrix(backbone_rev)
414424
backbone = self.converter.sparse2pairs(backbone)
415425
backbone_rev = self.converter.sparse2pairs(backbone_rev)
416426
positional_encoding = self.converter.get_positional_encoding( z.shape[0] , 256 )
417-
data['positions'].x = torch.tensor( positional_encoding, dtype=torch.float32)
418-
data['res'].x = torch.cat([data['res'].x, data['positions'].x], dim=1)
419-
data['res','backbone','res'].edge_index = torch.tensor(backbone, dtype=torch.long )
420-
data['res','backbonerev','res'].edge_index = torch.tensor(backbone, dtype=torch.long )
421-
422-
data['res','backbone','res'].edge_index = torch_geometric.utils.add_self_loops(data['res','backbone','res'].edge_index)[0]
423-
data['res','backbonerev','res'].edge_index = torch_geometric.utils.add_self_loops(data['res','backbonerev','res'].edge_index)[0]
424-
427+
data['positions'].x = torch.tensor( positional_encoding, dtype=torch.float32).to( self.device )
428+
data['res','backbone','res'].edge_index = torch.tensor(backbone, dtype=torch.long ).to( self.device )
429+
data['res','backbonerev','res'].edge_index = torch.tensor(backbone_rev, dtype=torch.long ).to( self.device )
425430
#add the godnode
426-
data['godnode'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
427-
data['godnode4decoder'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
428-
data['godnode4decoder', 'informs', 'res'].edge_index = torch.tensor(godnode_index, dtype=torch.long)
431+
data['godnode'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32).to( self.device )
432+
data['godnode4decoder'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32).to( self.device )
433+
data['godnode4decoder', 'informs', 'res'].edge_index = torch.tensor(godnode_index, dtype=torch.long).to( self.device )
429434

430435
# Repeat for godnode4decoder
431-
data['res', 'informs', 'godnode4decoder'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
432-
data['res', 'informs', 'godnode'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
433-
edge_index = edge_index.to( device )
436+
data['res', 'informs', 'godnode4decoder'].edge_index = torch.tensor(godnode_rev, dtype=torch.long).to( self.device )
437+
data['res', 'informs', 'godnode'].edge_index = torch.tensor(godnode_rev, dtype=torch.long).to( self.device )
438+
edge_index = edge_index.to( self.device )
434439
data = data.to( self.device )
435440
#decode_out = decoder(z , data.edge_index_dict[( 'res','contactPoints','res')] , data.edge_index_dict , poslossmod = 1 , neglossmod= 1 )
436441
allpairs = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ] , dtype = torch.long).T.to( self.device )
437-
out = decoder( data.x_dict, data.edge_index_dict , allpairs )
438-
recon_x = out['aa'].detach().to('cpu').numpy() if 'aa' in out else None
442+
out = self.decoder( data , allpairs )
443+
recon_x = out['aa'].detach().to('cpu') if 'aa' in out else None
439444
edge_probs = out['edge_probs'].detach().to('cpu').numpy() if 'edge_probs' in out else None
440-
amino_map = decoder.decoders['sequence_transformer'].amino_acid_indices
445+
amino_map = self.decoder.decoders['sequence_transformer'].amino_acid_indices
441446
revmap_aa = { v:k for k,v in amino_map.items() }
442447
edge_probs = edge_probs.reshape((z.shape[0], z.shape[0]))
443448
aastr = ''.join(revmap_aa[int(idx.item())] for idx in recon_x.argmax(dim=1) )
444-
out['aastr'] = aastr
445-
out['edge_probs'] = edge_probs
446-
return out
449+
res = {}
450+
res['aastr'] = aastr
451+
#res['edge_probs'] = edge_probs
452+
return res
447453

448454
def structs2tree(self, structs , outdir = None , ancestral = False , raxml_iterations = 20 , raxml_path = None , output_prefix = None , verbose = False , **kwargs ):
449455
#encode the structures
@@ -489,17 +495,25 @@ def structs2tree(self, structs , outdir = None , ancestral = False , raxml_itera
489495
ancestral_fasta = self.ancestral2fasta( ancestral_file )
490496
ancestral_df = self.ancestralfasta2df( ancestral_fasta )
491497
#decode the ancestral sequence
498+
print(ancestral_df.head())
492499
ords = ancestral_df.ord.values
500+
identifiers = ancestral_df.protid.values
501+
results = {}
493502
for l in tqdm.tqdm(range(ords.shape[0]), desc='decoding ancestral sequences'):
494-
res = self.decoder_reconstruction( ords[l] , verbose = verbose)
495-
for key,item in res.items():
496-
ancestral_df.loc[l , key] = item
503+
res = self.decoder_reconstruction( torch.tensor(ords[l] , dtype=torch.long).T , verbose = verbose)
504+
results.update({ identifiers[l] : res } )
505+
#create a new dataframe with the decoded sequences
506+
results = pd.DataFrame.from_dict( results , orient='index' )
507+
print('decoded ancestral sequences:')
508+
print(results.head())
509+
#merge with ancestral df
510+
ancestral_df = ancestral_df.merge( results , left_on='protid' , right_index=True , how='left' )
497511
#write the ancestral dataframe to a file
498512
ancestral_df.to_csv( ancestral_fasta.replace('.aastr.fasta' , '.csv') )
499513
#write out aastr to a fasta
500514
with open( ancestral_fasta , 'w') as f:
501515
for i in ancestral_df.index:
502-
f.write('>' + i + '\n' + ancestral_df.loc[i].aastr + '\n')
516+
f.write('>' + ancestral_df.loc[i].protid + '\n' + ancestral_df.loc[i].aastr + '\n')
503517
ancestral_fasta = ancestral_fasta
504518
else:
505519
ancestral_fasta = None
@@ -666,7 +680,7 @@ def main():
666680

667681

668682
# Create an instance of treebuilder
669-
tb = treebuilder(model=encoder_path, mafftmat=args.mafftmat, submat=args.submat , raxml_path=args.raxmlpath,
683+
tb = treebuilder(model=encoder_path, mafftmat=args.mafftmat, decoder_model=decoder_path, submat=args.submat , raxml_path=args.raxmlpath,
670684
aapropcsv=args.aapropcsv, maffttext2hex=args.maffttext2hex, maffthex2text=args.maffthex2text, ncores=args.ncores , charmaps=args.charmaps , device=args.device)
671685

672686
# Generate tree from structures using the provided options

foldtree2/src/pdbgraph.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,15 @@ def get_closest(chain):
213213
return contact_mat
214214

215215
@staticmethod
216-
def get_backbone(chain):
217-
backbone_mat = np.zeros((len(chain), len(chain)))
218-
backbone_rev_mat = np.zeros((len(chain), len(chain)))
216+
def get_backbone(chain = None , chainlen = None):
217+
if chainlen is not None:
218+
backbone_mat = np.zeros((chainlen, chainlen))
219+
backbone_rev_mat = np.zeros((chainlen, chainlen))
220+
elif chain is not None:
221+
backbone_mat = np.zeros((len(chain), len(chain)))
222+
backbone_rev_mat = np.zeros((len(chain), len(chain)))
223+
else:
224+
raise 'provide chain or chainlen'
219225
np.fill_diagonal(backbone_mat[1:], 1)
220226
np.fill_diagonal(backbone_rev_mat[:, 1:], 1)
221227
return backbone_mat, backbone_rev_mat

0 commit comments

Comments
 (0)