Source code for peptdeep.model.building_block
import torch
import numpy as np
# BERT from huggingface
from transformers.models.bert.modeling_bert import BertEncoder
from peptdeep.settings import model_const
from peptdeep.settings import global_settings as settings
torch.set_num_threads(2)
mod_feature_size = len(model_const["mod_elements"])
max_instrument_num = model_const["max_instrument_num"]
frag_types = settings["model"]["frag_types"]
max_frag_charge = settings["model"]["max_frag_charge"]
num_ion_types = len(frag_types) * max_frag_charge
aa_embedding_size = model_const["aa_embedding_size"]
[docs]
def aa_embedding(hidden_size):
return torch.nn.Embedding(aa_embedding_size, hidden_size, padding_idx=0)
[docs]
def aa_one_hot(aa_indices, *cat_others):
aa_x = torch.nn.functional.one_hot(aa_indices, aa_embedding_size)
return torch.cat((aa_x, *cat_others), 2)
[docs]
def instrument_embedding(hidden_size):
return torch.nn.Embedding(max_instrument_num, hidden_size)
[docs]
def xavier_param(*shape):
x = torch.nn.Parameter(torch.empty(shape), requires_grad=False)
torch.nn.init.xavier_uniform_(x)
return x
[docs]
def invert_attention_mask(
encoder_attention_mask: torch.Tensor, dtype=torch.float32
) -> torch.FloatTensor:
"""
See `invert_attention_mask()` in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L737.
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`torch.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=dtype
) # fp16 compatibility
encoder_extended_attention_mask = (
1.0 - encoder_extended_attention_mask
) * torch.finfo(dtype).min
return encoder_extended_attention_mask
init_state = xavier_param
[docs]
class SeqCNN_MultiKernel(torch.nn.Module):
"""
Extract sequence features using `torch.nn.Conv1D` with
different kernel sizes (1(residue connection),3,5,7),
and then concatenate the outputs of these Conv1Ds.
"""
[docs]
def __init__(self, out_features: int):
"""
Parameters
----------
out_features : int
Must be divided by 4.
Raises
------
ValueError
"out_features must be divided by 4"
"""
super().__init__()
hidden = out_features // 4
if hidden * 4 != out_features:
raise ValueError("out_features must be divided by 4")
self.cnn_short = torch.nn.Conv1d(hidden, hidden, kernel_size=3, padding=1)
self.cnn_medium = torch.nn.Conv1d(hidden, hidden, kernel_size=5, padding=2)
self.cnn_long = torch.nn.Conv1d(hidden, hidden, kernel_size=7, padding=3)
[docs]
def forward(self, x):
x = x.transpose(1, 2)
x1 = self.cnn_short(x)
x2 = self.cnn_medium(x)
x3 = self.cnn_long(x)
return torch.cat((x, x1, x2, x3), dim=1).transpose(1, 2)
# legacy
[docs]
class SeqCNN(torch.nn.Module):
"""
Extract sequence features using `torch.nn.Conv1D` with
different kernel sizes (1(residue connection),3,5,7), and then concatenate
the outputs of these Conv1Ds. The Output dim is 4*embedding_hidden.
"""
[docs]
def __init__(self, embedding_hidden):
super().__init__()
self.cnn_short = torch.nn.Conv1d(
embedding_hidden, embedding_hidden, kernel_size=3, padding=1
)
self.cnn_medium = torch.nn.Conv1d(
embedding_hidden, embedding_hidden, kernel_size=5, padding=2
)
self.cnn_long = torch.nn.Conv1d(
embedding_hidden, embedding_hidden, kernel_size=7, padding=3
)
[docs]
def forward(self, x):
x = x.transpose(1, 2)
x1 = self.cnn_short(x)
x2 = self.cnn_medium(x)
x3 = self.cnn_long(x)
return torch.cat((x, x1, x2, x3), dim=1).transpose(1, 2)
[docs]
class Seq_Transformer(torch.nn.Module):
"""
Using PyTorch built-in Transformer layers
"""
[docs]
def __init__(self, in_features, hidden_features, nheads=8, nlayers=2, dropout=0.1):
super().__init__()
encoder_layers = torch.nn.TransformerEncoderLayer(
in_features, nheads, hidden_features, dropout
)
self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
[docs]
class Hidden_Transformer(torch.nn.Module):
"""
Transformer NN based on pytorch's built-in TransformerLayer class
"""
[docs]
def __init__(self, hidden, hidden_expand=4, nheads=8, nlayers=4, dropout=0.1):
super().__init__()
self.transormer = Seq_Transformer(
hidden,
hidden * hidden_expand,
nheads=nheads,
nlayers=nlayers,
dropout=dropout,
)
class _Pseudo_Bert_Config:
def __init__(
self,
hidden_dim=256,
intermediate_size=1024,
num_attention_heads=8,
num_bert_layers=4,
dropout=0.1,
output_attentions=False,
):
self.add_cross_attention = False
self.chunk_size_feed_forward = 0
self.is_decoder = False
self.seq_len_dim = 1
self.training = False
self.hidden_act = "gelu"
self.hidden_dropout_prob = dropout
self.attention_probs_dropout_prob = dropout
self.hidden_size = hidden_dim
self.initializer_range = 0.02
self.intermediate_size = intermediate_size
self.layer_norm_eps = 1e-8
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_bert_layers
self.output_attentions = output_attentions
self._attn_implementation = "eager" # Add this for transformers-4.41.0
[docs]
class Hidden_HFace_Transformer(torch.nn.Module):
"""
Transformer NN based on HuggingFace's BertEncoder class
"""
[docs]
def __init__(
self,
hidden_dim,
hidden_expand=4,
nheads=8,
nlayers=4,
dropout=0.1,
output_attentions=False,
):
super().__init__()
self.config = _Pseudo_Bert_Config(
hidden_dim=hidden_dim,
intermediate_size=hidden_dim * hidden_expand,
num_attention_heads=nheads,
num_bert_layers=nlayers,
dropout=dropout,
output_attentions=False,
)
self.output_attentions = output_attentions
self.bert = BertEncoder(self.config)
[docs]
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
) -> tuple:
"""
Parameters
----------
x : torch.Tensor
shape = (batch, seq_len, dim)
attention_mask : torch.Tensor
shape = (batch, seq_len), [0,1] tensor , 1=enable
Returns
-------
(Tensor, [Tensor])
out[0] is the hidden layer output,
and out[1] is the output attention
if self.output_attentions==True
"""
if attention_mask is not None:
attention_mask = invert_attention_mask(attention_mask, dtype=x.dtype)
return self.bert(
x,
attention_mask=attention_mask,
output_attentions=self.output_attentions,
return_dict=False,
)
# legacy
HiddenBert = Hidden_HFace_Transformer
[docs]
class HFace_Transformer_with_PositionalEncoder(torch.nn.Module):
"""
HuggingFace transformer with a positional encoder in front.
Parameters
----------
hidden_dim : int
Input and output feature dimension.
hidden_expand : int, optional
FFN hidden size = hidden*hidden_expand. Defaults to 4.
nhead : int, optional
Multi-head attention number. Defaults to 8.
nlayers : int, optional
Number of transformer layers. Defaults to 4.
dropout : float, optional
Dropout rate. Defaults to 0.1.
output_attentions : bool, optional
If output attention values. Defaults to False.
max_len : int, optional
Max input sequence length. Defaults to 200.
"""
[docs]
def __init__(
self,
hidden_dim: int,
hidden_expand=4,
nheads=8,
nlayers=4,
dropout=0.1,
output_attentions=False,
max_len=200,
):
super().__init__()
self.pos_encoder = PositionalEncoding(hidden_dim, max_len=max_len)
self.bert = Hidden_HFace_Transformer(
hidden_dim=hidden_dim,
hidden_expand=hidden_expand,
nheads=nheads,
nlayers=nlayers,
dropout=dropout,
output_attentions=output_attentions,
)
[docs]
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
) -> tuple:
"""
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
tuple
Tensor: Output tensor.
[Tensor]: Attention tensor, returned only if output_attentions is True.
"""
x = self.pos_encoder(x)
return self.bert(x, attention_mask)
[docs]
class SeqLSTM(torch.nn.Module):
"""
returns LSTM applied on sequence input
"""
[docs]
def __init__(self, in_features, out_features, rnn_layer=2, bidirectional=True):
super().__init__()
if bidirectional:
if out_features % 2 != 0:
raise ValueError("'out_features' must be able to be divided by 2")
hidden = out_features // 2
else:
hidden = out_features
self.rnn_h0 = init_state(rnn_layer + rnn_layer * bidirectional, 1, hidden)
self.rnn_c0 = init_state(rnn_layer + rnn_layer * bidirectional, 1, hidden)
self.rnn = torch.nn.LSTM(
input_size=in_features,
hidden_size=hidden,
num_layers=rnn_layer,
batch_first=True,
bidirectional=bidirectional,
)
[docs]
def forward(self, x: torch.Tensor):
h0 = self.rnn_h0.repeat(1, x.size(0), 1)
c0 = self.rnn_c0.repeat(1, x.size(0), 1)
x, _ = self.rnn(x, (h0, c0))
return x
[docs]
class SeqGRU(torch.nn.Module):
"""
returns GRU applied on sequence input
"""
[docs]
def __init__(self, in_features, out_features, rnn_layer=2, bidirectional=True):
super().__init__()
if bidirectional:
if out_features % 2 != 0:
raise ValueError("'out_features' must be able to be divided by 2")
# to make sure that output dim is out_features
# as `bidirectional` will cat forward and reverse RNNs
hidden = out_features // 2
else:
hidden = out_features
self.rnn_h0 = init_state(rnn_layer + rnn_layer * bidirectional, 1, hidden)
self.rnn = torch.nn.GRU(
input_size=in_features,
hidden_size=hidden,
num_layers=rnn_layer,
batch_first=True,
bidirectional=bidirectional,
)
[docs]
def forward(self, x: torch.Tensor):
h0 = self.rnn_h0.repeat(1, x.size(0), 1)
x, _ = self.rnn(x, h0)
return x
[docs]
class SeqAttentionSum(torch.nn.Module):
"""
apply linear transformation and tensor rescaling with softmax
"""
[docs]
def __init__(self, in_features):
super().__init__()
self.attn = torch.nn.Sequential(
torch.nn.Linear(in_features, 1, bias=False),
torch.nn.Softmax(dim=1),
)
[docs]
class PositionalEncoding(torch.nn.Module):
"""
transform sequence input into a positional representation
"""
[docs]
def __init__(self, out_features=128, max_len=200):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, out_features, 2) * (-np.log(max_len) / out_features)
)
pe = torch.zeros(1, max_len, out_features)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
[docs]
class PositionalEmbedding(torch.nn.Module):
"""
transform sequence with the standard embedding function
"""
[docs]
def __init__(self, out_features=128, max_len=200):
super().__init__()
self.pos_emb = torch.nn.Embedding(max_len, out_features)
[docs]
def forward(self, x: torch.Tensor):
return x + self.pos_emb(
torch.arange(x.size(1), dtype=torch.long, device=x.device).unsqueeze(0)
)
[docs]
class Meta_Embedding(torch.nn.Module):
"""Encodes Charge state, Normalized Collision Energy (NCE) and Instrument for a given spectrum
into a 'meta' single layer network
"""
[docs]
def __init__(
self,
out_features,
):
super().__init__()
self.nn = torch.nn.Linear(max_instrument_num + 1, out_features - 1)
[docs]
def forward(
self,
charges,
NCEs,
instrument_indices,
):
inst_x = torch.nn.functional.one_hot(instrument_indices, max_instrument_num)
meta_x = self.nn(torch.cat((inst_x, NCEs), 1))
meta_x = torch.cat((meta_x, charges), 1)
return meta_x
# legacy
InputMetaNet = Meta_Embedding
[docs]
class Mod_Embedding_FixFirstK(torch.nn.Module):
"""
Encodes the modification vector in a single layer feed forward network, but not transforming the first k features
"""
[docs]
def __init__(
self,
out_features,
):
super().__init__()
self.k = 6
self.nn = torch.nn.Linear(
mod_feature_size - self.k, out_features - self.k, bias=False
)
[docs]
def forward(
self,
mod_x,
):
return torch.cat((mod_x[:, :, : self.k], self.nn(mod_x[:, :, self.k :])), 2)
# legacy
InputModNetFixFirstK = Mod_Embedding_FixFirstK
[docs]
class AA_Mod_Embedding(torch.nn.Module):
"""
Concatenates the AA (128 ASCII codes) embedding with the modifcation vector
"""
[docs]
def __init__(
self,
out_features,
mod_feature_size=8,
):
super().__init__()
self.mod_embedding = Mod_Embedding_FixFirstK(mod_feature_size)
self.aa_embedding = ascii_embedding(out_features - mod_feature_size)
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_embedding(mod_x)
aa_x = self.aa_embedding(aa_indices)
return torch.cat((aa_x, mod_x), 2)
# legacy
InputAAEmbedding = AA_Mod_Embedding
[docs]
class Mod_Embedding(torch.nn.Module):
"""
Encodes the modification vector in a single layer feed forward network
"""
[docs]
def __init__(
self,
out_features,
):
super().__init__()
self.nn = torch.nn.Linear(mod_feature_size, out_features, bias=False)
# legacy
InputModNet = Mod_Embedding
[docs]
class Input_26AA_Mod_PositionalEncoding(torch.nn.Module):
"""
Encodes AA (26 AA letters) and modification vector
"""
[docs]
def __init__(self, out_features, max_len=200):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.aa_emb = aa_embedding(out_features - mod_hidden)
self.pos_encoder = PositionalEncoding(out_features, max_len)
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_nn(mod_x)
x = self.aa_emb(aa_indices)
return self.pos_encoder(torch.cat((x, mod_x), 2))
# legacy
AATransformerEncoding = Input_26AA_Mod_PositionalEncoding
[docs]
class Input_AA_Mod_PositionalEncoding(torch.nn.Module):
"""
Encodes AA (ASCII codes) and modification vector
"""
[docs]
def __init__(self, out_features, max_len=200):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.aa_emb = ascii_embedding(out_features - mod_hidden)
self.pos_encoder = PositionalEncoding(out_features, max_len)
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_nn(mod_x)
x = self.aa_emb(aa_indices)
return self.pos_encoder(torch.cat((x, mod_x), 2))
[docs]
class Input_AA_Mod_Charge_PositionalEncoding(torch.nn.Module):
"""
Embed AA (128 ASCII codes), modification, and charge state
"""
[docs]
def __init__(self, out_features, max_len=200):
super().__init__()
mod_hidden = 8
self.charge_dim = 2
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.aa_emb = ascii_embedding(out_features - mod_hidden - self.charge_dim)
self.pos_encoder = PositionalEncoding(out_features, max_len)
[docs]
def forward(self, aa_indices, mod_x, charges):
mod_x = self.mod_nn(mod_x)
x = self.aa_emb(aa_indices)
charge_x = charges[:, None, None].repeat(1, mod_x.size(1), self.charge_dim)
return self.pos_encoder(torch.cat((x, mod_x, charge_x), 2))
[docs]
class Input_26AA_Mod_LSTM(torch.nn.Module):
"""
Applies an LSTM network to a AA (26 AA letters) sequence & modifications
"""
[docs]
def __init__(
self,
out_features,
n_lstm_layers=1,
):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.lstm = SeqLSTM(
aa_embedding_size + mod_hidden,
out_features,
n_lstm_layers=n_lstm_layers,
bidirectional=True,
)
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(aa_indices, mod_x)
return self.lstm(x)
# legacy
InputAALSTM = Input_26AA_Mod_LSTM
[docs]
class Input_26AA_Mod_Meta_LSTM(torch.nn.Module):
"""
Applies a LSTM network to a AA (26 AA letters) sequence and modifications,
and concatenates with 'meta' information (charge, nce, instrument_indices)
"""
[docs]
def __init__(
self,
out_features,
):
super().__init__()
meta_dim = 4
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.meta_nn = Meta_Embedding(meta_dim)
self.nn = SeqLSTM(
aa_embedding_size + mod_hidden,
out_features - meta_dim,
rnn_layer=1,
bidirectional=True,
)
[docs]
def forward(self, aa_indices, mod_x, charges, NCEs, instrument_indices):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(aa_indices, mod_x)
x = self.nn(x)
meta_x = (
self.meta_nn(charges, NCEs, instrument_indices)
.unsqueeze(1)
.repeat(1, mod_x.size(1), 1)
)
return torch.cat((x, meta_x), 2)
# legacy
InputAALSTM_cat_Meta = Input_26AA_Mod_Meta_LSTM
[docs]
class Input_26AA_Mod_Charge_LSTM(torch.nn.Module):
"""
Applies a LSTM network to a AA (26 AA letters) sequence and modifications,
and concatenates with charge state information
"""
[docs]
def __init__(
self,
out_features,
):
super().__init__()
self.charge_dim = 2
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
self.nn = SeqLSTM(
aa_embedding_size + mod_hidden,
out_features - self.charge_dim,
rnn_layer=1,
bidirectional=True,
)
[docs]
def forward(self, aa_indices, mod_x, charges):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(aa_indices, mod_x)
x = self.nn(x)
charge_x = charges.unsqueeze(1).repeat(1, mod_x.size(1), self.charge_dim)
return torch.cat((x, charge_x), 2)
# legacy
InputAALSTM_cat_Charge = Input_26AA_Mod_Charge_LSTM
[docs]
class Seq_Meta_LSTM(torch.nn.Module):
"""
Takes a hidden layer which processes the hidden tensor
as well as the 'meta' information of NCE, Instrument, Charge
"""
[docs]
def __init__(
self,
in_features,
out_features,
):
super().__init__()
meta_dim = 4
self.meta_nn = Meta_Embedding(meta_dim)
self.nn = SeqLSTM(
in_features + meta_dim, out_features, rnn_layer=1, bidirectional=False
)
[docs]
def forward(self, x, charges, NCEs, instrument_indices):
meta_x = (
self.meta_nn(charges, NCEs, instrument_indices)
.unsqueeze(1)
.repeat(1, x.size(1), 1)
)
return self.nn(torch.cat((x, meta_x), 2))
# legacy
OutputLSTM_cat_Meta = Seq_Meta_LSTM
[docs]
class Seq_Meta_Linear(torch.nn.Module):
"""
takes a hidden linear which processed the 'meta' information of NCE, Instrument, Charge
"""
[docs]
def __init__(
self,
in_features,
out_features,
):
super().__init__()
meta_dim = 4
self.meta_nn = Meta_Embedding(meta_dim)
self.nn = torch.nn.Linear(in_features + meta_dim, out_features, bias=False)
[docs]
def forward(self, x, charges, NCEs, instrument_indices):
meta_x = (
self.meta_nn(charges, NCEs, instrument_indices)
.unsqueeze(1)
.repeat(1, x.size(1), 1)
)
return self.nn(torch.cat((x, meta_x), 2))
# legacy
OutputLinear_cat_Meta = Seq_Meta_Linear
[docs]
class Encoder_26AA_Mod_LSTM(torch.nn.Module):
"""
Two LSTM layers on AA (26 AA letters) and modifications.
"""
[docs]
def __init__(self, out_features, n_lstm_layers=1):
super().__init__()
self.input_nn = Input_26AA_Mod_LSTM(out_features)
self.nn = SeqLSTM(out_features, out_features, rnn_layer=n_lstm_layers)
[docs]
def forward(self, aa_indices, mod_x):
x = self.input_nn(aa_indices, mod_x)
x = self.nn(x)
return x
# legacy
Input_AA_LSTM_Encoder = Encoder_26AA_Mod_LSTM
[docs]
class Encoder_26AA_Mod_CNN_LSTM(torch.nn.Module):
"""
Encode AAs (26 AA letters) and modifications by CNN and LSTM layers
"""
[docs]
def __init__(self, out_features, n_lstm_layers=1):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
input_dim = aa_embedding_size + mod_hidden
self.input_cnn = SeqCNN(input_dim)
self.hidden_nn = SeqLSTM(
input_dim * 4, out_features, rnn_layer=n_lstm_layers
) # SeqCNN outputs 4*input_dim
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(aa_indices, mod_x)
x = self.input_cnn(x)
x = self.hidden_nn(x)
return x
# legacy
Input_AA_CNN_Encoder = Encoder_26AA_Mod_CNN_LSTM
[docs]
class Encoder_26AA_Mod_CNN_LSTM_AttnSum(torch.nn.Module):
"""
Encode AAs (26 AA letters) and modifications by CNN and LSTM layers,
then by 'SeqAttentionSum'.
"""
[docs]
def __init__(self, out_features, n_lstm_layers=2):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
input_dim = aa_embedding_size + mod_hidden
self.input_cnn = SeqCNN(input_dim)
self.hidden_nn = SeqLSTM(
input_dim * 4, out_features, rnn_layer=n_lstm_layers
) # SeqCNN outputs 4*input_dim
self.attn_sum = SeqAttentionSum(out_features)
[docs]
def forward(self, aa_indices, mod_x):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(aa_indices, mod_x)
x = self.input_cnn(x)
x = self.hidden_nn(x)
x = self.attn_sum(x)
return x
# legacy
Input_AA_CNN_LSTM_Encoder = Encoder_26AA_Mod_CNN_LSTM_AttnSum
[docs]
class Encoder_AA_Mod_CNN_LSTM_AttnSum(torch.nn.Module):
"""
Encode AAs (128 ASCII codes) and modifications by CNN and LSTM layers,
and then by 'SeqAttentionSum'.
"""
[docs]
def __init__(self, out_features, n_lstm_layers=2):
super().__init__()
mod_hidden = 8
input_dim = out_features // 4
self.aa_mod_embedding = AA_Mod_Embedding(input_dim, mod_feature_size=mod_hidden)
self.input_cnn = SeqCNN(input_dim)
self.hidden_nn = SeqLSTM(
input_dim * 4, out_features, rnn_layer=n_lstm_layers
) # SeqCNN outputs 4*input_dim
self.attn_sum = SeqAttentionSum(out_features)
[docs]
def forward(self, aa_indices, mod_x):
x = self.aa_mod_embedding(aa_indices, mod_x)
x = self.input_cnn(x)
x = self.hidden_nn(x)
x = self.attn_sum(x)
return x
[docs]
class Encoder_AA_Mod_Transformer(torch.nn.Module):
"""
AAs (128 ASCII codes) and modifications embedded by Bert,
then encoded by 'SeqAttentionSum'.
"""
[docs]
def __init__(self, out_features, dropout=0.1, nlayers=4, output_attentions=False):
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.input_nn = Input_AA_Mod_PositionalEncoding(out_features)
self.output_attentions = output_attentions
self.encoder = Hidden_HFace_Transformer(
out_features,
nlayers=nlayers,
dropout=dropout,
output_attentions=output_attentions,
)
[docs]
def forward(self, aa_indices, mod_x, attention_mask=None):
x = self.input_nn(aa_indices, mod_x)
x = self.dropout(x)
x = self.encoder(x, attention_mask)
if self.output_attentions:
self.attentions = x[1]
else:
self.attentions = None
return x[0]
[docs]
class Encoder_AA_Mod_Transformer_AttnSum(torch.nn.Module):
"""
Encode AAs (128 ASCII codes) and modifications by transformers.
"""
[docs]
def __init__(self, out_features, dropout=0.1, nlayers=4, output_attentions=False):
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.encoder_nn = Encoder_AA_Mod_Transformer(
out_features,
dropout=dropout,
nlayers=nlayers,
output_attentions=output_attentions,
)
self.attn_sum = SeqAttentionSum(out_features)
[docs]
def forward(self, aa_indices, mod_x):
x = self.encoder_nn(aa_indices, mod_x)
return self.dropout(self.attn_sum(x))
[docs]
class Encoder_AA_Mod_Charge_Transformer(torch.nn.Module):
"""
Encode AAs (128 ASCII codes), modifications and charge by transformers.
"""
[docs]
def __init__(self, out_features, dropout=0.1, nlayers=4, output_attentions=False):
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.input_nn = Input_AA_Mod_Charge_PositionalEncoding(out_features)
self.output_attentions = output_attentions
self.encoder = Hidden_HFace_Transformer(
out_features,
nlayers=nlayers,
dropout=dropout,
output_attentions=output_attentions,
)
[docs]
def forward(
self,
aa_indices,
mod_x,
charges,
attention_mask=None,
):
x = self.input_nn(aa_indices, mod_x, charges)
x = self.dropout(x)
x = self.encoder(x, attention_mask)
if self.output_attentions:
self.attentions = x[1]
else:
self.attentions = None
return x[0]
[docs]
class Encoder_AA_Mod_Charge_Transformer_AttnSum(torch.nn.Module):
"""
Encode AAs (128 ASCII codes), modifications and charge by transformers,
and then by 'SeqAttentionSum'
"""
[docs]
def __init__(self, out_features, dropout=0.1, nlayers=4, output_attentions=False):
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.encoder_nn = Encoder_AA_Mod_Charge_Transformer(
out_features,
dropout=dropout,
nlayers=nlayers,
output_attentions=output_attentions,
)
self.attn_sum = SeqAttentionSum(out_features)
[docs]
def forward(self, aa_indices, mod_x, charges):
x = self.encoder_nn(aa_indices, mod_x, charges)
return self.dropout(self.attn_sum(x))
[docs]
class Encoder_26AA_Mod_Charge_CNN_LSTM_AttnSum(torch.nn.Module):
"""
Encode AAs (26 AA letters), modifications and charge by transformers,
and then by 'SeqAttentionSum'
"""
[docs]
def __init__(self, out_features):
super().__init__()
mod_hidden = 8
self.mod_nn = Mod_Embedding_FixFirstK(mod_hidden)
input_dim = aa_embedding_size + mod_hidden + 1
self.input_cnn = SeqCNN(input_dim)
self.hidden_nn = SeqLSTM(
input_dim * 4, out_features, rnn_layer=2
) # SeqCNN outputs 4*input_dim
self.attn_sum = SeqAttentionSum(out_features)
[docs]
def forward(self, aa_indices, mod_x, charges):
mod_x = self.mod_nn(mod_x)
x = aa_one_hot(
aa_indices, mod_x, charges.unsqueeze(1).repeat(1, mod_x.size(1), 1)
)
x = self.input_cnn(x)
x = self.hidden_nn(x)
x = self.attn_sum(x)
return x
# legacy
Input_AA_CNN_LSTM_cat_Charge_Encoder = Encoder_26AA_Mod_Charge_CNN_LSTM_AttnSum
[docs]
class Decoder_LSTM(torch.nn.Module):
"""
Decode with LSTM
"""
[docs]
def __init__(self, in_features, out_features):
super().__init__()
hidden = 128
self.rnn = SeqLSTM(
in_features,
out_features,
rnn_layer=1,
bidirectional=False,
)
self.output_nn = torch.nn.Linear(hidden, out_features, bias=False)
[docs]
def forward(self, x: torch.tensor, output_len):
x = self.rnn(x.unsqueeze(1).repeat(1, output_len, 1))
x = self.output_nn(x)
return x
# legacy
SeqLSTMDecoder = Decoder_LSTM
[docs]
class Decoder_GRU(torch.nn.Module):
"""
Decode with GRU
"""
[docs]
def __init__(self, in_features, out_features):
super().__init__()
hidden = 128
self.rnn = SeqGRU(
in_features,
out_features,
rnn_layer=1,
bidirectional=False,
)
self.output_nn = torch.nn.Linear(hidden, out_features, bias=False)
[docs]
def forward(self, x: torch.tensor, output_len):
x = self.rnn(x.unsqueeze(1).repeat(1, output_len, 1))
x = self.output_nn(x)
return x
# legacy
SeqGRUDecoder = Decoder_GRU
[docs]
class Decoder_Linear(torch.nn.Module):
"""
Decode w linear NN
"""
[docs]
def __init__(self, in_features, out_features):
super().__init__()
self.nn = torch.nn.Sequential(
torch.nn.Linear(in_features, 64),
torch.nn.PReLU(),
torch.nn.Linear(64, out_features),
)
# legacy
LinearDecoder = Decoder_Linear