|
|
from __future__ import print_function |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
import sys |
|
|
import pprint |
|
|
import datetime |
|
|
import dateutil |
|
|
import dateutil.tz |
|
|
|
|
|
|
|
|
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) |
|
|
sys.path.append(dir_path) |
|
|
|
|
|
from miscc.datasets import TextDataset |
|
|
from miscc.config import cfg, cfg_from_file |
|
|
from miscc.utils import mkdir_p |
|
|
from trainer import GANTrainer |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Train a GAN network') |
|
|
parser.add_argument('--cfg', dest='cfg_file', |
|
|
help='optional config file', |
|
|
default='birds_stage1.yml', type=str) |
|
|
parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') |
|
|
parser.add_argument('--data_dir', dest='data_dir', type=str, default='') |
|
|
parser.add_argument('--manualSeed', type=int, help='manual seed') |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
if args.cfg_file is not None: |
|
|
cfg_from_file(args.cfg_file) |
|
|
if args.gpu_id != -1: |
|
|
cfg.GPU_ID = args.gpu_id |
|
|
if args.data_dir != '': |
|
|
cfg.DATA_DIR = args.data_dir |
|
|
print('Using config:') |
|
|
pprint.pprint(cfg) |
|
|
if args.manualSeed is None: |
|
|
args.manualSeed = random.randint(1, 10000) |
|
|
random.seed(args.manualSeed) |
|
|
torch.manual_seed(args.manualSeed) |
|
|
if cfg.CUDA: |
|
|
torch.cuda.manual_seed_all(args.manualSeed) |
|
|
now = datetime.datetime.now(dateutil.tz.tzlocal()) |
|
|
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') |
|
|
output_dir = '../output/%s_%s_%s' % \ |
|
|
(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) |
|
|
|
|
|
num_gpu = len(cfg.GPU_ID.split(',')) |
|
|
if cfg.TRAIN.FLAG: |
|
|
dataset = TextDataset(cfg.DATA_DIR, 'train', |
|
|
rirsize=cfg.RIRSIZE) |
|
|
assert dataset |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, |
|
|
drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) |
|
|
|
|
|
algo = GANTrainer(output_dir) |
|
|
algo.train(dataloader, cfg.STAGE) |
|
|
else: |
|
|
file_path = cfg.EVAL_DIR |
|
|
algo = GANTrainer(output_dir) |
|
|
algo.sample(file_path, cfg.STAGE) |
|
|
|