|
1487 | 1487 | }, |
1488 | 1488 | { |
1489 | 1489 | "cell_type": "code", |
1490 | | - "execution_count": 22, |
| 1490 | + "execution_count": null, |
1491 | 1491 | "id": "7d89d6d2", |
1492 | 1492 | "metadata": {}, |
1493 | 1493 | "outputs": [], |
|
1739 | 1739 | "\t\n", |
1740 | 1740 | "\tencoder.train()\n", |
1741 | 1741 | "\tdecoder.train()\n", |
1742 | | - "\t\n", |
| 1742 | + "\n", |
| 1743 | + "\n", |
| 1744 | + "\tfig2 = plt.figure(figsize=(6, 4))\n", |
| 1745 | + "\t#ramachandran plot\n", |
| 1746 | + "\tif 'angles' in sample_out and sample_out['angles'] is not None:\n", |
| 1747 | + "\t\ttrue_angles = data_sample['bondangles'].x.detach().cpu().numpy()\n", |
| 1748 | + "\t\tpred_angles = sample_out['angles'].detach().cpu().numpy()\n", |
| 1749 | + "\t\taxs2 = fig2.add_subplot(1, 1, 1)\n", |
| 1750 | + "\t\taxs2.scatter(true_angles[:, 0], true_angles[:, 1], label='True', color='blue', alpha=0.5, s=10)\n", |
| 1751 | + "\t\taxs2.scatter(pred_angles[:, 0], pred_angles[:, 1], label='Pred', color='red', alpha=0.5, s=10)\n", |
| 1752 | + "\t\taxs2.set_xlabel('Phi (N-Ca)')\n", |
| 1753 | + "\t\taxs2.set_ylabel('Psi (Ca-C)')\n", |
| 1754 | + "\t\taxs2.set_title('Ramachandran Plot')\n", |
| 1755 | + "\t\taxs2.legend()\n", |
| 1756 | + "\t\tfig2.tight_layout()\n", |
| 1757 | + "\t\t\n", |
| 1758 | + "\n", |
| 1759 | + "\tfig3, axs3 = plt.subplots(1, 3, figsize=(6, 4))\n", |
| 1760 | + "\t#separate angles plot\n", |
| 1761 | + "\tif 'angles' in sample_out and sample_out['angles'] is not None:\n", |
| 1762 | + "\t\ttrue_angles = data_sample['bondangles'].x.detach().cpu().numpy()\n", |
| 1763 | + "\t\tpred_angles = sample_out['angles'].detach().cpu().numpy()\n", |
| 1764 | + "\t\tangle_names = ['N-Ca-C', 'Ca-C-N', 'C-N-Ca']\n", |
| 1765 | + "\t\tangle_colors = ['r', 'g', 'b']\n", |
| 1766 | + "\t\t\n", |
| 1767 | + "\t\tfor i in range(3):\n", |
| 1768 | + "\t\t\taxs3[i].plot(true_angles[:, i], label='True ' + angle_names[i], \n", |
| 1769 | + "\t\t\t\t\t\t color=angle_colors[i], alpha=0.5)\n", |
| 1770 | + "\t\t\taxs3[i].plot(pred_angles[:, i], label='Pred ' + angle_names[i], \n", |
| 1771 | + "\t\t\t\t\t\t color=angle_colors[i], linestyle='--', alpha=0.5)\n", |
| 1772 | + "\t\t\taxs3[i].set_title(angle_names[i])\n", |
| 1773 | + "\t\t\taxs3[i].set_xlabel('Residue Index')\n", |
| 1774 | + "\t\t\taxs3[i].set_ylabel('Angle (radians)')\n", |
| 1775 | + "\t\t\taxs3[i].legend()\n", |
| 1776 | + "\t\tfig3.tight_layout()\n", |
| 1777 | + "\t\n", |
| 1778 | + "\t#confusion matrix for ss\n", |
| 1779 | + "\tfig4 = plt.figure(figsize=(6, 4))\n", |
| 1780 | + "\tif 'ss' in sample_out and sample_out['ss_pred'] is not None:\n", |
| 1781 | + "\t\tfrom sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n", |
| 1782 | + "\t\ttrue_ss = data_sample['ss'].x.argmax(dim=1).cpu().numpy()\n", |
| 1783 | + "\t\tpred_ss = sample_out['ss_pred'].argmax(dim=1).cpu().numpy()\n", |
| 1784 | + "\t\tcm = confusion_matrix(true_ss, pred_ss)\n", |
| 1785 | + "\t\tdisp = ConfusionMatrixDisplay(confusion_matrix=cm)\n", |
| 1786 | + "\t\tdisp.plot(ax=axs3[0], cmap='Blues')\n", |
| 1787 | + "\t\tfig4.set_title('SS Confusion Matrix')\n", |
| 1788 | + "\t\tfig4.set_xlabel('Predicted SS')\n", |
| 1789 | + "\t\tfig4.set_ylabel('True SS')\n", |
| 1790 | + "\t\tfig4.tight_layout()\n", |
| 1791 | + "\t\n", |
| 1792 | + "\t#confusion matrix for aa\n", |
| 1793 | + "\tfig5 = plt.figure(figsize=(6, 4))\n", |
| 1794 | + "\tif 'aa' in sample_out and sample_out['aa'] is not None:\n", |
| 1795 | + "\t\ttrue_aa = data_sample['AA'].x.argmax(dim=1).cpu().numpy()\n", |
| 1796 | + "\t\tpred_aa = sample_out['aa'].argmax(dim=1).cpu().numpy()\n", |
| 1797 | + "\t\tcm_aa = confusion_matrix(true_aa, pred_aa)\n", |
| 1798 | + "\t\tdisp_aa = ConfusionMatrixDisplay(confusion_matrix=cm_aa)\n", |
| 1799 | + "\t\tdisp_aa.plot(ax=axs3[1], cmap='Blues')\n", |
| 1800 | + "\t\tfig5.set_title('AA Confusion Matrix')\n", |
| 1801 | + "\t\tfig5.set_xlabel('Predicted AA')\n", |
| 1802 | + "\t\tfig5.set_ylabel('True AA')\n", |
| 1803 | + "\t\tfig5.tight_layout()\n", |
| 1804 | + "\n", |
| 1805 | + "\tsup_figs = [fig2, fig3, fig4, fig5]\n", |
| 1806 | + "\tmetrics_dict['sup_figs'] = sup_figs\n", |
| 1807 | + "\n", |
1743 | 1808 | "\treturn fig, metrics_dict , sample_out" |
1744 | 1809 | ] |
1745 | 1810 | }, |
|
0 commit comments