Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| class BaseOptions(object): | |
| def __init__(self): | |
| self.initialized = False | |
| self.parser = None | |
| self.opt = None | |
| # config for predefined method | |
| self.override_dict = { | |
| 'InsDis': ['RGB', False, 'bank', 'A', 'linear', 0.07], | |
| 'CMC': ['CMC', False, 'bank', 'C', 'linear', 0.07], | |
| 'MoCo': ['RGB', False, 'moco', 'A', 'linear', 0.07], | |
| 'PIRL': ['RGB', True, 'bank', 'A', 'linear', 0.07], | |
| 'MoCov2': ['RGB', False, 'moco', 'B', 'mlp', 0.2], | |
| 'CMCv2': ['CMC', False, 'moco', 'E', 'mlp', 0.2], | |
| 'InfoMin': ['RGB', True, 'moco', 'D', 'mlp', 0.15], | |
| } | |
| def initialize(self, parser): | |
| # specify folder | |
| #parser.add_argument('--data_folder', type=str, default='./data', | |
| # help='path to data') | |
| parser.add_argument('--csv_path', type=str, default='./contrast_train.csv', | |
| help='path to csv file') | |
| parser.add_argument('--model_path', type=str, default='./test', | |
| help='path to save model') | |
| parser.add_argument('--tb_path', type=str, default='./test', | |
| help='path to tensorboard') | |
| # basics | |
| parser.add_argument('--print_freq', type=int, default=10, | |
| help='print frequency') | |
| parser.add_argument('--save_freq', type=int, default=1, | |
| help='save frequency') | |
| parser.add_argument('--batch_size', type=int, default=60, | |
| help='batch_size') | |
| parser.add_argument('-j', '--num_workers', type=int, default=40, | |
| help='num of workers to use') | |
| parser.add_argument('-n_aug', '--n_aug', type=int, default=7, | |
| help='num of augmentations per image to use') | |
| parser.add_argument('-n_scale', '--n_scale', type=int, default=1, | |
| help='num of scales per image to use. 1 only image, 2 image and half resized image') | |
| parser.add_argument('-n_distortions', '--n_distortions', type=int, default=1, | |
| help='num of distortions per image crop to use. 1 for single distortion image, 2 for randomly selecting amoing 1/2 distortions') | |
| parser.add_argument('-patch_size', '--patch_size', type=int, default=224, | |
| help='patch_size to crop for each image') | |
| parser.add_argument('-swap_crops', '--swap_crops', type=int, default=1, | |
| help='patch_size to crop for each image') | |
| # optimization | |
| parser.add_argument('--epochs', type=int, default=30, | |
| help='number of training epochs') | |
| parser.add_argument('--learning_rate', type=float, default=0.05, | |
| help='learning rate') | |
| parser.add_argument('--lr_decay_epochs', type=str, default='120,160', | |
| help='where to decay lr, can be a list') | |
| parser.add_argument('--lr_decay_rate', type=float, default=0.1, | |
| help='decay rate for learning rate') | |
| parser.add_argument('--weight_decay', type=float, default=1e-4, | |
| help='weight decay') | |
| parser.add_argument('--momentum', type=float, default=0.9, | |
| help='momentum for SGD') | |
| parser.add_argument('--cosine', action='store_true', | |
| help='using cosine annealing') | |
| parser.add_argument('--optimizer', type=str, default='SGD', | |
| help='SGD/AdamW/LARS') | |
| # method selection | |
| parser.add_argument('--method', default='Customize', type=str, | |
| choices=['InsDis', 'CMC', 'CMCv2', 'MoCo', 'MoCov2', | |
| 'PIRL', 'InfoMin', 'Customize'], | |
| help='Choose predefined method. Configs will be override ' | |
| 'for all methods except for `Customize`, which allows ' | |
| 'for user-defined combination of methods') | |
| # method configuration | |
| parser.add_argument('--modal', default='RGB', type=str, choices=['RGB', 'CMC'], | |
| help='single RGB modal, or two modalities in CMC') | |
| parser.add_argument('--jigsaw', action='store_true', | |
| help='adding PIRL branch') | |
| parser.add_argument('--mem', default='bank', type=str, choices=['bank', 'moco'], | |
| help='memory mechanism: memory bank, or moco encoder cache') | |
| # model setup | |
| parser.add_argument('--arch', default='resnet50', type=str, | |
| help='e.g., resnet50, resnext50, resnext101' | |
| 'and their wider variants, resnet50x4') | |
| parser.add_argument('-d', '--feat_dim', default=128, type=int, | |
| help='feature dimension for contrastive loss') | |
| parser.add_argument('-k', '--nce_k', default=65536, type=int, | |
| help='number of negatives') | |
| parser.add_argument('-m', '--nce_m', default=0.5, type=float, | |
| help='momentum for memory update') | |
| parser.add_argument('-t', '--nce_t', default=0.07, type=float, | |
| help='temperature') | |
| parser.add_argument('--alpha', default=0.999, type=float, | |
| help='momentum coefficients for moco encoder update') | |
| parser.add_argument('--head', default='linear', type=str, | |
| choices=['linear', 'mlp'], help='projection head') | |
| # resume | |
| parser.add_argument('--resume', default='', type=str, metavar='PATH', | |
| help='path to latest checkpoint (default: none)') | |
| # Parallel setting | |
| parser.add_argument('--world-size', default=-1, type=int, | |
| help='number of nodes for distributed training') | |
| parser.add_argument('--rank', default=-1, type=int, | |
| help='node rank for distributed training') | |
| parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str, | |
| help='url used to set up distributed training') | |
| parser.add_argument('--dist-backend', default='nccl', type=str, | |
| help='distributed backend') | |
| parser.add_argument('--seed', default=None, type=int, | |
| help='seed for initializing training. ') | |
| parser.add_argument('--gpu', default=None, type=int, | |
| help='GPU id to use.') | |
| parser.add_argument('--multiprocessing-distributed', action='store_true', | |
| help='Use multi-processing distributed training to launch ' | |
| 'N processes per node, which has N GPUs. This is the ' | |
| 'fastest way to use PyTorch for either single node or ' | |
| 'multi node data parallel training') | |
| return parser | |
| def print_options(self, opt): | |
| message = '' | |
| message += '----------------- Options ---------------\n' | |
| for k, v in sorted(vars(opt).items()): | |
| comment = '' | |
| default = self.parser.get_default(k) | |
| if v != default: | |
| comment = '\t[default: %s]' % str(default) | |
| message += '{:>35}: {:<30}{}\n'.format(str(k), str(v), comment) | |
| message += '----------------- End -------------------' | |
| print(message) | |
| def modify_options(self, opt): | |
| raise NotImplementedError | |
| def override_options(self, opt): | |
| # override parameters for predefined method | |
| if opt.method in self.override_dict.keys(): | |
| opt.modal = self.override_dict[opt.method][0] | |
| opt.jigsaw = self.override_dict[opt.method][1] | |
| opt.mem = self.override_dict[opt.method][2] | |
| opt.aug = self.override_dict[opt.method][3] | |
| opt.head = self.override_dict[opt.method][4] | |
| opt.nce_t = self.override_dict[opt.method][5] | |
| return opt | |
| def parse(self): | |
| if not self.initialized: | |
| parser = argparse.ArgumentParser('arguments options') | |
| parser = self.initialize(parser) | |
| self.parser = parser | |
| self.initialized = True | |
| else: | |
| parser = self.parser | |
| opt = parser.parse_args() | |
| opt = self.modify_options(opt) | |
| self.opt = opt | |
| self.print_options(opt) | |
| return opt | |