|
|
import sys |
|
|
import os |
|
|
import xgboost as xgb |
|
|
import torch |
|
|
import numpy as np |
|
|
import warnings |
|
|
import numpy as np |
|
|
from rdkit import Chem, rdBase, DataStructs |
|
|
from transformers import AutoTokenizer, EsmModel |
|
|
|
|
|
rdBase.DisableLog('rdApp.error') |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
class Nonfouling: |
|
|
def __init__(self): |
|
|
|
|
|
self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_nonfouling.json') |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
self.model.eval() |
|
|
|
|
|
def generate_embeddings(self, sequences): |
|
|
"""Generate ESM embeddings for protein sequences""" |
|
|
embeddings = [] |
|
|
|
|
|
|
|
|
batch_size = 8 |
|
|
for i in range(0, len(sequences), batch_size): |
|
|
batch_sequences = sequences[i:i + batch_size] |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
batch_sequences, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
self.model = self.model.cuda() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
|
|
|
last_hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
attention_mask = inputs['attention_mask'].unsqueeze(-1) |
|
|
masked_hidden_states = last_hidden_states * attention_mask |
|
|
sum_hidden_states = masked_hidden_states.sum(dim=1) |
|
|
seq_lengths = attention_mask.sum(dim=1) |
|
|
batch_embeddings = sum_hidden_states / seq_lengths |
|
|
|
|
|
batch_embeddings = batch_embeddings.cpu().numpy() |
|
|
embeddings.append(batch_embeddings) |
|
|
|
|
|
if embeddings: |
|
|
return np.vstack(embeddings) |
|
|
else: |
|
|
return np.array([]) |
|
|
|
|
|
def get_scores(self, input_seqs: list): |
|
|
scores = np.zeros(len(input_seqs)) |
|
|
features = self.generate_embeddings(input_seqs) |
|
|
|
|
|
if len(features) == 0: |
|
|
return scores |
|
|
|
|
|
features = np.nan_to_num(features, nan=0.) |
|
|
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
|
|
|
|
|
features = xgb.DMatrix(features) |
|
|
|
|
|
scores = self.predictor.predict(features) |
|
|
return scores |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
scores = self.get_scores(input_seqs) |
|
|
return scores |
|
|
|
|
|
def unittest(): |
|
|
nonfouling = Nonfouling() |
|
|
sequences = [ |
|
|
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", |
|
|
"MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD" |
|
|
] |
|
|
|
|
|
scores = nonfouling(input_seqs=sequences) |
|
|
print(scores) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest() |