Skip to content

Commit cf037b6

Browse files
committed
add paircounts
1 parent 7bd0a51 commit cf037b6

10 files changed

Lines changed: 708 additions & 466 deletions

foldtree2/notebooks/benchmarks/alphabet_Information_content_benchmark.ipynb

Lines changed: 639 additions & 463 deletions
Large diffs are not rendered by default.

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,7 +1487,7 @@
14871487
},
14881488
{
14891489
"cell_type": "code",
1490-
"execution_count": 22,
1490+
"execution_count": null,
14911491
"id": "7d89d6d2",
14921492
"metadata": {},
14931493
"outputs": [],
@@ -1739,7 +1739,72 @@
17391739
"\t\n",
17401740
"\tencoder.train()\n",
17411741
"\tdecoder.train()\n",
1742-
"\t\n",
1742+
"\n",
1743+
"\n",
1744+
"\tfig2 = plt.figure(figsize=(6, 4))\n",
1745+
"\t#ramachandran plot\n",
1746+
"\tif 'angles' in sample_out and sample_out['angles'] is not None:\n",
1747+
"\t\ttrue_angles = data_sample['bondangles'].x.detach().cpu().numpy()\n",
1748+
"\t\tpred_angles = sample_out['angles'].detach().cpu().numpy()\n",
1749+
"\t\taxs2 = fig2.add_subplot(1, 1, 1)\n",
1750+
"\t\taxs2.scatter(true_angles[:, 0], true_angles[:, 1], label='True', color='blue', alpha=0.5, s=10)\n",
1751+
"\t\taxs2.scatter(pred_angles[:, 0], pred_angles[:, 1], label='Pred', color='red', alpha=0.5, s=10)\n",
1752+
"\t\taxs2.set_xlabel('Phi (N-Ca)')\n",
1753+
"\t\taxs2.set_ylabel('Psi (Ca-C)')\n",
1754+
"\t\taxs2.set_title('Ramachandran Plot')\n",
1755+
"\t\taxs2.legend()\n",
1756+
"\t\tfig2.tight_layout()\n",
1757+
"\t\t\n",
1758+
"\n",
1759+
"\tfig3, axs3 = plt.subplots(1, 3, figsize=(6, 4))\n",
1760+
"\t#separate angles plot\n",
1761+
"\tif 'angles' in sample_out and sample_out['angles'] is not None:\n",
1762+
"\t\ttrue_angles = data_sample['bondangles'].x.detach().cpu().numpy()\n",
1763+
"\t\tpred_angles = sample_out['angles'].detach().cpu().numpy()\n",
1764+
"\t\tangle_names = ['N-Ca-C', 'Ca-C-N', 'C-N-Ca']\n",
1765+
"\t\tangle_colors = ['r', 'g', 'b']\n",
1766+
"\t\t\n",
1767+
"\t\tfor i in range(3):\n",
1768+
"\t\t\taxs3[i].plot(true_angles[:, i], label='True ' + angle_names[i], \n",
1769+
"\t\t\t\t\t\t color=angle_colors[i], alpha=0.5)\n",
1770+
"\t\t\taxs3[i].plot(pred_angles[:, i], label='Pred ' + angle_names[i], \n",
1771+
"\t\t\t\t\t\t color=angle_colors[i], linestyle='--', alpha=0.5)\n",
1772+
"\t\t\taxs3[i].set_title(angle_names[i])\n",
1773+
"\t\t\taxs3[i].set_xlabel('Residue Index')\n",
1774+
"\t\t\taxs3[i].set_ylabel('Angle (radians)')\n",
1775+
"\t\t\taxs3[i].legend()\n",
1776+
"\t\tfig3.tight_layout()\n",
1777+
"\t\n",
1778+
"\t#confusion matrix for ss\n",
1779+
"\tfig4 = plt.figure(figsize=(6, 4))\n",
1780+
"\tif 'ss' in sample_out and sample_out['ss_pred'] is not None:\n",
1781+
"\t\tfrom sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
1782+
"\t\ttrue_ss = data_sample['ss'].x.argmax(dim=1).cpu().numpy()\n",
1783+
"\t\tpred_ss = sample_out['ss_pred'].argmax(dim=1).cpu().numpy()\n",
1784+
"\t\tcm = confusion_matrix(true_ss, pred_ss)\n",
1785+
"\t\tdisp = ConfusionMatrixDisplay(confusion_matrix=cm)\n",
1786+
"\t\tdisp.plot(ax=axs3[0], cmap='Blues')\n",
1787+
"\t\tfig4.set_title('SS Confusion Matrix')\n",
1788+
"\t\tfig4.set_xlabel('Predicted SS')\n",
1789+
"\t\tfig4.set_ylabel('True SS')\n",
1790+
"\t\tfig4.tight_layout()\n",
1791+
"\t\n",
1792+
"\t#confusion matrix for aa\n",
1793+
"\tfig5 = plt.figure(figsize=(6, 4))\n",
1794+
"\tif 'aa' in sample_out and sample_out['aa'] is not None:\n",
1795+
"\t\ttrue_aa = data_sample['AA'].x.argmax(dim=1).cpu().numpy()\n",
1796+
"\t\tpred_aa = sample_out['aa'].argmax(dim=1).cpu().numpy()\n",
1797+
"\t\tcm_aa = confusion_matrix(true_aa, pred_aa)\n",
1798+
"\t\tdisp_aa = ConfusionMatrixDisplay(confusion_matrix=cm_aa)\n",
1799+
"\t\tdisp_aa.plot(ax=axs3[1], cmap='Blues')\n",
1800+
"\t\tfig5.set_title('AA Confusion Matrix')\n",
1801+
"\t\tfig5.set_xlabel('Predicted AA')\n",
1802+
"\t\tfig5.set_ylabel('True AA')\n",
1803+
"\t\tfig5.tight_layout()\n",
1804+
"\n",
1805+
"\tsup_figs = [fig2, fig3, fig4, fig5]\n",
1806+
"\tmetrics_dict['sup_figs'] = sup_figs\n",
1807+
"\n",
17431808
"\treturn fig, metrics_dict , sample_out"
17441809
]
17451810
},
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ dependencies = [
4949
"statsmodels",
5050
"pydssp",
5151
"ete3",
52-
"pyyaml"
52+
"pyyaml",
53+
"ProDy"
5354
]
5455

5556
[project.scripts]

0 commit comments

Comments
 (0)