import yaml from tqdm import tqdm import torch from torch import nn from transformers import AutoTokenizer from models.peptide_classifiers import * from utils.parsing import parse_guidance_args args = parse_guidance_args() import pdb import random import inspect # MOO hyper-parameters step_size = 1 / 100 n_samples = 1 length = args.length target = args.target_protein motifs = args.motifs # args.motifs vocab_size = 24 source_distribution = "uniform" device = 'cuda:0' tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) motifs = parse_motifs(motifs).to(device) print(motifs) # Load Models solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device) motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=True) affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device) affinity_model = AffinityModel(affinity_predictor, target_sequence) score_models = [motif_model, affinity_model] for i in range(args.n_batches): if source_distribution == "uniform": x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) elif source_distribution == "mask": x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() else: raise NotImplementedError zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) x_init = torch.cat([zeros, x_init, twos], dim=1) x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, step_size=step_size, verbose=True, time_grid=torch.tensor([0.0, 1.0-1e-3]), score_models=score_models, num_objectives=3, weights=args.weights) samples = x_1.tolist() samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples] print(samples) scores = [] for i, s in enumerate(score_models): sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) if 't' in sig.parameters: candidate_scores = s(x_1, 1) else: candidate_scores = s(x_1) if isinstance(candidate_scores, tuple): for score in candidate_scores: scores.append(score.item()) else: scores.append(candidate_scores.item()) print(scores) with open(args.output_file, 'a') as f: f.write(samples[0]) for score in scores: f.write(f",{score}") f.write('\n') # samples = x_1.tolist() # sample = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples][0] # with open(f"/vast/home/c/chentong/MOG-DFM/samples/{name}.csv", "a") as f: # f.write(sample + ',' + str(score_list_0[-1]) + ',' + str(score_list_1[-1]) + '\n')