Skip to content

Commit c712016

Browse files
authored
Further enhance Classification Reference (#4444)
* Adding ExponentialLR and LinearLR * Fix arg type of --lr-warmup-decay * Adding support of Zero gamma BN and SGD with nesterov. * Fix --lr-warmup-decay for video_classification. * Update bn_reinit * Fix pre-existing bug on num_classes of model * Remove zero gamma. * Use fstrings.
1 parent 16405ac commit c712016

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

references/classification/train.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,19 @@ def main(args):
186186
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
187187

188188
print("Creating model")
189-
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
189+
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
190190
model.to(device)
191+
191192
if args.distributed and args.sync_bn:
192193
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
193194

194195
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
195196

196197
opt_name = args.opt.lower()
197-
if opt_name == 'sgd':
198+
if opt_name.startswith("sgd"):
198199
optimizer = torch.optim.SGD(
199-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
200+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
201+
nesterov="nesterov" in opt_name)
200202
elif opt_name == 'rmsprop':
201203
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
202204
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
@@ -214,15 +216,25 @@ def main(args):
214216
elif args.lr_scheduler == 'cosineannealinglr':
215217
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
216218
T_max=args.epochs - args.lr_warmup_epochs)
219+
elif args.lr_scheduler == 'exponentiallr':
220+
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
217221
else:
218-
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR "
222+
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
219223
"are supported.".format(args.lr_scheduler))
220224

221225
if args.lr_warmup_epochs > 0:
226+
if args.lr_warmup_method == 'linear':
227+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
228+
total_iters=args.lr_warmup_epochs)
229+
elif args.lr_warmup_method == 'constant':
230+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
231+
total_iters=args.lr_warmup_epochs)
232+
else:
233+
raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant "
234+
"are supported.")
222235
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
223236
optimizer,
224-
schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
225-
total_iters=args.lr_warmup_epochs), main_lr_scheduler],
237+
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
226238
milestones=[args.lr_warmup_epochs]
227239
)
228240
else:
@@ -307,7 +319,9 @@ def get_args_parser(add_help=True):
307319
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
308320
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
309321
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
310-
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
322+
parser.add_argument('--lr-warmup-method', default="constant", type=str,
323+
help='the warmup method (default: constant)')
324+
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
311325
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
312326
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
313327
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')

references/segmentation/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_args_parser(add_help=True):
220220
dest='weight_decay')
221221
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
222222
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
223-
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
223+
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
224224
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
225225
parser.add_argument('--output-dir', default='.', help='path where to save')
226226
parser.add_argument('--resume', default='', help='resume from checkpoint')

references/video_classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def parse_args():
296296
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
297297
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)')
298298
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
299-
parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr')
299+
parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr')
300300
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
301301
parser.add_argument('--output-dir', default='.', help='path where to save')
302302
parser.add_argument('--resume', default='', help='resume from checkpoint')

0 commit comments

Comments
 (0)