import os
import io
import sys
import pandas as pd
import urllib
import shutil
import ssl
import typing
from pickle import UnpicklingError
import torch.multiprocessing as mp
from peptdeep.utils.deprecations import ModuleWithDeprecations
if sys.platform.lower().startswith("linux"):
# to prevent `too many open files` bug on Linux
mp.set_sharing_strategy("file_system")
from typing import Dict
from zipfile import ZipFile
from typing import Union
from alphabase.peptide.fragment import (
create_fragment_mz_dataframe,
concat_precursor_fragment_dataframes,
)
from alphabase.peptide.precursor import refine_precursor_df, update_precursor_mz
from alphabase.peptide.mobility import mobility_to_ccs_for_df, ccs_to_mobility_for_df
from peptdeep.utils import logging, process_bar
from peptdeep.model.ms2 import (
pDeepModel,
normalize_fragment_intensities,
)
from peptdeep.model.rt import AlphaRTModel
from peptdeep.model.ccs import AlphaCCSModel
from peptdeep.model.charge import ChargeModelForAASeq, ChargeModelForModAASeq
from peptdeep.utils import uniform_sampling
from peptdeep.settings import global_settings, update_global_settings
[docs]
def get_pretrain_dir() -> str:
"""Get the pretrained models directory path dynamically from settings."""
return os.path.join(
os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models"
)
[docs]
def get_local_model_zip_name() -> str:
"""Get the local model zip file name dynamically from settings."""
return global_settings["local_model_zip_name"]
[docs]
def get_model_url() -> str:
"""Get the model URL dynamically from settings."""
return global_settings["model_url"]
[docs]
def get_model_zip_file_path() -> str:
"""Get the full path to the model zip file dynamically from settings."""
return os.path.join(get_pretrain_dir(), get_local_model_zip_name())
sys.modules[__name__].__class__ = ModuleWithDeprecations
ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "get_pretrain_dir()")
ModuleWithDeprecations.deprecate(
__name__, "model_zip_name", "get_local_model_zip_name()"
)
ModuleWithDeprecations.deprecate(__name__, "model_url", "get_model_url()")
ModuleWithDeprecations.deprecate(__name__, "model_zip", "get_model_zip_file_path()")
[docs]
def get_model_download_instructions() -> str:
"""Get the model download instructions dynamically from settings."""
return (
"Please download the "
f'zip or tar file by yourself from "{get_model_url()}",'
" and use \n"
f'"peptdeep install-models --model-file /path/to/{get_local_model_zip_name()}.zip"\n'
" to install the models"
)
[docs]
def is_model_zip(downloaded_zip):
with ZipFile(downloaded_zip) as zip:
return any(x == "generic/ms2.pth" for x in zip.namelist())
[docs]
def download_models(url: str = None, target_path: str = None, overwrite: bool = True):
"""
Parameters
----------
url : str, optional
Remote or local path.
Defaults to None, which will take the default using get_model_url()
target_path : str, optional
Target file path after download.
Defaults to None, which will take the default using get_model_zip_file_path()
overwrite : bool, optional
overwrite old model files.
Defaults to True.
Raises
------
FileNotFoundError
If remote url is not accessible.
"""
if url is None:
url = get_model_url()
if target_path is None:
target_path = get_model_zip_file_path()
if not overwrite and os.path.exists(target_path):
raise FileExistsError(f"Model file already exists: {target_path}")
if url is None:
raise ValueError(
"Cannot download models: 'model_url' is not set in settings. "
"Please either set 'model_url' in your settings file, or ensure "
"the model file already exists at the expected location."
)
if not os.path.isfile(url):
logging.info(f"Downloading pretrained models from {url} to {target_path} ...")
try:
os.makedirs(os.path.dirname(target_path), exist_ok=True)
context = ssl._create_unverified_context()
requests = urllib.request.urlopen(url, context=context, timeout=10)
with open(target_path, "wb") as f:
f.write(requests.read())
except Exception as e:
raise FileNotFoundError(
f"Downloading model failed: {e}.\n" + get_model_download_instructions()
) from e
else:
logging.info(f"Copying pretrained models from {url} to {target_path} ...")
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copy(url, target_path)
logging.info(f"Successfully downloaded pretrained models.")
def _download_models(model_zip_file_path: str = None) -> None:
"""Download models if not done yet."""
if model_zip_file_path is None:
model_zip_file_path = get_model_zip_file_path()
os.makedirs(get_pretrain_dir(), exist_ok=True)
if not os.path.exists(model_zip_file_path):
download_models()
if not is_model_zip(model_zip_file_path):
raise ValueError(
f"Local model file is not a valid zip: {model_zip_file_path}.\n"
f"Please delete this file and try again.\n"
f"Or: {get_model_download_instructions()}"
)
model_mgr_settings = global_settings["model_mgr"]
[docs]
def count_mods(psm_df) -> pd.DataFrame:
mods = psm_df[psm_df.mods.str.len() > 0].mods.apply(lambda x: x.split(";"))
mod_dict = {}
mod_dict["mutation"] = {}
mod_dict["mutation"]["spec_count"] = 0
for one_mods in mods.values:
for mod in set(one_mods):
items = mod.split("->")
if len(items) == 2 and len(items[0]) == 3 and len(items[1]) == 5:
mod_dict["mutation"]["spec_count"] += 1
elif mod not in mod_dict:
mod_dict[mod] = {}
mod_dict[mod]["spec_count"] = 1
else:
mod_dict[mod]["spec_count"] += 1
return (
pd.DataFrame()
.from_dict(mod_dict, orient="index")
.reset_index(drop=False)
.rename(columns={"index": "mod"})
.sort_values("spec_count", ascending=False)
.reset_index(drop=True)
)
[docs]
def psm_sampling_with_important_mods(
psm_df,
n_sample,
top_n_mods=10,
n_sample_each_mod=0,
uniform_sampling_column=None,
random_state=1337,
):
psm_df_list = []
if uniform_sampling_column is None:
def _sample(psm_df, n):
if n < len(psm_df):
return psm_df.sample(n, replace=False, random_state=random_state).copy()
else:
return psm_df.copy()
else:
def _sample(psm_df, n):
if len(psm_df) == 0:
return psm_df
return uniform_sampling(
psm_df,
target=uniform_sampling_column,
n_train=n,
random_state=random_state,
)
psm_df_list.append(_sample(psm_df, n_sample))
if n_sample_each_mod > 0:
mod_df = count_mods(psm_df)
mod_df = mod_df[mod_df["mod"] != "mutation"]
if len(mod_df) > top_n_mods:
mod_df = mod_df.iloc[:top_n_mods, :]
for mod in mod_df["mod"].values:
psm_df_list.append(
_sample(
psm_df[psm_df.mods.str.contains(mod, regex=False)],
n_sample_each_mod,
)
)
if len(psm_df_list) > 0:
return pd.concat(psm_df_list, ignore_index=True)
else:
return pd.DataFrame()
[docs]
def load_phos_models(mask_modloss=True):
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(model_zip_file_path, model_path_in_zip="phospho/ms2_phos.pth")
rt_model = AlphaRTModel()
rt_model.load(model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model
[docs]
def load_models(mask_modloss=True):
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(model_zip_file_path, model_path_in_zip="generic/ms2.pth")
rt_model = AlphaRTModel()
rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(model_zip_file_path, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model
[docs]
def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True):
model_zip_file_path = get_model_zip_file_path()
_download_models(model_zip_file_path)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(
model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ms2.pth"
)
rt_model = AlphaRTModel()
rt_model.load(model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(
model_zip_file_path, model_path_in_zip=f"{model_type_in_zip}/ccs.pth"
)
return ms2_model, rt_model, ccs_model
[docs]
def clear_error_modloss_intensities(fragment_mz_df, fragment_intensity_df):
# clear error modloss intensities
for col in fragment_mz_df.columns.values:
if "modloss" in col:
fragment_intensity_df.loc[fragment_mz_df[col] == 0, col] = 0
[docs]
class ModelManager(object):
"""
The manager class to access MS2/RT/CCS models.
Attributes
----------
ms2_model : peptdeep.model.ms2.pDeepModel
The MS2 prediction model.
rt_model : peptdeep.model.rt.AlphaRTModel
The RT prediction model.
ccs_model : peptdeep.model.ccs.AlphaCCSModel
The CCS prediciton model.
psm_num_to_train_ms2 : int
Number of PSMs to train the MS2 model.
Defaults to global_settings['model_mgr']['transfer']['psm_num_to_train_ms2'].
epoch_to_train_ms2 : int
Number of epoches to train the MS2 model.
Defaults to global_settings['model_mgr']['transfer']['epoch_ms2'].
psm_num_to_train_rt_ccs : int
Number of PSMs to train RT/CCS model.
Defaults to global_settings['model_mgr']['transfer']['psm_num_to_train_rt_ccs'].
epoch_to_train_rt_ccs : int
Number of epoches to train RT/CCS model.
Defaults to global_settings['model_mgr']['transfer']['epoch_rt_ccs'].
nce : float
Default NCE value for a precursor_df without the 'nce' column.
Defaults to global_settings['model_mgr']['default_nce'].
instrument : str
Default instrument type for a precursor_df without the 'instrument' column.
Defaults to global_settings['model_mgr']['default_instrument'].
use_grid_nce_search : bool
If self.ms2_model uses `peptdeep.model.ms2.pDeepModel.grid_nce_search()` to determine optimal
NCE and instrument type. This will change `self.nce` and `self.instrument` values.
Defaults to global_settings['model_mgr']['transfer']['grid_nce_search'].
"""
[docs]
def __init__(
self,
mask_modloss: bool = False,
device: str = "gpu",
):
"""
Parameters
----------
mask_modloss : bool, optional
If modloss ions are masked to zeros in the ms2 model. `modloss`
ions are mostly useful for phospho MS2 prediciton model.
Defaults to True.
device : str, optional
Device for DL models, could be 'gpu' ('cuda') or 'cpu'.
if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'.
Defaults to 'gpu'
"""
_download_models(get_model_zip_file_path())
self._device = device
self._train_psm_logging = True
self.ms2_model: pDeepModel = pDeepModel(
mask_modloss=mask_modloss, device=self._device
)
self.rt_model: AlphaRTModel = AlphaRTModel(device=self._device)
self.ccs_model: AlphaCCSModel = AlphaCCSModel(device=self._device)
self.charge_model: ChargeModelForModAASeq = ChargeModelForModAASeq(
device=self._device
)
self.load_installed_models()
self.reset_by_global_settings(reload_models=False)
[docs]
def reinitialize_ms2_model(self, charged_frag_types: typing.List[str], **kwargs):
"""
Reinitialize the MS2 model with new charged fragment types.
Parameters
----------
charged_frag_types : List[str]
Charged fragment types for the new MS2 model.
kwargs : dict
Other keyword arguments for `pDeepModel`.
"""
device = kwargs.pop("device", self._device)
if device != self._device:
logging.warning(
f"Overwriting MS2 model device from '{self._device}' to '{device}'"
)
self.ms2_model = pDeepModel(
charged_frag_types=charged_frag_types,
device=device,
**kwargs,
)
[docs]
def reset_by_global_settings(
self,
reload_models=True,
):
mgr_settings = global_settings["model_mgr"]
if reload_models:
self.load_installed_models(mgr_settings["model_type"])
self.load_external_models(
ms2_model_file=mgr_settings["external_ms2_model"],
rt_model_file=mgr_settings["external_rt_model"],
ccs_model_file=mgr_settings["external_ccs_model"],
charge_model_file=mgr_settings["external_charge_model"],
)
self.ms2_model.model._mask_modloss = global_settings["model_mgr"][
"mask_modloss"
]
device = global_settings["torch_device"]["device_type"]
self.ms2_model.set_device(device)
self.rt_model.set_device(device)
self.ccs_model.set_device(device)
self.charge_model.set_device(device)
self.use_grid_nce_search = mgr_settings["transfer"]["grid_nce_search"]
self.psm_num_to_train_ms2 = mgr_settings["transfer"]["psm_num_to_train_ms2"]
self.psm_num_to_test_ms2 = mgr_settings["transfer"]["psm_num_to_test_ms2"]
self.epoch_to_train_ms2 = mgr_settings["transfer"]["epoch_ms2"]
self.warmup_epoch_to_train_ms2 = mgr_settings["transfer"]["warmup_epoch_ms2"]
self.batch_size_to_train_ms2 = mgr_settings["transfer"]["batch_size_ms2"]
self.lr_to_train_ms2 = float(mgr_settings["transfer"]["lr_ms2"])
self.psm_num_to_train_rt_ccs = mgr_settings["transfer"][
"psm_num_to_train_rt_ccs"
]
self.psm_num_to_test_rt_ccs = mgr_settings["transfer"]["psm_num_to_test_rt_ccs"]
self.epoch_to_train_rt_ccs = mgr_settings["transfer"]["epoch_rt_ccs"]
self.warmup_epoch_to_train_rt_ccs = mgr_settings["transfer"][
"warmup_epoch_rt_ccs"
]
self.batch_size_to_train_rt_ccs = mgr_settings["transfer"]["batch_size_rt_ccs"]
self.lr_to_train_rt_ccs = float(mgr_settings["transfer"]["lr_rt_ccs"])
self.psm_num_per_mod_to_train_ms2 = mgr_settings["transfer"][
"psm_num_per_mod_to_train_ms2"
]
self.psm_num_per_mod_to_train_rt_ccs = mgr_settings["transfer"][
"psm_num_per_mod_to_train_rt_ccs"
]
# loading charge model parameters
self.charge_model.predict_batch_size = mgr_settings["predict"][
"batch_size_charge"
]
self.charge_prob_cutoff = mgr_settings["charge_prob_cutoff"]
self.use_predicted_charge_in_speclib = mgr_settings[
"use_predicted_charge_in_speclib"
]
self.psm_num_to_test_charge = mgr_settings["transfer"]["psm_num_to_test_charge"]
self.psm_num_to_train_charge = mgr_settings["transfer"][
"psm_num_to_train_charge"
]
self.psm_num_per_mod_to_train_charge = mgr_settings["transfer"][
"psm_num_per_mod_to_train_charge"
]
self.epoch_to_train_charge = mgr_settings["transfer"]["epoch_charge"]
self.batch_size_to_train_charge = mgr_settings["transfer"]["batch_size_charge"]
self.lr_to_train_charge = float(mgr_settings["transfer"]["lr_charge"])
self.warmup_epoch_to_train_charge = mgr_settings["transfer"][
"warmup_epoch_charge"
]
self.top_n_mods_to_train = mgr_settings["transfer"]["top_n_mods_to_train"]
self.nce = mgr_settings["default_nce"]
if self.nce == "from_ms_file":
self.use_grid_nce_search = False
self.instrument = mgr_settings["default_instrument"]
self.verbose = mgr_settings["predict"]["verbose"]
self.train_verbose = mgr_settings["transfer"]["verbose"]
@property
def instrument(self):
return self._instrument
@instrument.setter
def instrument(self, instrument_name: str):
instrument_name = instrument_name.upper()
if instrument_name in model_mgr_settings["instrument_group"]:
self._instrument = model_mgr_settings["instrument_group"][instrument_name]
else:
self._instrument = "Lumos"
[docs]
def set_default_nce_instrument(self, df):
"""
Append 'nce' and 'instrument' columns into df
with self.nce and self.instrument
"""
if "nce" not in df.columns and "instrument" not in df.columns:
df["nce"] = float(self.nce)
df["instrument"] = self.instrument
elif "nce" not in df.columns:
df["nce"] = float(self.nce)
elif "instrument" not in df.columns:
df["instrument"] = self.instrument
[docs]
def set_default_nce(self, df):
"""Alias for `set_default_nce_instrument`"""
self.set_default_nce_instrument(df)
[docs]
def save_models(self, folder: str):
"""Save MS2/RT/CCS models into a folder
Parameters
----------
folder : str
folder to save
"""
if os.path.isdir(folder):
self.ms2_model.save(os.path.join(folder, "ms2.pth"))
self.rt_model.save(os.path.join(folder, "rt.pth"))
self.ccs_model.save(os.path.join(folder, "ccs.pth"))
if self.charge_model is not None:
self.charge_model.save(os.path.join(folder, "charge.pth"))
elif not os.path.exists(folder):
os.makedirs(folder)
self.save_models(folder)
[docs]
def load_installed_models(self, model_type: str = "generic"):
"""Load built-in MS2/CCS/RT models.
Parameters
----------
model_type : str, optional
To load the installed MS2/RT/CCS models or phos MS2/RT/CCS models.
It could be 'digly', 'phospho', 'HLA', or 'generic'.
Defaults to 'generic'.
"""
model_zip_file_path = get_model_zip_file_path()
if model_type.lower() in ["phospho", "phos", "phosphorylation"]:
self.ms2_model.load(
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
model_zip_file_path, model_path_in_zip="phospho/rt_phos.pth"
)
self.ccs_model.load(
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in [
"digly",
"glygly",
"ubiquitylation",
"ubiquitination",
"ubiquitinylation",
]:
self.ms2_model.load(
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
model_zip_file_path, model_path_in_zip="digly/rt_digly.pth"
)
self.ccs_model.load(
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in ["regular", "common", "generic"]:
self.ms2_model.load(
model_zip_file_path, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(model_zip_file_path, model_path_in_zip="generic/rt.pth")
self.ccs_model.load(
model_zip_file_path, model_path_in_zip="generic/ccs.pth"
)
self.charge_model.load(
model_zip_file_path, model_path_in_zip="generic/charge.pth"
)
elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]:
self.load_installed_models(model_type="generic")
else:
logging.warning(
f"model_type='{model_type}' is not supported, use 'generic' instead."
)
self.load_installed_models(model_type="generic")
[docs]
def load_external_models(
self,
*,
ms2_model_file: Union[str, io.BytesIO] = "",
rt_model_file: Union[str, io.BytesIO] = "",
ccs_model_file: Union[str, io.BytesIO] = "",
charge_model_file: Union[str, io.BytesIO] = "",
):
"""Load external MS2/RT/CCS models.
Parameters
----------
ms2_model_file : Tuple[str, io.BytesIO], optional
MS2 model file or stream. Do nothing if the value is '' or None.
Defaults to ''.
rt_model_file : Tuple[str, io.BytesIO], optional
RT model file or stream. Do nothing if the value is '' or None.
Defaults to ''.
ccs_model_file : Tuple[str, io.BytesIO], optional
CCS model or stream. Do nothing if the value is '' or None.
Defaults to ''.
charge_model_file : Tuple[str, io.BytesIO], optional
Charge model or stream. Do nothing if the value is '' or None.
Defaults to ''.
"""
def _load_file(model, model_file):
if model_file is None:
return
try:
if isinstance(model_file, str):
if os.path.isfile(model_file):
model.load(model_file)
else:
return
else:
model.load(model_file)
except (UnpicklingError, TypeError, ValueError, KeyError):
logging.info(
f"Cannot load {model_file} as {model.__class__} model, peptdeep will use the pretrained model instead."
)
if isinstance(ms2_model_file, str) and ms2_model_file:
logging.info(f"Using external ms2 model: '{ms2_model_file}'")
if not os.path.isfile(ms2_model_file):
logging.info(" -- This model file does not exist")
_load_file(self.ms2_model, ms2_model_file)
if isinstance(rt_model_file, str) and rt_model_file:
logging.info(f"Using external rt model: '{rt_model_file}'")
if not os.path.isfile(rt_model_file):
logging.info(" -- This model file does not exist")
_load_file(self.rt_model, rt_model_file)
if isinstance(ccs_model_file, str) and ccs_model_file:
logging.info(f"Using external ccs model: '{ccs_model_file}'")
if not os.path.isfile(ccs_model_file):
logging.info(" -- This model file does not exist")
_load_file(self.ccs_model, ccs_model_file)
if isinstance(charge_model_file, str) and charge_model_file:
logging.info(f"Using external charge model: '{charge_model_file}'")
if not os.path.isfile(charge_model_file):
logging.info(" -- This model file does not exist")
_load_file(self.charge_model, charge_model_file)
[docs]
def train_rt_model(
self,
psm_df: pd.DataFrame,
):
"""
Train/fine-tune the RT model. The fine-tuning will be skipped
if `self.psm_num_to_train_rt_ccs` is zero.
Parameters
----------
psm_df : pd.DataFrame
Training psm_df which contains 'rt_norm' column.
"""
psm_df = (
psm_df.groupby(["sequence", "mods", "mod_sites"])[["rt_norm"]]
.median()
.reset_index(drop=False)
)
if self.psm_num_to_train_rt_ccs > 0:
if self.psm_num_to_train_rt_ccs < len(psm_df):
tr_df = psm_sampling_with_important_mods(
psm_df,
self.psm_num_to_train_rt_ccs,
self.top_n_mods_to_train,
self.psm_num_per_mod_to_train_rt_ccs,
).copy()
else:
tr_df = psm_df
if self._train_psm_logging:
logging.info(
f"{len(tr_df)} PSMs for RT model training/transfer learning"
)
else:
tr_df = []
if self.psm_num_to_test_rt_ccs > 0:
if len(tr_df) > 0:
test_psm_df = psm_df[~psm_df.sequence.isin(set(tr_df.sequence))].copy()
if len(test_psm_df) > self.psm_num_to_test_rt_ccs:
test_psm_df = test_psm_df.sample(
n=self.psm_num_to_test_rt_ccs
).copy()
elif len(test_psm_df) == 0:
logging.info(
"No enough PSMs for testing RT models, "
"please reduce the `psm_num_to_train_rt_ccs` "
"value according to overall peptide numbers. "
)
test_psm_df = []
else:
test_psm_df = psm_df
else:
test_psm_df = []
if len(test_psm_df) > 0:
logging.info(
"Testing pretrained RT model:\n" + str(self.rt_model.test(test_psm_df))
)
if len(tr_df) > 0:
self.rt_model.train(
tr_df,
batch_size=self.batch_size_to_train_rt_ccs,
epoch=self.epoch_to_train_rt_ccs,
warmup_epoch=self.warmup_epoch_to_train_rt_ccs,
lr=self.lr_to_train_rt_ccs,
verbose=self.train_verbose,
)
if len(test_psm_df) > 0:
logging.info(
"Testing refined RT model:\n" + str(self.rt_model.test(test_psm_df))
)
[docs]
def train_charge_model(self, psm_df: pd.DataFrame):
"""
Train/fine-tune the charge model.
Parameters
----------
psm_df : pd.DataFrame
Training psm_df which contains 'charge' column.
"""
psm_df = self.charge_model.create_charge_indicators(psm_df.copy())
if self.psm_num_to_train_charge > 0:
if self.psm_num_to_train_charge < len(psm_df):
tr_df = psm_sampling_with_important_mods(
psm_df,
self.psm_num_to_train_charge,
self.top_n_mods_to_train,
self.psm_num_per_mod_to_train_charge,
).copy()
else:
tr_df = psm_df
if self._train_psm_logging:
logging.info(
f"{len(tr_df)} PSMs for charge model training/transfer learning"
)
else:
tr_df = []
if self.psm_num_to_test_charge > 0:
if len(tr_df) > 0:
test_psm_df = psm_df[~psm_df.sequence.isin(set(tr_df.sequence))].copy()
if len(test_psm_df) > self.psm_num_to_test_charge:
test_psm_df = test_psm_df.sample(
n=self.psm_num_to_test_charge
).copy()
elif len(test_psm_df) == 0:
logging.info(
"No enough PSMs for testing CHarge models, "
"please reduce the `psm_num_to_train_charge` "
"value according to overall peptide numbers. "
)
test_psm_df = []
else:
test_psm_df = psm_df
else:
test_psm_df = []
if len(test_psm_df) > 0:
logging.info(
"Testing pretrained RT model:\n" + str(self.rt_model.test(test_psm_df))
)
# train_df = grouped_df.sample(frac=0.8)
# test_df = grouped_df.drop(train_df.index)
# if len(test_df) > 0:
# logging.info(
# "Testing pretrained charge model:\n"
# + str(self.charge_model.test(test_df))
# )
self.charge_model.train(
tr_df,
batch_size=self.batch_size_to_train_charge,
epoch=self.epoch_to_train_charge,
warmup_epoch=self.warmup_epoch_to_train_charge,
lr=self.lr_to_train_charge,
verbose=self.train_verbose,
)
if len(test_psm_df) > 0:
logging.info(
"Testing refined charge model:\n"
+ str(self.charge_model.test(test_psm_df))
)
[docs]
def train_ccs_model(
self,
psm_df: pd.DataFrame,
):
"""
Train/fine-tune the CCS model. The fine-tuning will be skipped
if `self.psm_num_to_train_rt_ccs` is zero.
Parameters
----------
psm_df : pd.DataFrame
Training psm_df which contains 'ccs' or 'mobility' column.
"""
if "mobility" not in psm_df.columns or "ccs" not in psm_df.columns:
return
elif "ccs" not in psm_df.columns:
psm_df["ccs"] = mobility_to_ccs_for_df(psm_df, "mobility")
elif "mobility" not in psm_df.columns:
psm_df["mobility"] = ccs_to_mobility_for_df(psm_df, "ccs")
psm_df = (
psm_df.groupby(["sequence", "mods", "mod_sites", "charge"])[
["mobility", "ccs"]
]
.median()
.reset_index(drop=False)
)
if self.psm_num_to_train_rt_ccs > 0:
if self.psm_num_to_train_rt_ccs < len(psm_df):
tr_df = psm_sampling_with_important_mods(
psm_df,
self.psm_num_to_train_rt_ccs,
self.top_n_mods_to_train,
self.psm_num_per_mod_to_train_rt_ccs,
).copy()
else:
tr_df = psm_df
if self._train_psm_logging:
logging.info(
f"{len(tr_df)} PSMs for CCS model training/transfer learning"
)
else:
tr_df = []
if self.psm_num_to_test_rt_ccs > 0:
if len(tr_df) > 0:
test_psm_df = psm_df[~psm_df.sequence.isin(set(tr_df.sequence))].copy()
if len(test_psm_df) > self.psm_num_to_test_rt_ccs:
test_psm_df = test_psm_df.sample(
n=self.psm_num_to_test_rt_ccs
).copy()
elif len(test_psm_df) == 0:
logging.info(
"No enough PSMs for testing CCS models, "
"please reduce the `psm_num_to_train_rt_ccs` "
"value according to overall precursor numbers. "
)
test_psm_df = []
else:
test_psm_df = psm_df
else:
test_psm_df = []
if len(test_psm_df) > 0:
logging.info(
"Testing pretrained CCS model:\n"
+ str(self.ccs_model.test(test_psm_df))
)
if len(tr_df) > 0:
self.ccs_model.train(
tr_df,
batch_size=self.batch_size_to_train_rt_ccs,
epoch=self.epoch_to_train_rt_ccs,
warmup_epoch=self.warmup_epoch_to_train_rt_ccs,
lr=self.lr_to_train_rt_ccs,
verbose=self.train_verbose,
)
if len(test_psm_df) > 0:
logging.info(
"Testing refined CCS model:\n" + str(self.ccs_model.test(test_psm_df))
)
[docs]
def train_ms2_model(
self,
psm_df: pd.DataFrame,
matched_intensity_df: pd.DataFrame,
):
"""
Using matched_intensity_df to train/fine-tune the ms2 model.
1. It will sample `n=self.psm_num_to_train_ms2` PSMs into training dataframe (`tr_df`) to for fine-tuning.
2. This method will also consider some important PTMs (`n=self.top_n_mods_to_train`) into `tr_df` for fine-tuning.
3. If `self.use_grid_nce_search==True`, this method will call `self.ms2_model.grid_nce_search` to find the best NCE and instrument.
Parameters
----------
psm_df : pd.DataFrame
PSM dataframe for fine-tuning
matched_intensity_df : pd.DataFrame
The matched fragment intensities for `psm_df`.
"""
if self.psm_num_to_train_ms2 > 0:
if self.psm_num_to_train_ms2 < len(psm_df):
tr_df = psm_sampling_with_important_mods(
psm_df,
self.psm_num_to_train_ms2,
self.top_n_mods_to_train,
self.psm_num_per_mod_to_train_ms2,
).copy()
else:
tr_df = psm_df
if len(tr_df) > 0:
tr_inten_df = pd.DataFrame()
for frag_type in self.ms2_model.charged_frag_types:
if frag_type in matched_intensity_df.columns:
tr_inten_df[frag_type] = matched_intensity_df[frag_type]
else:
tr_inten_df[frag_type] = 0.0
normalize_fragment_intensities(tr_df, tr_inten_df)
if self.use_grid_nce_search:
self.nce, self.instrument = self.ms2_model.grid_nce_search(
tr_df,
tr_inten_df,
nce_first=model_mgr_settings["transfer"]["grid_nce_first"],
nce_last=model_mgr_settings["transfer"]["grid_nce_last"],
nce_step=model_mgr_settings["transfer"]["grid_nce_step"],
search_instruments=model_mgr_settings["transfer"][
"grid_instrument"
],
)
tr_df["nce"] = self.nce
tr_df["instrument"] = self.instrument
else:
self.set_default_nce_instrument(tr_df)
else:
tr_df = []
if self.psm_num_to_test_ms2 > 0:
if len(tr_df) > 0:
test_psm_df = psm_df[~psm_df.sequence.isin(set(tr_df.sequence))].copy()
if len(test_psm_df) > self.psm_num_to_test_ms2:
test_psm_df = test_psm_df.sample(n=self.psm_num_to_test_ms2)
elif len(test_psm_df) == 0:
logging.info(
"No enough PSMs for testing MS2 models, "
"please reduce the `psm_num_to_train_ms2` "
"value according to overall PSM numbers. "
)
test_psm_df = []
else:
test_psm_df = psm_df.copy()
tr_inten_df = pd.DataFrame()
for frag_type in self.ms2_model.charged_frag_types:
if frag_type in matched_intensity_df.columns:
tr_inten_df[frag_type] = matched_intensity_df[frag_type]
else:
tr_inten_df[frag_type] = 0.0
self.set_default_nce_instrument(test_psm_df)
else:
test_psm_df = []
if len(test_psm_df) > 0:
logging.info(
"Testing pretrained MS2 model on testing df:\n"
+ str(self.ms2_model.test(test_psm_df, tr_inten_df))
)
if len(tr_df) > 0:
if self._train_psm_logging:
logging.info(
f"{len(tr_df)} PSMs for MS2 model training/transfer learning"
)
self.ms2_model.train(
tr_df,
fragment_intensity_df=tr_inten_df,
batch_size=self.batch_size_to_train_ms2,
epoch=self.epoch_to_train_ms2,
warmup_epoch=self.warmup_epoch_to_train_ms2,
lr=self.lr_to_train_ms2,
verbose=self.train_verbose,
)
logging.info(
"Testing refined MS2 model on training df:\n"
+ str(self.ms2_model.test(tr_df, tr_inten_df))
)
if len(test_psm_df) > 0:
logging.info(
"Testing refined MS2 model on testing df:\n"
+ str(self.ms2_model.test(test_psm_df, tr_inten_df))
)
[docs]
def predict_ms2(
self,
precursor_df: pd.DataFrame,
*,
batch_size: int = 512,
reference_frag_df: pd.DataFrame = None,
) -> pd.DataFrame:
"""Predict MS2 for the given precursor_df
Parameters
----------
precursor_df : pd.DataFrame
Precursor dataframe for MS2 prediction
batch_size : int, optional
Batch size for prediction.
Defaults to 512.
reference_frag_df : pd.DataFrame, optional
If precursor_df has 'frag_start_idx' pointing to reference_frag_df.
Defaults to None
Returns
-------
pd.DataFrame
Predicted fragment intensity dataframe.
If there are no such two columns in precursor_df,
it will insert 'frag_start_idx' and `frag_stop_idx` in
precursor_df pointing to this predicted fragment dataframe.
"""
self.set_default_nce_instrument(precursor_df)
if self.verbose:
logging.info("Predicting MS2 ...")
return self.ms2_model.predict(
precursor_df,
batch_size=batch_size,
reference_frag_df=reference_frag_df,
verbose=self.verbose,
)
[docs]
def predict_charge(
self,
psm_df: pd.DataFrame,
min_precursor_charge: int,
max_precursor_charge: int,
charge_prob_cutoff: float = None,
) -> pd.DataFrame:
"""
Predict charge states for a given PSM dataframe by predicting the probabilities of each charge state,
and including precursors with charge probabilities above the cutoff.
Parameters
----------
psm_df : pd.DataFrame
PSM dataframe to predict charge states.
min_precursor_charge : int
Minimum precursor charge.
max_precursor_charge : int
Maximum precursor charge.
charge_prob_cutoff : float
Charge probability cutoff for including precursors set to 0.0 to predict all charges in the given range, and set to None to use the default value from the default_settings yaml.
Returns
-------
pd.DataFrame
PSM dataframe with predicted charge states.
"""
charge_prob_cutoff = (
self.charge_prob_cutoff
if charge_prob_cutoff is None
else charge_prob_cutoff
)
return self.charge_model.predict_and_clip_charges(
psm_df,
min_precursor_charge=min_precursor_charge,
max_precursor_charge=max_precursor_charge,
charge_prob_cutoff=charge_prob_cutoff,
)
[docs]
def predict_rt(
self, precursor_df: pd.DataFrame, *, batch_size: int = 1024
) -> pd.DataFrame:
"""Predict RT ('rt_pred') inplace into `precursor_df`.
Parameters
----------
precursor_df : pd.DataFrame
precursor_df for RT prediction
batch_size : int, optional
Batch size for prediction.
Defaults to 1024.
Returns
-------
pd.DataFrame
df with 'rt_pred' and 'rt_norm_pred' columns.
"""
if self.verbose:
logging.info("Predicting RT ...")
df = self.rt_model.predict(
precursor_df, batch_size=batch_size, verbose=self.verbose
)
df["rt_norm_pred"] = df.rt_pred
return df
[docs]
def predict_mobility(
self, precursor_df: pd.DataFrame, *, batch_size: int = 1024
) -> pd.DataFrame:
"""Predict mobility (`ccs_pred` and `mobility_pred`) inplace into `precursor_df`.
Parameters
----------
precursor_df : pd.DataFrame
Precursor_df for CCS/mobility prediction
batch_size : int, optional
Batch size for prediction.
Defaults to 1024.
Returns
-------
pd.DataFrame
df with 'ccs_pred' and 'mobility_pred' columns.
"""
if self.verbose:
logging.info("Predicting mobility ...")
precursor_df = self.ccs_model.predict(
precursor_df, batch_size=batch_size, verbose=self.verbose
)
return self.ccs_model.ccs_to_mobility_pred(precursor_df)
def _predict_func_for_mp(self, arg_dict: dict):
"""Internal function, for multiprocessing"""
update_global_settings(arg_dict.pop("mp_global_settings"))
return self.predict_all(multiprocessing=False, **arg_dict)
[docs]
def predict_all_mp(
self,
precursor_df: pd.DataFrame,
*,
predict_items: list = ["rt", "mobility", "ms2"],
frag_types: list = None,
process_num: int = 8,
mp_batch_size: int = 100000,
):
self.ms2_model.model.share_memory()
self.rt_model.model.share_memory()
self.ccs_model.model.share_memory()
df_groupby = precursor_df.groupby("nAA")
mgr = mp.Manager()
mp_global_settings = mgr.dict()
mp_global_settings.update(global_settings)
def get_batch_num_mp(df_groupby):
batch_num = 0
for group_len in df_groupby.size().values:
for i in range(0, group_len, mp_batch_size):
batch_num += 1
return batch_num
def mp_param_generator(df_groupby):
for nAA, df in df_groupby:
for i in range(0, len(df), mp_batch_size):
yield {
"precursor_df": df.iloc[i : i + mp_batch_size, :],
"predict_items": predict_items,
"frag_types": frag_types,
"mp_global_settings": mp_global_settings,
}
precursor_df_list = []
if "ms2" in predict_items:
fragment_mz_df_list = []
fragment_intensity_df_list = []
else:
fragment_mz_df_list = None
if self.verbose:
logging.info(f"Predicting {','.join(predict_items)} ...")
verbose_bak = self.verbose
self.verbose = False
with mp.get_context("spawn").Pool(process_num) as p:
for ret_dict in process_bar(
p.imap_unordered(
self._predict_func_for_mp, mp_param_generator(df_groupby)
),
get_batch_num_mp(df_groupby),
):
precursor_df_list.append(ret_dict["precursor_df"])
if fragment_mz_df_list is not None:
fragment_mz_df_list.append(ret_dict["fragment_mz_df"])
fragment_intensity_df_list.append(ret_dict["fragment_intensity_df"])
self.verbose = verbose_bak
if fragment_mz_df_list is not None:
(precursor_df, fragment_mz_df, fragment_intensity_df) = (
concat_precursor_fragment_dataframes(
precursor_df_list,
fragment_mz_df_list,
fragment_intensity_df_list,
)
)
return {
"precursor_df": precursor_df,
"fragment_mz_df": fragment_mz_df,
"fragment_intensity_df": fragment_intensity_df,
}
else:
precursor_df = pd.concat(precursor_df_list)
precursor_df.reset_index(drop=True, inplace=True)
return {"precursor_df": precursor_df}
[docs]
def predict_all(
self,
precursor_df: pd.DataFrame,
*,
predict_items: list = ["rt", "mobility", "ms2"],
frag_types: list = None,
multiprocessing: bool = True,
min_required_precursor_num_for_mp: int = 3000,
process_num: int = 8,
mp_batch_size: int = 100000,
) -> Dict[str, pd.DataFrame]:
"""
Predict all items defined by `predict_items`,
which may include rt, mobility, fragment_mz
and fragment_intensity.
Parameters
----------
precursor_df : pd.DataFrame
Precursor dataframe contains `sequence`, `mods`, `mod_sites`, `charge` ... columns.
predict_items : list, optional
items ('rt', 'mobility', 'ms2') to predict.
Defaults to ['rt' ,'mobility' ,'ms2'].
frag_types : list, optional
Fragment types to predict.
If it is None, it then depends on `self.ms2_model.charged_frag_types` and
`self.ms2_model.model._mask_modloss`.
Defaults to None.
multiprocessing : bool, optional
If use multiprocessing is gpu is not available
Defaults to True.
process_num : int, optional
Defaults to 4
min_required_precursor_num_for_mp : int, optional
It will not use multiprocessing when the number of precursors in precursor_df
is lower than this value.
Defaults to 3000.
mp_batch_size : int, optional
Splitting data into batches for multiprocessing.
Defaults to 100000.
Returns
-------
Dict[str, pd.DataFrame]
`{'precursor_df': precursor_df}`
and if 'ms2' in predict_items, it also contains:
```
{
'fragment_mz_df': fragment_mz_df,
'fragment_intensity_df': fragment_intensity_df
}
```
"""
def refine_df(df):
if "ms2" in predict_items:
refine_precursor_df(df)
else:
refine_precursor_df(df, drop_frag_idx=False)
if frag_types is None:
if self.ms2_model.mask_modloss:
frag_types = [
frag
for frag in self.ms2_model.charged_frag_types
if "modloss" not in frag
]
else:
frag_types = self.ms2_model.charged_frag_types
if "precursor_mz" not in precursor_df.columns:
update_precursor_mz(precursor_df)
if (
self.ms2_model.device_type != "cpu"
or not multiprocessing
or process_num <= 1
or len(precursor_df) < min_required_precursor_num_for_mp
):
refine_df(precursor_df)
if "rt" in predict_items:
self.predict_rt(
precursor_df,
batch_size=model_mgr_settings["predict"]["batch_size_rt_ccs"],
)
if "mobility" in predict_items:
self.predict_mobility(
precursor_df,
batch_size=model_mgr_settings["predict"]["batch_size_rt_ccs"],
)
if "ms2" in predict_items:
if "frag_start_idx" in precursor_df.columns:
precursor_df.drop(
columns=["frag_start_idx", "frag_stop_idx"], inplace=True
)
fragment_mz_df = create_fragment_mz_dataframe(precursor_df, frag_types)
fragment_intensity_df = self.predict_ms2(
precursor_df,
batch_size=model_mgr_settings["predict"]["batch_size_ms2"],
)
fragment_intensity_df.drop(
columns=[
col
for col in fragment_intensity_df.columns
if col not in frag_types
],
inplace=True,
)
clear_error_modloss_intensities(fragment_mz_df, fragment_intensity_df)
return {
"precursor_df": precursor_df,
"fragment_mz_df": fragment_mz_df,
"fragment_intensity_df": fragment_intensity_df,
}
else:
return {"precursor_df": precursor_df}
else:
logging.info(f"Using multiprocessing with {process_num} processes ...")
return self.predict_all_mp(
precursor_df,
predict_items=predict_items,
process_num=process_num,
mp_batch_size=mp_batch_size,
)