{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Building your own models for RT prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from peptdeep.model.featurize import (\n", " get_batch_aa_indices, \n", " get_batch_mod_feature\n", ")\n", "\n", "from peptdeep.settings import model_const\n", "\n", "import peptdeep.model.model_interface as model_base\n", "import peptdeep.model.building_block as building_block\n", "\n", "mod_feature_size = len(model_const['mod_elements'])\n", "\n", "import torch\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class RT_LSTM_Module(torch.nn.Module):\n", " def __init__(self, \n", " dropout=0.2\n", " ):\n", " super().__init__()\n", " \n", " self.dropout = torch.nn.Dropout(dropout)\n", " \n", " hidden = 128\n", " self.rt_encoder = building_block.Encoder_26AA_Mod_CNN_LSTM_AttnSum(\n", " hidden\n", " )\n", "\n", " self.rt_decoder = building_block.Decoder_Linear(\n", " hidden,\n", " 1\n", " )\n", "\n", " def forward(self, \n", " aa_indices, \n", " mod_x,\n", " ):\n", " x = self.rt_encoder(aa_indices, mod_x)\n", " x = self.dropout(x)\n", "\n", " return self.rt_decoder(x).squeeze(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class RT_Transformer_Module(torch.nn.Module):\n", " def __init__(self, \n", " dropout=0.2\n", " ):\n", " super().__init__()\n", " \n", " self.dropout = torch.nn.Dropout(dropout)\n", " \n", " hidden = 128\n", " self.encoder = building_block.Encoder_AA_Mod_Transformer_AttnSum(\n", " hidden\n", " )\n", "\n", " self.decoder = building_block.Decoder_Linear(\n", " hidden,1\n", " )\n", "\n", " def forward(self, \n", " aa_indices, \n", " mod_x,\n", " ):\n", " x = self.encoder(aa_indices, mod_x)\n", " x = self.dropout(x)\n", "\n", " return self.decoder(x).squeeze(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class RT_ModelInterface(model_base.ModelInterface):\n", " def __init__(self, \n", " model_class:torch.nn.Module=RT_LSTM_Module,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", " self.build(\n", " model_class,\n", " dropout=dropout,\n", " )\n", " self.loss_func = torch.nn.L1Loss()\n", " self.target_column_to_train = 'rt_norm'\n", " self.target_column_to_predict = 'rt_pred'\n", "\n", " def _get_features_from_batch_df(self, \n", " batch_df: pd.DataFrame,\n", " ):\n", " aa_indices = torch.LongTensor(\n", " get_batch_aa_indices(\n", " batch_df['sequence'].values.astype('U')\n", " )\n", " )\n", " mod_x = torch.Tensor(\n", " get_batch_mod_feature(\n", " batch_df\n", " )\n", " )\n", "\n", " return aa_indices, mod_x\n", "\n", " def _get_targets_from_batch_df(self, \n", " batch_df: pd.DataFrame,\n", " ) -> torch.Tensor:\n", " return torch.Tensor(batch_df['rt_norm'].values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing the RT model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare training data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sequencepep_nameirtmodsmod_sitesnAArt_norm
0LGGNEQVTRRT-pep a-24.9290.000000
1GAGSSEPVTGLDAKRT-pep b0.00140.199488
2VEATFGVDESNAKRT-pep c12.39130.298671
3YILAGVENSKRT-pep d19.79100.357909
4TPVISGGPYEYRRT-pep e28.71120.429315
5TPVITGAPYEYRRT-pep f33.38120.466699
6DGLDAASYYAPVRRT-pep g42.26130.537784
7ADVTPADFSEWSKRT-pep h54.62130.636728
8GTFIIDPGGVIRRT-pep i70.52120.764009
9GTFIIDPAAVIRRT-pep k87.23120.897775
10LFLQFGAQGSPFLKRT-pep l100.00141.000000
\n", "
" ], "text/plain": [ " sequence pep_name irt mods mod_sites nAA rt_norm\n", "0 LGGNEQVTR RT-pep a -24.92 9 0.000000\n", "1 GAGSSEPVTGLDAK RT-pep b 0.00 14 0.199488\n", "2 VEATFGVDESNAK RT-pep c 12.39 13 0.298671\n", "3 YILAGVENSK RT-pep d 19.79 10 0.357909\n", "4 TPVISGGPYEYR RT-pep e 28.71 12 0.429315\n", "5 TPVITGAPYEYR RT-pep f 33.38 12 0.466699\n", "6 DGLDAASYYAPVR RT-pep g 42.26 13 0.537784\n", "7 ADVTPADFSEWSK RT-pep h 54.62 13 0.636728\n", "8 GTFIIDPGGVIR RT-pep i 70.52 12 0.764009\n", "9 GTFIIDPAAVIR RT-pep k 87.23 12 0.897775\n", "10 LFLQFGAQGSPFLK RT-pep l 100.00 14 1.000000" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from peptdeep.model.rt import irt_pep\n", "irt_pep['rt_norm'] = (irt_pep.irt - irt_pep.irt.min())/(irt_pep.irt.max()-irt_pep.irt.min())\n", "irt_pep" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device `gpu` is not available, set to `cpu`\n" ] } ], "source": [ "rt_model = RT_ModelInterface(model_class=RT_LSTM_Module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test the untrained model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sequencepep_nameirtmodsmod_sitesnAArt_normrt_pred
0LGGNEQVTRRT-pep a-24.9290.0000000.0
1GAGSSEPVTGLDAKRT-pep b0.00140.1994880.0
2VEATFGVDESNAKRT-pep c12.39130.2986710.0
3YILAGVENSKRT-pep d19.79100.3579090.0
4TPVISGGPYEYRRT-pep e28.71120.4293150.0
5TPVITGAPYEYRRT-pep f33.38120.4666990.0
6DGLDAASYYAPVRRT-pep g42.26130.5377840.0
7ADVTPADFSEWSKRT-pep h54.62130.6367280.0
8GTFIIDPGGVIRRT-pep i70.52120.7640090.0
9GTFIIDPAAVIRRT-pep k87.23120.8977750.0
10LFLQFGAQGSPFLKRT-pep l100.00141.0000000.0
\n", "
" ], "text/plain": [ " sequence pep_name irt mods mod_sites nAA rt_norm rt_pred\n", "0 LGGNEQVTR RT-pep a -24.92 9 0.000000 0.0\n", "1 GAGSSEPVTGLDAK RT-pep b 0.00 14 0.199488 0.0\n", "2 VEATFGVDESNAK RT-pep c 12.39 13 0.298671 0.0\n", "3 YILAGVENSK RT-pep d 19.79 10 0.357909 0.0\n", "4 TPVISGGPYEYR RT-pep e 28.71 12 0.429315 0.0\n", "5 TPVITGAPYEYR RT-pep f 33.38 12 0.466699 0.0\n", "6 DGLDAASYYAPVR RT-pep g 42.26 13 0.537784 0.0\n", "7 ADVTPADFSEWSK RT-pep h 54.62 13 0.636728 0.0\n", "8 GTFIIDPGGVIR RT-pep i 70.52 12 0.764009 0.0\n", "9 GTFIIDPAAVIR RT-pep k 87.23 12 0.897775 0.0\n", "10 LFLQFGAQGSPFLK RT-pep l 100.00 14 1.000000 0.0" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rt_model.predict(irt_pep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test if training works" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "rt_model.train(irt_pep, epoch=100, verbose=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test if the model fits the irt_pep data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sequencepep_nameirtmodsmod_sitesnAArt_normrt_pred
0LGGNEQVTRRT-pep a-24.9290.0000000.000000
1GAGSSEPVTGLDAKRT-pep b0.00140.1994880.209159
2VEATFGVDESNAKRT-pep c12.39130.2986710.293867
3YILAGVENSKRT-pep d19.79100.3579090.349884
4TPVISGGPYEYRRT-pep e28.71120.4293150.416145
5TPVITGAPYEYRRT-pep f33.38120.4666990.462958
6DGLDAASYYAPVRRT-pep g42.26130.5377840.540334
7ADVTPADFSEWSKRT-pep h54.62130.6367280.638801
8GTFIIDPGGVIRRT-pep i70.52120.7640090.725222
9GTFIIDPAAVIRRT-pep k87.23120.8977750.882472
10LFLQFGAQGSPFLKRT-pep l100.00141.0000000.962103
\n", "
" ], "text/plain": [ " sequence pep_name irt mods mod_sites nAA rt_norm rt_pred\n", "0 LGGNEQVTR RT-pep a -24.92 9 0.000000 0.000000\n", "1 GAGSSEPVTGLDAK RT-pep b 0.00 14 0.199488 0.209159\n", "2 VEATFGVDESNAK RT-pep c 12.39 13 0.298671 0.293867\n", "3 YILAGVENSK RT-pep d 19.79 10 0.357909 0.349884\n", "4 TPVISGGPYEYR RT-pep e 28.71 12 0.429315 0.416145\n", "5 TPVITGAPYEYR RT-pep f 33.38 12 0.466699 0.462958\n", "6 DGLDAASYYAPVR RT-pep g 42.26 13 0.537784 0.540334\n", "7 ADVTPADFSEWSK RT-pep h 54.62 13 0.636728 0.638801\n", "8 GTFIIDPGGVIR RT-pep i 70.52 12 0.764009 0.725222\n", "9 GTFIIDPAAVIR RT-pep k 87.23 12 0.897775 0.882472\n", "10 LFLQFGAQGSPFLK RT-pep l 100.00 14 1.000000 0.962103" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rt_model.predict(irt_pep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get number of model parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/plain": [ "232448" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rt_model.get_parameter_num()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### It is easy to switch the model to Transformer. \n", "#### Users can add more nn.Modules without re-designing the AA/PTM feature extraction parts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device `gpu` is not available, set to `cpu`\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sequencepep_nameirtmodsmod_sitesnAArt_normrt_pred
0LGGNEQVTRRT-pep a-24.9290.0000000.007334
1GAGSSEPVTGLDAKRT-pep b0.00140.1994880.209777
2VEATFGVDESNAKRT-pep c12.39130.2986710.350849
3YILAGVENSKRT-pep d19.79100.3579090.388612
4TPVISGGPYEYRRT-pep e28.71120.4293150.483431
5TPVITGAPYEYRRT-pep f33.38120.4666990.506625
6DGLDAASYYAPVRRT-pep g42.26130.5377840.578891
7ADVTPADFSEWSKRT-pep h54.62130.6367280.619564
8GTFIIDPGGVIRRT-pep i70.52120.7640090.818625
9GTFIIDPAAVIRRT-pep k87.23120.8977750.936355
10LFLQFGAQGSPFLKRT-pep l100.00141.0000001.094726
\n", "
" ], "text/plain": [ " sequence pep_name irt mods mod_sites nAA rt_norm rt_pred\n", "0 LGGNEQVTR RT-pep a -24.92 9 0.000000 0.007334\n", "1 GAGSSEPVTGLDAK RT-pep b 0.00 14 0.199488 0.209777\n", "2 VEATFGVDESNAK RT-pep c 12.39 13 0.298671 0.350849\n", "3 YILAGVENSK RT-pep d 19.79 10 0.357909 0.388612\n", "4 TPVISGGPYEYR RT-pep e 28.71 12 0.429315 0.483431\n", "5 TPVITGAPYEYR RT-pep f 33.38 12 0.466699 0.506625\n", "6 DGLDAASYYAPVR RT-pep g 42.26 13 0.537784 0.578891\n", "7 ADVTPADFSEWSK RT-pep h 54.62 13 0.636728 0.619564\n", "8 GTFIIDPGGVIR RT-pep i 70.52 12 0.764009 0.818625\n", "9 GTFIIDPAAVIR RT-pep k 87.23 12 0.897775 0.936355\n", "10 LFLQFGAQGSPFLK RT-pep l 100.00 14 1.000000 1.094726" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rt_model = RT_ModelInterface(model_class=RT_Transformer_Module)\n", "rt_model.train(irt_pep, epoch=50, warmup_epoch=20)\n", "rt_model.predict(irt_pep)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "data": { "text/plain": [ "817104" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rt_model.get_parameter_num()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.3 ('base')", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }