|
|
import numpy as np |
|
|
import torch |
|
|
import xgboost as xgb |
|
|
from transformers import EsmModel, EsmTokenizer |
|
|
import torch.nn as nn |
|
|
import pdb |
|
|
|
|
|
class PeptideCNN(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) |
|
|
self.fc = nn.Linear(hidden_dims[1], output_dim) |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
self.predictor = nn.Linear(output_dim, 1) |
|
|
|
|
|
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
|
|
self.esm_model.eval() |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, return_features=False): |
|
|
with torch.no_grad(): |
|
|
x = self.esm_model(input_ids, attention_mask).last_hidden_state |
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
x = nn.functional.relu(self.conv1(x)) |
|
|
x = self.dropout(x) |
|
|
x = nn.functional.relu(self.conv2(x)) |
|
|
x = self.dropout(x) |
|
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
x = x.mean(dim=1) |
|
|
|
|
|
features = self.fc(x) |
|
|
if return_features: |
|
|
return features |
|
|
return self.predictor(features) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
input_dim = 1280 |
|
|
hidden_dims = [input_dim // 2, input_dim // 4] |
|
|
output_dim = input_dim // 8 |
|
|
dropout_rate = 0.3 |
|
|
|
|
|
nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) |
|
|
nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth')) |
|
|
nn_model.eval() |
|
|
|
|
|
def predict(inputs): |
|
|
with torch.no_grad(): |
|
|
prediction = nn_model(**inputs, return_features=False) |
|
|
|
|
|
return prediction.item() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
sequence = 'RGLSDGFLKLKMGISGSLGC' |
|
|
|
|
|
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) |
|
|
|
|
|
prediction = predict(inputs) |
|
|
print(prediction) |
|
|
print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h") |
|
|
|