|
|
import os |
|
|
|
|
|
import pandas as pd |
|
|
import torch, time, wandb |
|
|
from collections import defaultdict |
|
|
import pytorch_lightning as pl |
|
|
import numpy as np |
|
|
import pdb |
|
|
from utils.log import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeneralModule(pl.LightningModule): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.args = args |
|
|
|
|
|
self.iter_step = -1 |
|
|
self._log = defaultdict(list) |
|
|
self.generator = np.random.default_rng() |
|
|
self.last_log_time = time.time() |
|
|
|
|
|
|
|
|
def try_print_log(self): |
|
|
|
|
|
step = self.iter_step if self.args.validate else self.trainer.global_step |
|
|
if (step + 1) % self.args.print_freq == 0: |
|
|
print(os.environ["MODEL_DIR"]) |
|
|
log = self._log |
|
|
log = {key: log[key] for key in log if "iter_" in key} |
|
|
|
|
|
log = self.gather_log(log, self.trainer.world_size) |
|
|
mean_log = self.get_log_mean(log) |
|
|
mean_log.update( |
|
|
{'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
if self.trainer.is_global_zero: |
|
|
print(str(mean_log)) |
|
|
self.log_dict(mean_log, batch_size=1) |
|
|
if self.args.wandb: |
|
|
wandb.log(mean_log) |
|
|
for key in list(log.keys()): |
|
|
if "iter_" in key: |
|
|
del self._log[key] |
|
|
|
|
|
def lg(self, key, data): |
|
|
if isinstance(data, torch.Tensor): |
|
|
data = data.detach().cpu().item() |
|
|
log = self._log |
|
|
|
|
|
if self.args.validate or self.stage == 'train': |
|
|
log["iter_" + key].append(data) |
|
|
log[self.stage + "_" + key].append(data) |
|
|
|
|
|
def on_train_epoch_end(self): |
|
|
log = self._log |
|
|
log = {key: log[key] for key in log if "train_" in key} |
|
|
log = self.gather_log(log, self.trainer.world_size) |
|
|
mean_log = self.get_log_mean(log) |
|
|
mean_log.update( |
|
|
{'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
|
|
|
if self.trainer.is_global_zero: |
|
|
logger.info(str(mean_log)) |
|
|
self.log_dict(mean_log, batch_size=1) |
|
|
if self.args.wandb: |
|
|
wandb.log(mean_log) |
|
|
|
|
|
for key in list(log.keys()): |
|
|
if "train_" in key: |
|
|
del self._log[key] |
|
|
|
|
|
def on_validation_epoch_end(self): |
|
|
self.generator = np.random.default_rng() |
|
|
log = self._log |
|
|
log = {key: log[key] for key in log if "val_" in key} |
|
|
log = self.gather_log(log, self.trainer.world_size) |
|
|
mean_log = self.get_log_mean(log) |
|
|
mean_log.update( |
|
|
{'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
|
|
|
if self.trainer.is_global_zero: |
|
|
logger.info(str(mean_log)) |
|
|
self.log_dict(mean_log, batch_size=1) |
|
|
if self.args.wandb: |
|
|
wandb.log(mean_log) |
|
|
|
|
|
path = os.path.join( |
|
|
os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv" |
|
|
) |
|
|
pd.DataFrame(log).to_csv(path) |
|
|
|
|
|
for key in list(log.keys()): |
|
|
if "val_" in key: |
|
|
del self._log[key] |
|
|
|
|
|
|
|
|
|
|
|
def gather_log(self, log, world_size): |
|
|
if world_size == 1: |
|
|
return log |
|
|
log_list = [None] * world_size |
|
|
torch.distributed.all_gather_object(log_list, log) |
|
|
log = {key: sum([l[key] for l in log_list], []) for key in log} |
|
|
return log |
|
|
|
|
|
def get_log_mean(self, log): |
|
|
out = {} |
|
|
for key in log: |
|
|
try: |
|
|
out[key] = np.nanmean(log[key]) |
|
|
except: |
|
|
pass |
|
|
return out |