File size: 10,053 Bytes
3527383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import torch
from flow_matching.utils import categorical
import math
import inspect
import random
def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor:
def rec(n, H):
if n == 1:
return [[H]]
points = []
for i in range(H + 1):
for tail in rec(n - 1, H - i):
points.append([i] + tail)
return points
points = rec(num_obj, num_div)
weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div
return weight_vectors
def select_random_weight_vector(num_obj: int, num_div: int):
weight_vectors = generate_simplex_lattice_points(num_obj, num_div)
idx = torch.randint(0, weight_vectors.size(0), (1,)).item()
random_weight_vector = weight_vectors[idx]
return random_weight_vector, weight_vectors
def z_score_norm(tensor, eps=1e-8):
mean = tensor.mean(dim=-1, keepdim=True)
std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps)
return (tensor - mean) / std
def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args):
B, L, vocab_size = u_t.shape
device = x_t.device
guided_u_t = u_t.clone()
# 1. Randomly select one position per sequence.
# pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) # shape: (B,) # CHANGE!
pos_indices = torch.tensor([random.choice([i for i in range(1, L-2) if i != 6])]).to(x_t.device)
batch_idx = torch.arange(B, device=device)
current_tokens = x_t[batch_idx, pos_indices] # shape: (B,)
# 2. Build candidate tokens for each sequence and remove self-transition.
full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) # (B, vocab_size)
mask = (full_cand_tokens != current_tokens.unsqueeze(1)) & (full_cand_tokens != 23) # (B, vocab_size)
# Now, cand_tokens contains only candidate tokens that differ from the current token.
cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 2) # (B, vocab_size-1)
# 3. Create candidate sequences by replacing the token at the selected position.
new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone()
new_x = new_x[mask].view(B, vocab_size - 2, L) # (B, vocab_size-1, L)
new_x[batch_idx, :, pos_indices] = cand_tokens
new_x_flat = new_x.view(B * (vocab_size - 2), L)
improvements_list = []
with torch.no_grad():
count = 0
for i, s in enumerate(s_models):
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
if 't' in sig.parameters:
candidate_scores = s(new_x_flat, t)
base_score = s(x_t, t)
else:
candidate_scores = s(new_x_flat)
base_score = s(x_t)
if isinstance(candidate_scores, tuple):
for k, score in enumerate(candidate_scores):
improvement = candidate_scores[k].view(B, vocab_size - 2) - base_score[k].unsqueeze(1)
improvement = improvement.float()
improvement *= importance[count]
improvements_list.append(improvement.unsqueeze(2))
count += 1
else:
improvement = candidate_scores.view(B, vocab_size - 2) - base_score.unsqueeze(1)
improvement = improvement.float()
improvement *= importance[count]
improvements_list.append(improvement.unsqueeze(2)) # (B, vocab_size-1, 1)
count += 1
improvement_values = torch.cat(improvements_list, dim=2) # (B, vocab_size-1, N)
if args.is_peptide:
improvement_values[:, :4, :] = -10 # Mask non-residue positions
# 5. Compute ranking scores I_n
ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 # (B, vocab_size-1, N)
I_n = ranks / float(vocab_size - 2)
avg_I = I_n.mean(dim=2)
norm_avg_I = z_score_norm(avg_I) # (B, vocab_size-1)
# 6. Compute directional score D
D = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
norm_D = z_score_norm(D) # (B, vocab_size-1)
# 7. Combine the scores
delta_S = norm_avg_I + args.lambda_ * norm_D # (B, vocab_size-1)
# 9. Update the guided velocities at the selected positions.
factor = torch.exp(args.beta * delta_S) # (B, vocab_size-1)
factor = torch.clamp(factor, min=-100, max=100)
guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor
# 10. For the self-transition (current token) at the selected position,
# set its guided velocity to be the negative sum of the updated off-diagonals.
updated_vals = guided_u_t[batch_idx, pos_indices, :] # (B, vocab_size)
sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens]
guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag
return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S
def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None):
B, num_candidates, N = improvement_values.shape
device = improvement_values.device
eps = 1e-8
# Compute norms and angles.
imp_norm = torch.norm(improvement_values.float(), dim=2) # (B, num_candidates)
dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2)
w_norm = torch.norm(w) + eps
cos_angle = dot_product / (imp_norm * w_norm + eps)
cos_angle = cos_angle.clamp(-1.0, 1.0)
angles = torch.acos(cos_angle) # (B, num_candidates)
valid_mask = angles < math.pi / 2
accepted_mask = valid_mask & (angles <= Phi) # (B, num_candidates)
# Determine the best candidate for each sequence.
# We'll use a loop over batch items (batch size is typically moderate).
best_candidate = torch.empty(B, dtype=torch.long, device=device)
for i in range(B):
# For sequence i, consider only valid candidates.
if valid_mask[i].any():
# There is at least one candidate with α^i < π.
if accepted_mask[i].any():
# At least one candidate passes the hypercone: choose the one with max delta_S among accepted.
candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf')))
else:
# No candidate was accepted, but some are valid. Select best candidate among valid ones.
candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf')))
best_candidate[i] = cand_tokens[i, candidate_idx]
else:
# No candidate is valid (all α^i >= π) → self-transition.
best_candidate[i] = -1
# Compute rejection rate only over valid candidates.
rejection_rates = []
for i in range(B):
valid_candidates = valid_mask[i]
total_valid = valid_candidates.sum().item()
if total_valid > 0:
# Among valid candidates, count how many are rejected.
num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item()
rejection_rates.append(num_rejected / total_valid)
if len(rejection_rates) > 0:
r_t = sum(rejection_rates) / len(rejection_rates)
else:
# If no sequence has any valid candidate, set r_t to 0.
r_t = 0.0
if ema_r_t is None:
ema_r_t = args.tau
# Update hypercone angle and ema rejection rate only if there is at least one valid candidate in the batch.
if valid_mask.any():
new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t
new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device))
new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item()
else:
new_ema_r_t = ema_r_t
new_Phi = Phi # No update if no valid candidate exists.
return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t
def get_best_candidate(improvement_values, cand_tokens, delta_S):
B, num_candidates, N = improvement_values.shape
device = improvement_values.device
best_candidate = torch.empty(B, dtype=torch.long, device=device)
for i in range(B):
candidate_idx = torch.argmax(delta_S[i])
best_candidate[i] = cand_tokens[i, candidate_idx]
return best_candidate
def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h):
B, L, V = guided_u_t.shape
device = x_t.device
u = torch.zeros_like(guided_u_t)
valid_mask = best_candidate != -1
if valid_mask.any():
valid_idx = torch.nonzero(valid_mask).squeeze(-1)
# For these sequences, update the velocity at the selected position and candidate token.
u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \
guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]]
# Compute intensity at the selected positions.
# For sequences with no valid candidate (i.e. self-transition), intensity remains zero.
intensity = torch.zeros(B, device=device)
if valid_mask.any():
intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1)
# According to the Euler Sampling formula, `p_jump` should be `1 - torch.exp(-h * intensity)`
# However, since `h = 1 / T` is small, p_jump becomes tiny and slows down sampling.
# To compensate, we scale `intensity` by T. We can do this because this is equivalent to setting `args.beta` to `T * args.beta`.
# So for faster sampling, we just use `1 - torch.exp(-1 * intensity)`
p_jump = 1 - torch.exp(-1 * intensity)
rand_val = torch.rand(B, device=device)
jump_decision = (rand_val < p_jump) & valid_mask
# For sequences where a jump is decided, update the token at pos_indices to best_candidate.
x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision]
return x_t
|