77import torch_geometric
88import multiprocessing as mp
99import argparse
10+ import numpy as np
1011
1112from foldtree2 .src import encoder as ecdr
1213from foldtree2 .src import mono_decoders
1718import tqdm
1819import pandas as pd
1920import os
20- import ete3
21+ import ete3
22+ from scipy import sparse
2123import sys
2224
2325class treebuilder ():
24- def __init__ ( self , model , mafftmat = None , submat = None , raxml_path = None , charmaps = None , ** kwargs ):
26+ def __init__ ( self , model , decoder_model = None , mafftmat = None , submat = None , raxml_path = None , charmaps = None , ** kwargs ):
2527
2628 #make fasta is shifted by 1 and goes from 1-248 included
2729 #0x01 – 0xFF excluding > (0x3E), = (0x3D), < (0x3C), - (0x2D), Space (0x20), Carriage Return (0x0d) and Line Feed (0x0a)
@@ -68,13 +70,14 @@ def __init__ ( self , model , mafftmat = None , submat = None , raxml_path= Non
6870 self .revmap = { v :k for k ,v in data ['char_position_map' ].items () }
6971 self .raxml_indices = data ['raxml_char_position_map' ]
7072 self .rev_raxml_indices = { v :k for k ,v in data ['raxml_char_position_map' ].items () }
71- self .revmap_raxml = { v : k for k , v in data [ 'raxml_char_position_map' ]. items () }
73+ self .revmap_raxml = self . raxml_indices
7274 self .raxmlchars = data ['raxml_charset' ]
7375
7476 self .ordset = set ([ ord (c ) for c in self .alphabet ])
7577 #load pickled model
7678 self .model = model
7779 self .encoder = torch .load (model , map_location = torch .device ('cpu' ) , weights_only = False )
80+ self .decoder = torch .load ( decoder_model , map_location = torch .device ('cpu' ) , weights_only = False ) if decoder_model is not None else None
7881 if 'aapropcsv' in kwargs and kwargs ['aapropcsv' ] is not None :
7982 self .converter = PDB2PyG (aapropcsv = kwargs ['aapropcsv' ])
8083 else :
@@ -87,8 +90,13 @@ def __init__ ( self , model , mafftmat = None , submat = None , raxml_path= Non
8790 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
8891 self .encoder = self .encoder .to (self .device )
8992 self .encoder .device = self .device
93+ if self .decoder is not None :
94+ self .decoder = self .decoder .to (self .device )
95+ self .decoder .device = self .device
9096
9197 self .encoder .eval ()
98+ if self .decoder is not None :
99+ self .decoder .eval ()
92100
93101
94102 #load the mafftmat and submat matrices
@@ -402,48 +410,46 @@ def ancestralfasta2df(self, outfasta ):
402410
403411 def decoder_reconstruction ( self , ords , verbose = False ):
404412 data = HeteroData ()
405- z = self .encoder .vector_quantizer .embeddings ( ords ).to ('cpu' )
413+ ords = ords .to ( self .device )
414+ z = self .encoder .vector_quantizer .embeddings ( ords ).to (self .device )
406415 edge_index = torch .tensor ( [ [i ,j ] for i in range (z .shape [0 ]) for j in range (z .shape [0 ]) ] , dtype = torch .long ).T
407416 godnode_index = np .vstack ([np .zeros (z .shape [0 ]), [ i for i in range (z .shape [0 ]) ] ])
408417 godnode_rev = np .vstack ([ [ i for i in range (z .shape [0 ]) ] , np .zeros (z .shape [0 ]) ])
409418 #generate a backbone for the decoder
410419 data ['res' ].x = z
411- backbone , backbone_rev = self .converter .get_backbone ( z .shape [0 ] )
420+ data ['res' ].batch = torch .tensor ([0 for i in range (z .shape [0 ])], dtype = torch .long )
421+ backbone , backbone_rev = self .converter .get_backbone ( chainlen = z .shape [0 ] )
412422 backbone = sparse .csr_matrix (backbone )
413423 backbone_rev = sparse .csr_matrix (backbone_rev )
414424 backbone = self .converter .sparse2pairs (backbone )
415425 backbone_rev = self .converter .sparse2pairs (backbone_rev )
416426 positional_encoding = self .converter .get_positional_encoding ( z .shape [0 ] , 256 )
417- data ['positions' ].x = torch .tensor ( positional_encoding , dtype = torch .float32 )
418- data ['res' ].x = torch .cat ([data ['res' ].x , data ['positions' ].x ], dim = 1 )
419- data ['res' ,'backbone' ,'res' ].edge_index = torch .tensor (backbone , dtype = torch .long )
420- data ['res' ,'backbonerev' ,'res' ].edge_index = torch .tensor (backbone , dtype = torch .long )
421-
422- data ['res' ,'backbone' ,'res' ].edge_index = torch_geometric .utils .add_self_loops (data ['res' ,'backbone' ,'res' ].edge_index )[0 ]
423- data ['res' ,'backbonerev' ,'res' ].edge_index = torch_geometric .utils .add_self_loops (data ['res' ,'backbonerev' ,'res' ].edge_index )[0 ]
424-
427+ data ['positions' ].x = torch .tensor ( positional_encoding , dtype = torch .float32 ).to ( self .device )
428+ data ['res' ,'backbone' ,'res' ].edge_index = torch .tensor (backbone , dtype = torch .long ).to ( self .device )
429+ data ['res' ,'backbonerev' ,'res' ].edge_index = torch .tensor (backbone_rev , dtype = torch .long ).to ( self .device )
425430 #add the godnode
426- data ['godnode' ].x = torch .tensor (np .ones ((1 ,5 )), dtype = torch .float32 )
427- data ['godnode4decoder' ].x = torch .tensor (np .ones ((1 ,5 )), dtype = torch .float32 )
428- data ['godnode4decoder' , 'informs' , 'res' ].edge_index = torch .tensor (godnode_index , dtype = torch .long )
431+ data ['godnode' ].x = torch .tensor (np .ones ((1 ,5 )), dtype = torch .float32 ). to ( self . device )
432+ data ['godnode4decoder' ].x = torch .tensor (np .ones ((1 ,5 )), dtype = torch .float32 ). to ( self . device )
433+ data ['godnode4decoder' , 'informs' , 'res' ].edge_index = torch .tensor (godnode_index , dtype = torch .long ). to ( self . device )
429434
430435 # Repeat for godnode4decoder
431- data ['res' , 'informs' , 'godnode4decoder' ].edge_index = torch .tensor (godnode_rev , dtype = torch .long )
432- data ['res' , 'informs' , 'godnode' ].edge_index = torch .tensor (godnode_rev , dtype = torch .long )
433- edge_index = edge_index .to ( device )
436+ data ['res' , 'informs' , 'godnode4decoder' ].edge_index = torch .tensor (godnode_rev , dtype = torch .long ). to ( self . device )
437+ data ['res' , 'informs' , 'godnode' ].edge_index = torch .tensor (godnode_rev , dtype = torch .long ). to ( self . device )
438+ edge_index = edge_index .to ( self . device )
434439 data = data .to ( self .device )
435440 #decode_out = decoder(z , data.edge_index_dict[( 'res','contactPoints','res')] , data.edge_index_dict , poslossmod = 1 , neglossmod= 1 )
436441 allpairs = torch .tensor ( [ [i ,j ] for i in range (z .shape [0 ]) for j in range (z .shape [0 ]) ] , dtype = torch .long ).T .to ( self .device )
437- out = decoder ( data . x_dict , data . edge_index_dict , allpairs )
438- recon_x = out ['aa' ].detach ().to ('cpu' ). numpy () if 'aa' in out else None
442+ out = self . decoder ( data , allpairs )
443+ recon_x = out ['aa' ].detach ().to ('cpu' ) if 'aa' in out else None
439444 edge_probs = out ['edge_probs' ].detach ().to ('cpu' ).numpy () if 'edge_probs' in out else None
440- amino_map = decoder .decoders ['sequence_transformer' ].amino_acid_indices
445+ amino_map = self . decoder .decoders ['sequence_transformer' ].amino_acid_indices
441446 revmap_aa = { v :k for k ,v in amino_map .items () }
442447 edge_probs = edge_probs .reshape ((z .shape [0 ], z .shape [0 ]))
443448 aastr = '' .join (revmap_aa [int (idx .item ())] for idx in recon_x .argmax (dim = 1 ) )
444- out ['aastr' ] = aastr
445- out ['edge_probs' ] = edge_probs
446- return out
449+ res = {}
450+ res ['aastr' ] = aastr
451+ #res['edge_probs'] = edge_probs
452+ return res
447453
448454 def structs2tree (self , structs , outdir = None , ancestral = False , raxml_iterations = 20 , raxml_path = None , output_prefix = None , verbose = False , ** kwargs ):
449455 #encode the structures
@@ -489,17 +495,25 @@ def structs2tree(self, structs , outdir = None , ancestral = False , raxml_itera
489495 ancestral_fasta = self .ancestral2fasta ( ancestral_file )
490496 ancestral_df = self .ancestralfasta2df ( ancestral_fasta )
491497 #decode the ancestral sequence
498+ print (ancestral_df .head ())
492499 ords = ancestral_df .ord .values
500+ identifiers = ancestral_df .protid .values
501+ results = {}
493502 for l in tqdm .tqdm (range (ords .shape [0 ]), desc = 'decoding ancestral sequences' ):
494- res = self .decoder_reconstruction ( ords [l ] , verbose = verbose )
495- for key ,item in res .items ():
496- ancestral_df .loc [l , key ] = item
503+ res = self .decoder_reconstruction ( torch .tensor (ords [l ] , dtype = torch .long ).T , verbose = verbose )
504+ results .update ({ identifiers [l ] : res } )
505+ #create a new dataframe with the decoded sequences
506+ results = pd .DataFrame .from_dict ( results , orient = 'index' )
507+ print ('decoded ancestral sequences:' )
508+ print (results .head ())
509+ #merge with ancestral df
510+ ancestral_df = ancestral_df .merge ( results , left_on = 'protid' , right_index = True , how = 'left' )
497511 #write the ancestral dataframe to a file
498512 ancestral_df .to_csv ( ancestral_fasta .replace ('.aastr.fasta' , '.csv' ) )
499513 #write out aastr to a fasta
500514 with open ( ancestral_fasta , 'w' ) as f :
501515 for i in ancestral_df .index :
502- f .write ('>' + i + '\n ' + ancestral_df .loc [i ].aastr + '\n ' )
516+ f .write ('>' + ancestral_df . loc [ i ]. protid + '\n ' + ancestral_df .loc [i ].aastr + '\n ' )
503517 ancestral_fasta = ancestral_fasta
504518 else :
505519 ancestral_fasta = None
@@ -666,7 +680,7 @@ def main():
666680
667681
668682 # Create an instance of treebuilder
669- tb = treebuilder (model = encoder_path , mafftmat = args .mafftmat , submat = args .submat , raxml_path = args .raxmlpath ,
683+ tb = treebuilder (model = encoder_path , mafftmat = args .mafftmat , decoder_model = decoder_path , submat = args .submat , raxml_path = args .raxmlpath ,
670684 aapropcsv = args .aapropcsv , maffttext2hex = args .maffttext2hex , maffthex2text = args .maffthex2text , ncores = args .ncores , charmaps = args .charmaps , device = args .device )
671685
672686 # Generate tree from structures using the provided options
0 commit comments