Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import math | |
| from einops import rearrange | |
| from torch.nn import functional as F | |
| def add_gumbel_noise(t, temperature, device): | |
| return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) | |
| class MUSE(object): | |
| def __init__(self, codebook_size, device, ignore_ind=-1, smoothing=0., gen_temp=4.5): | |
| self.mask_ind = codebook_size # for input masking | |
| self.ignore_ind = ignore_ind # for ce loss, excluding visible | |
| self.device = device | |
| self.smoothing = smoothing | |
| self.gen_temp = gen_temp | |
| def cosine_schedule(t): | |
| return torch.cos(t * math.pi * 0.5) | |
| def sample(self, x0): | |
| N, L, device = *x0.shape, self.device | |
| timesteps = torch.zeros((N,), device=device).float().uniform_(0, 1) | |
| rand_mask_probs = self.cosine_schedule(timesteps) # cosine schedule | |
| num_token_masked = (L * rand_mask_probs).round().clamp(min=1) | |
| batch_randperm = torch.rand(N, L, device=device).argsort(dim=-1) | |
| mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1') | |
| masked_ids = torch.where(mask, self.mask_ind, x0) | |
| labels = torch.where(mask, x0, self.ignore_ind) | |
| return labels, masked_ids | |
| def loss(self, pred, label): | |
| return F.cross_entropy(pred.transpose(1, 2), label.long(), | |
| ignore_index=self.ignore_ind, label_smoothing=self.smoothing) | |
| def generate(self, config, _n_samples, nnet, decode_fn, is_eval=False, **kwargs): | |
| fmap_size, _sample_steps, device = config.z_shape[-1], config.sample.sample_steps, self.device | |
| seq_len = fmap_size ** 2 | |
| ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) | |
| cfg_scale = 0. | |
| for step in range(_sample_steps): | |
| ratio = 1. * (step + 1) / _sample_steps | |
| annealed_temp = self.gen_temp * (1 - ratio) | |
| is_mask = (ids == self.mask_ind) | |
| logits = nnet(ids, **kwargs, scale=cfg_scale) | |
| # sampling & scoring | |
| sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) | |
| sampled_logits = torch.squeeze( | |
| torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
| sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
| sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
| # masking | |
| mask_ratio = np.cos(ratio * math.pi * 0.5) | |
| mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) | |
| mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
| torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
| mask_len))[0].squeeze() | |
| confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) | |
| sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
| cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
| masking = (confidence <= cut_off) | |
| ids = torch.where(masking, self.mask_ind, sampled_ids) | |
| cfg_scale = ratio * config.sample.scale | |
| _z1 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) | |
| # with adapter | |
| ids = torch.full((_n_samples, seq_len), self.mask_ind, dtype=torch.long, device=device) | |
| cfg_scale = 0. | |
| lambdaA=0. | |
| lambdaB=0. | |
| for step in range(_sample_steps): | |
| ratio = 1. * (step + 1) / _sample_steps | |
| annealed_temp = self.gen_temp * (1 - ratio) | |
| is_mask = (ids == self.mask_ind) | |
| # 尝试使用 *ratio | |
| logits = nnet(ids, **kwargs, scale=cfg_scale,lambdaA=lambdaA,lambdaB=lambdaB) | |
| # sampling & scoring | |
| sampled_ids = add_gumbel_noise(logits, annealed_temp, device).argmax(dim=-1) | |
| sampled_logits = torch.squeeze( | |
| torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
| sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
| sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
| # masking | |
| mask_ratio = np.cos(ratio * math.pi * 0.5) | |
| mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(device) | |
| mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
| torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
| mask_len))[0].squeeze() | |
| confidence = add_gumbel_noise(sampled_logits, annealed_temp, device) | |
| sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
| cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
| masking = (confidence <= cut_off) | |
| ids = torch.where(masking, self.mask_ind, sampled_ids) | |
| cfg_scale = ratio * config.sample.scale | |
| lambdaA = config.sample.lambdaA | |
| lambdaB = config.sample.lambdaB | |
| _z2 = rearrange(sampled_ids, 'b (i j) -> b i j', i=fmap_size, j=fmap_size) | |
| _z = _z2 if is_eval else torch.cat([_z1,_z2],dim=0) | |
| out = decode_fn(_z) | |
| return out | |