|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
from pathlib import Path |
|
|
import inspect |
|
|
|
|
|
from models.peptide_classifiers import * |
|
|
from utils.parsing import parse_guidance_args |
|
|
args = parse_guidance_args() |
|
|
|
|
|
import pdb |
|
|
import random |
|
|
import inspect |
|
|
import csv |
|
|
|
|
|
|
|
|
step_size = 1 / 100 |
|
|
n_samples = 1 |
|
|
vocab_size = 24 |
|
|
source_distribution = "uniform" |
|
|
device = 'cuda:0' |
|
|
|
|
|
length = args.length |
|
|
target = args.target_protein |
|
|
if args.motifs: |
|
|
motifs = parse_motifs(args.motifs).to(device) |
|
|
print(motifs) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) |
|
|
|
|
|
|
|
|
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) |
|
|
|
|
|
score_models = [] |
|
|
if 'Hemolysis' in args.objectives: |
|
|
hemolysis_model = HemolysisModel(device=device) |
|
|
score_models.append(hemolysis_model) |
|
|
if 'Non-Fouling' in args.objectives: |
|
|
nonfouling_model = NonfoulingModel(device=device) |
|
|
score_models.append(nonfouling_model) |
|
|
if 'Solubility' in args.objectives: |
|
|
solubility_model = SolubilityModelNew(device=device) |
|
|
score_models.append(solubility_model) |
|
|
if 'Half-Life' in args.objectives: |
|
|
halflife_model = HalfLifeModel(device=device) |
|
|
score_models.append(halflife_model) |
|
|
if 'Affinity' in args.objectives: |
|
|
affinity_predictor = load_affinity_predictor('./classifier_ckpt/binding_affinity_unpooled.pt', device) |
|
|
affinity_model = AffinityModel(affinity_predictor, target_sequence) |
|
|
score_models.append(affinity_model) |
|
|
if 'Motif' in args.objectives or 'Specificity' in args.objectives: |
|
|
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device) |
|
|
motif_model = MotifModel(bindevaluator, target_sequence, motifs, penalty=args.motif_penalty) |
|
|
score_models.append(motif_model) |
|
|
|
|
|
objective_line = str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n' |
|
|
|
|
|
if Path(args.output_file).exists(): |
|
|
with open(args.output_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
if lines[0] != objective_line: |
|
|
with open(args.output_file, 'w') as f: |
|
|
f.write(objective_line) |
|
|
else: |
|
|
with open(args.output_file, 'w') as f: |
|
|
f.write(objective_line) |
|
|
|
|
|
for i in range(args.n_batches): |
|
|
if source_distribution == "uniform": |
|
|
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) |
|
|
elif source_distribution == "mask": |
|
|
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) |
|
|
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) |
|
|
x_init = torch.cat([zeros, x_init, twos], dim=1) |
|
|
|
|
|
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, |
|
|
step_size=step_size, |
|
|
verbose=True, |
|
|
time_grid=torch.tensor([0.0, 1.0-1e-3]), |
|
|
score_models=score_models, |
|
|
num_objectives=len(score_models) + int(args.motif_penalty), |
|
|
weights=args.weights) |
|
|
|
|
|
samples = x_1.tolist() |
|
|
samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples] |
|
|
print(samples) |
|
|
|
|
|
scores = [] |
|
|
for i, s in enumerate(score_models): |
|
|
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
|
|
if 't' in sig.parameters: |
|
|
candidate_scores = s(x_1, 1) |
|
|
else: |
|
|
candidate_scores = s(x_1) |
|
|
|
|
|
if args.objectives[i] == 'Half-Life': |
|
|
candidate_scores = 10 ** (candidate_scores * 2) |
|
|
if args.objectives[i] == 'Affinity': |
|
|
candidate_scores = 10 * candidate_scores |
|
|
|
|
|
if isinstance(candidate_scores, tuple): |
|
|
for score in candidate_scores: |
|
|
scores.append(score.item()) |
|
|
else: |
|
|
scores.append(candidate_scores.item()) |
|
|
print(scores) |
|
|
|
|
|
with open(args.output_file, 'a') as f: |
|
|
f.write(samples[0]) |
|
|
for score in scores: |
|
|
f.write(f",{score}") |
|
|
f.write('\n') |