diff --git a/references/classification/train.py b/references/classification/train.py index 90abdb0b47e..48ab75bc2c1 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -186,17 +186,19 @@ def main(args): sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model.to(device) + if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() - if opt_name == 'sgd': + if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name) elif opt_name == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) @@ -214,15 +216,25 @@ def main(args): elif args.lr_scheduler == 'cosineannealinglr': main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.lr_warmup_epochs) + elif args.lr_scheduler == 'exponentiallr': + main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) else: - raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " + raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " "are supported.".format(args.lr_scheduler)) if args.lr_warmup_epochs > 0: + if args.lr_warmup_method == 'linear': + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs) + elif args.lr_warmup_method == 'constant': + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, + total_iters=args.lr_warmup_epochs) + else: + raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " + "are supported.") lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, - schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, - total_iters=args.lr_warmup_epochs), main_lr_scheduler], + schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] ) else: @@ -307,7 +319,9 @@ def get_args_parser(add_help=True): parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-method', default="constant", type=str, + help='the warmup method (default: constant)') + parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 476058ce0c0..83277de9c2c 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -220,7 +220,7 @@ def get_args_parser(add_help=True): dest='weight_decay') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 353e0d6d1f7..0eefbc0b282 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -296,7 +296,7 @@ def parse_args(): parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr') + parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint')