import os import pandas as pd import torch, time, wandb from collections import defaultdict import pytorch_lightning as pl import numpy as np import pdb from utils.log import get_logger logger = get_logger(__name__) class GeneralModule(pl.LightningModule): def __init__(self, args): super().__init__() self.save_hyperparameters() self.args = args self.iter_step = -1 self._log = defaultdict(list) self.generator = np.random.default_rng() self.last_log_time = time.time() def try_print_log(self): step = self.iter_step if self.args.validate else self.trainer.global_step if (step + 1) % self.args.print_freq == 0: print(os.environ["MODEL_DIR"]) log = self._log log = {key: log[key] for key in log if "iter_" in key} log = self.gather_log(log, self.trainer.world_size) mean_log = self.get_log_mean(log) mean_log.update( {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) if self.trainer.is_global_zero: print(str(mean_log)) self.log_dict(mean_log, batch_size=1) if self.args.wandb: wandb.log(mean_log) for key in list(log.keys()): if "iter_" in key: del self._log[key] def lg(self, key, data): if isinstance(data, torch.Tensor): data = data.detach().cpu().item() log = self._log # pdb.set_trace() if self.args.validate or self.stage == 'train': log["iter_" + key].append(data) log[self.stage + "_" + key].append(data) def on_train_epoch_end(self): log = self._log log = {key: log[key] for key in log if "train_" in key} log = self.gather_log(log, self.trainer.world_size) mean_log = self.get_log_mean(log) mean_log.update( {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) if self.trainer.is_global_zero: logger.info(str(mean_log)) self.log_dict(mean_log, batch_size=1) if self.args.wandb: wandb.log(mean_log) for key in list(log.keys()): if "train_" in key: del self._log[key] def on_validation_epoch_end(self): self.generator = np.random.default_rng() log = self._log log = {key: log[key] for key in log if "val_" in key} log = self.gather_log(log, self.trainer.world_size) mean_log = self.get_log_mean(log) mean_log.update( {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) if self.trainer.is_global_zero: logger.info(str(mean_log)) self.log_dict(mean_log, batch_size=1) if self.args.wandb: wandb.log(mean_log) path = os.path.join( os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv" ) pd.DataFrame(log).to_csv(path) for key in list(log.keys()): if "val_" in key: del self._log[key] def gather_log(self, log, world_size): if world_size == 1: return log log_list = [None] * world_size torch.distributed.all_gather_object(log_list, log) log = {key: sum([l[key] for l in log_list], []) for key in log} return log def get_log_mean(self, log): out = {} for key in log: try: out[key] = np.nanmean(log[key]) except: pass return out