import os import argparse class BaseOptions(object): def __init__(self): self.initialized = False self.parser = None self.opt = None # config for predefined method self.override_dict = { 'InsDis': ['RGB', False, 'bank', 'A', 'linear', 0.07], 'CMC': ['CMC', False, 'bank', 'C', 'linear', 0.07], 'MoCo': ['RGB', False, 'moco', 'A', 'linear', 0.07], 'PIRL': ['RGB', True, 'bank', 'A', 'linear', 0.07], 'MoCov2': ['RGB', False, 'moco', 'B', 'mlp', 0.2], 'CMCv2': ['CMC', False, 'moco', 'E', 'mlp', 0.2], 'InfoMin': ['RGB', True, 'moco', 'D', 'mlp', 0.15], } def initialize(self, parser): # specify folder #parser.add_argument('--data_folder', type=str, default='./data', # help='path to data') parser.add_argument('--csv_path', type=str, default='./contrast_train.csv', help='path to csv file') parser.add_argument('--model_path', type=str, default='./test', help='path to save model') parser.add_argument('--tb_path', type=str, default='./test', help='path to tensorboard') # basics parser.add_argument('--print_freq', type=int, default=10, help='print frequency') parser.add_argument('--save_freq', type=int, default=1, help='save frequency') parser.add_argument('--batch_size', type=int, default=60, help='batch_size') parser.add_argument('-j', '--num_workers', type=int, default=40, help='num of workers to use') parser.add_argument('-n_aug', '--n_aug', type=int, default=7, help='num of augmentations per image to use') parser.add_argument('-n_scale', '--n_scale', type=int, default=1, help='num of scales per image to use. 1 only image, 2 image and half resized image') parser.add_argument('-n_distortions', '--n_distortions', type=int, default=1, help='num of distortions per image crop to use. 1 for single distortion image, 2 for randomly selecting amoing 1/2 distortions') parser.add_argument('-patch_size', '--patch_size', type=int, default=224, help='patch_size to crop for each image') parser.add_argument('-swap_crops', '--swap_crops', type=int, default=1, help='patch_size to crop for each image') # optimization parser.add_argument('--epochs', type=int, default=30, help='number of training epochs') parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') parser.add_argument('--lr_decay_epochs', type=str, default='120,160', help='where to decay lr, can be a list') parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') parser.add_argument('--cosine', action='store_true', help='using cosine annealing') parser.add_argument('--optimizer', type=str, default='SGD', help='SGD/AdamW/LARS') # method selection parser.add_argument('--method', default='Customize', type=str, choices=['InsDis', 'CMC', 'CMCv2', 'MoCo', 'MoCov2', 'PIRL', 'InfoMin', 'Customize'], help='Choose predefined method. Configs will be override ' 'for all methods except for `Customize`, which allows ' 'for user-defined combination of methods') # method configuration parser.add_argument('--modal', default='RGB', type=str, choices=['RGB', 'CMC'], help='single RGB modal, or two modalities in CMC') parser.add_argument('--jigsaw', action='store_true', help='adding PIRL branch') parser.add_argument('--mem', default='bank', type=str, choices=['bank', 'moco'], help='memory mechanism: memory bank, or moco encoder cache') # model setup parser.add_argument('--arch', default='resnet50', type=str, help='e.g., resnet50, resnext50, resnext101' 'and their wider variants, resnet50x4') parser.add_argument('-d', '--feat_dim', default=128, type=int, help='feature dimension for contrastive loss') parser.add_argument('-k', '--nce_k', default=65536, type=int, help='number of negatives') parser.add_argument('-m', '--nce_m', default=0.5, type=float, help='momentum for memory update') parser.add_argument('-t', '--nce_t', default=0.07, type=float, help='temperature') parser.add_argument('--alpha', default=0.999, type=float, help='momentum coefficients for moco encoder update') parser.add_argument('--head', default='linear', type=str, choices=['linear', 'mlp'], help='projection head') # resume parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') # Parallel setting parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') return parser def print_options(self, opt): message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>35}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) def modify_options(self, opt): raise NotImplementedError def override_options(self, opt): # override parameters for predefined method if opt.method in self.override_dict.keys(): opt.modal = self.override_dict[opt.method][0] opt.jigsaw = self.override_dict[opt.method][1] opt.mem = self.override_dict[opt.method][2] opt.aug = self.override_dict[opt.method][3] opt.head = self.override_dict[opt.method][4] opt.nce_t = self.override_dict[opt.method][5] return opt def parse(self): if not self.initialized: parser = argparse.ArgumentParser('arguments options') parser = self.initialize(parser) self.parser = parser self.initialized = True else: parser = self.parser opt = parser.parse_args() opt = self.modify_options(opt) self.opt = opt self.print_options(opt) return opt