File size: 2,559 Bytes
3527383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import numpy as np
import torch
import xgboost as xgb
from transformers import EsmModel, EsmTokenizer
import torch.nn as nn
import pdb
class PeptideCNN(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
super().__init__()
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
self.fc = nn.Linear(hidden_dims[1], output_dim)
self.dropout = nn.Dropout(dropout_rate)
self.predictor = nn.Linear(output_dim, 1) # For regression/classification
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
self.esm_model.eval()
def forward(self, input_ids, attention_mask=None, return_features=False):
with torch.no_grad():
x = self.esm_model(input_ids, attention_mask).last_hidden_state
# pdb.set_trace()
# x shape: (B, L, input_dim)
x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
x = nn.functional.relu(self.conv1(x))
x = self.dropout(x)
x = nn.functional.relu(self.conv2(x))
x = self.dropout(x)
x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
# Global average pooling over the sequence dimension (L)
x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
features = self.fc(x) # features shape: (B, output_dim)
if return_features:
return features
return self.predictor(features) # Output shape: (B, 1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 1280
hidden_dims = [input_dim // 2, input_dim // 4]
output_dim = input_dim // 8
dropout_rate = 0.3
nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth'))
nn_model.eval()
def predict(inputs):
with torch.no_grad():
prediction = nn_model(**inputs, return_features=False)
return prediction.item()
if __name__ == '__main__':
sequence = 'RGLSDGFLKLKMGISGSLGC'
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
prediction = predict(inputs)
print(prediction)
print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h")
|