moPPIt-v3 / moppit.py
AlienChen's picture
Update moppit.py
c20ca8e verified
import yaml
from tqdm import tqdm
import torch
from torch import nn
from transformers import AutoTokenizer
from models.peptide_classifiers import *
from utils.parsing import parse_guidance_args
args = parse_guidance_args()
import pdb
import random
import inspect
# MOO hyper-parameters
step_size = 1 / 100
n_samples = 1
length = args.length
target = args.target_protein
motifs = args.motifs # args.motifs
vocab_size = 24
source_distribution = "uniform"
device = 'cuda:0'
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
motifs = parse_motifs(motifs).to(device)
print(motifs)
# Load Models
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True)
affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device)
affinity_model = AffinityModel(affinity_predictor, target_sequence)
score_models = [motif_model, affinity_model]
for i in range(args.n_batches):
if source_distribution == "uniform":
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device)
elif source_distribution == "mask":
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
else:
raise NotImplementedError
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
x_init = torch.cat([zeros, x_init, twos], dim=1)
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
step_size=step_size,
verbose=True,
time_grid=torch.tensor([0.0, 1.0-1e-3]),
score_models=score_models,
num_objectives=3,
weights=args.weights)
samples = x_1.tolist()
samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples]
print(samples)
scores = []
for i, s in enumerate(score_models):
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
if 't' in sig.parameters:
candidate_scores = s(x_1, 1)
else:
candidate_scores = s(x_1)
if isinstance(candidate_scores, tuple):
for score in candidate_scores:
scores.append(score.item())
else:
scores.append(candidate_scores.item())
print(scores)
with open(args.output_file, 'a') as f:
f.write(samples[0])
for score in scores:
f.write(f",{score}")
f.write('\n')
# samples = x_1.tolist()
# sample = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples][0]
# with open(f"/vast/home/c/chentong/MOG-DFM/samples/{name}.csv", "a") as f:
# f.write(sample + ',' + str(score_list_0[-1]) + ',' + str(score_list_1[-1]) + '\n')