Spaces:
Runtime error
Runtime error
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import os | |
| from tqdm import tqdm | |
| from torchvision.utils import save_image | |
| from torch import distributed as dist | |
| from loguru import logger | |
| logging = logger | |
| def set_logger(log_level='info', fname=None): | |
| import logging as _logging | |
| handler = logging.get_absl_handler() | |
| formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') | |
| handler.setFormatter(formatter) | |
| logging.set_verbosity(log_level) | |
| if fname is not None: | |
| handler = _logging.FileHandler(fname) | |
| handler.setFormatter(formatter) | |
| logging.get_absl_logger().addHandler(handler) | |
| def dct2str(dct): | |
| return str({k: f'{v:.6g}' for k, v in dct.items()}) | |
| def get_nnet(name, **kwargs): | |
| if name == 'uvit_t2i_vq': | |
| from libs.uvit_t2i_vq import UViT | |
| return UViT(**kwargs) | |
| elif name == 'uvit_vq': | |
| from libs.uvit_vq import UViT | |
| return UViT(**kwargs) | |
| else: | |
| raise NotImplementedError(name) | |
| def set_seed(seed: int): | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| def get_optimizer(params, name, **kwargs): | |
| if name == 'adam': | |
| from torch.optim import Adam | |
| return Adam(params, **kwargs) | |
| elif name == 'adamw': | |
| from torch.optim import AdamW | |
| return AdamW(params, **kwargs) | |
| else: | |
| raise NotImplementedError(name) | |
| def customized_lr_scheduler(optimizer, warmup_steps=-1): | |
| from torch.optim.lr_scheduler import LambdaLR | |
| def fn(step): | |
| if warmup_steps > 0: | |
| return min(step / warmup_steps, 1) | |
| else: | |
| return 1 | |
| return LambdaLR(optimizer, fn) | |
| def get_lr_scheduler(optimizer, name, **kwargs): | |
| if name == 'customized': | |
| return customized_lr_scheduler(optimizer, **kwargs) | |
| else: | |
| raise NotImplementedError(name) | |
| def ema(model_dest: nn.Module, model_src: nn.Module, rate): | |
| param_dict_src = dict(model_src.named_parameters()) | |
| for p_name, p_dest in model_dest.named_parameters(): | |
| p_src = param_dict_src[p_name] | |
| assert p_src is not p_dest | |
| if 'adapter' not in p_name: | |
| p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) | |
| else: | |
| p_dest.data = p_src.detach().clone() | |
| class TrainState(object): | |
| def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): | |
| self.optimizer = optimizer | |
| self.lr_scheduler = lr_scheduler | |
| self.step = step | |
| self.nnet = nnet | |
| self.nnet_ema = nnet_ema | |
| def ema_update(self, rate=0.9999): | |
| if self.nnet_ema is not None: | |
| ema(self.nnet_ema, self.nnet, rate) | |
| def save(self, path, adapter_only=False,name=""): | |
| os.makedirs(path, exist_ok=True) | |
| torch.save(self.step, os.path.join(path, 'step.pth')) | |
| if adapter_only: | |
| torch.save(self.nnet.adapter.state_dict(), os.path.join(path, name+'adapter.pth')) | |
| else: | |
| for key, val in self.__dict__.items(): | |
| if key != 'step' and val is not None: | |
| torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) | |
| def make_dict(self,model,state_dict): | |
| state = {} | |
| for k in model.state_dict().keys(): | |
| if k in state_dict: | |
| state[k] = state_dict[k].clone() | |
| else: | |
| state[k] = model.state_dict()[k].clone() | |
| return state | |
| def load(self, path): | |
| logging.info(f'load from {path}') | |
| self.step = torch.load(os.path.join(path, 'step.pth'), map_location='cpu') | |
| for key, val in self.__dict__.items(): | |
| if key != 'step' and val is not None and key != 'optimizer' and key != 'lr_scheduler': | |
| if key == 'nnet' or key == 'nnet_ema': | |
| val.load_state_dict(self.make_dict(val,torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))) | |
| else: | |
| val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) | |
| def load_adapter(self,path): | |
| logging.info('load adapter from {}'.format(path)) | |
| adapter = torch.load(path,map_location='cpu') | |
| keys=['nnet','nnet_ema'] | |
| for key in keys: | |
| if key in self.__dict__: | |
| self.__dict__[key].adapter.load_state_dict(adapter) | |
| else: | |
| logging.info('adapter not in state_dict') | |
| def resume(self, ckpt_root,adapter_path=None, step=None): | |
| if not os.path.exists(ckpt_root): | |
| return | |
| if ckpt_root.endswith('.ckpt'): | |
| ckpt_path = ckpt_root | |
| else: | |
| if step is None: | |
| ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) | |
| if not ckpts: | |
| return | |
| steps = map(lambda x: int(x.split(".")[0]), ckpts) | |
| step = max(steps) | |
| ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') | |
| logging.info(f'resume from {ckpt_path}') | |
| self.load(ckpt_path) | |
| if adapter_path is not None: | |
| self.load_adapter(adapter_path) | |
| def to(self, device): | |
| for key, val in self.__dict__.items(): | |
| if isinstance(val, nn.Module): | |
| val.to(device) | |
| def freeze(self): | |
| self.nnet.requires_grad_(False) | |
| for name, p in self.nnet.named_parameters(): | |
| if 'adapter' in name: | |
| p.requires_grad_(True) | |
| def cnt_params(model): | |
| return sum(param.numel() for param in model.parameters()) | |
| def initialize_train_state(config, device): | |
| params = [] | |
| nnet = get_nnet(**config.nnet) | |
| params += nnet.adapter.parameters() | |
| nnet_ema = get_nnet(**config.nnet) | |
| nnet_ema.eval() | |
| logging.info(f'nnet has {cnt_params(nnet)} parameters') | |
| optimizer = get_optimizer(params, **config.optimizer) | |
| lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) | |
| train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, | |
| nnet=nnet, nnet_ema=nnet_ema) | |
| train_state.ema_update(0) | |
| train_state.to(device) | |
| return train_state | |
| def amortize(n_samples, batch_size): | |
| k = n_samples // batch_size | |
| r = n_samples % batch_size | |
| return k * [batch_size] if r == 0 else k * [batch_size] + [r] | |
| def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, dist=True): | |
| if path: | |
| os.makedirs(path, exist_ok=True) | |
| idx = 0 | |
| batch_size = mini_batch_size * accelerator.num_processes if dist else mini_batch_size | |
| for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): | |
| samples = unpreprocess_fn(sample_fn(mini_batch_size)) | |
| if dist: | |
| samples = accelerator.gather(samples.contiguous())[:_batch_size] | |
| if accelerator.is_main_process: | |
| for sample in samples: | |
| save_image(sample, os.path.join(path, f"{idx}.png")) | |
| idx += 1 | |
| def grad_norm(model): | |
| total_norm = 0. | |
| for p in model.parameters(): | |
| param_norm = p.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| total_norm = total_norm ** (1. / 2) | |
| return total_norm | |
| from collections import defaultdict, deque | |
| class SmoothedValue(object): | |
| """Track a series of values and provide access to smoothed values over a | |
| window or the global series average. | |
| """ | |
| def __init__(self, window_size=20, fmt=None): | |
| if fmt is None: | |
| fmt = "{median:.4f} ({global_avg:.4f})" | |
| self.deque = deque(maxlen=window_size) | |
| self.total = 0.0 | |
| self.count = 0 | |
| self.fmt = fmt | |
| def update(self, value, n=1): | |
| self.deque.append(value) | |
| self.count += n | |
| self.total += value * n | |
| def median(self): | |
| d = torch.tensor(list(self.deque)) | |
| return d.median().item() | |
| def avg(self): | |
| d = torch.tensor(list(self.deque), dtype=torch.float32) | |
| return d.mean().item() | |
| def global_avg(self): | |
| return self.total / self.count | |
| def max(self): | |
| return max(self.deque) | |
| def value(self): | |
| return self.deque[-1] | |
| def __str__(self): | |
| return self.fmt.format( | |
| median=self.median, | |
| avg=self.avg, | |
| global_avg=self.global_avg, | |
| max=self.max, | |
| value=self.value) | |
| class MetricLogger(object): | |
| def __init__(self, delimiter=" "): | |
| self.meters = defaultdict(SmoothedValue) | |
| self.delimiter = delimiter | |
| def update(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.item() | |
| assert isinstance(v, (float, int)) | |
| self.meters[k].update(v) | |
| def __getattr__(self, attr): | |
| if attr in self.meters: | |
| return self.meters[attr] | |
| if attr in self.__dict__: | |
| return self.__dict__[attr] | |
| raise AttributeError("'{}' object has no attribute '{}'".format( | |
| type(self).__name__, attr)) | |
| def __str__(self): | |
| loss_str = [] | |
| for name, meter in self.meters.items(): | |
| loss_str.append( | |
| "{}: {}".format(name, str(meter)) | |
| ) | |
| return self.delimiter.join(loss_str) | |
| def add_meter(self, name, meter): | |
| self.meters[name] = meter | |
| def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: | |
| from torch._six import inf | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| parameters = [p for p in parameters if p.grad is not None] | |
| norm_type = float(norm_type) | |
| if len(parameters) == 0: | |
| return torch.tensor(0.) | |
| device = parameters[0].grad.device | |
| if norm_type == inf: | |
| total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) | |
| else: | |
| total_norm = torch.norm( | |
| torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) | |
| return total_norm | |