@@ -186,17 +186,19 @@ def main(args):
186
186
sampler = test_sampler , num_workers = args .workers , pin_memory = True )
187
187
188
188
print ("Creating model" )
189
- model = torchvision .models .__dict__ [args .model ](pretrained = args .pretrained )
189
+ model = torchvision .models .__dict__ [args .model ](pretrained = args .pretrained , num_classes = num_classes )
190
190
model .to (device )
191
+
191
192
if args .distributed and args .sync_bn :
192
193
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
193
194
194
195
criterion = nn .CrossEntropyLoss (label_smoothing = args .label_smoothing )
195
196
196
197
opt_name = args .opt .lower ()
197
- if opt_name == ' sgd' :
198
+ if opt_name . startswith ( " sgd" ) :
198
199
optimizer = torch .optim .SGD (
199
- model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
200
+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay ,
201
+ nesterov = "nesterov" in opt_name )
200
202
elif opt_name == 'rmsprop' :
201
203
optimizer = torch .optim .RMSprop (model .parameters (), lr = args .lr , momentum = args .momentum ,
202
204
weight_decay = args .weight_decay , eps = 0.0316 , alpha = 0.9 )
@@ -214,15 +216,25 @@ def main(args):
214
216
elif args .lr_scheduler == 'cosineannealinglr' :
215
217
main_lr_scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (optimizer ,
216
218
T_max = args .epochs - args .lr_warmup_epochs )
219
+ elif args .lr_scheduler == 'exponentiallr' :
220
+ main_lr_scheduler = torch .optim .lr_scheduler .ExponentialLR (optimizer , gamma = args .lr_gamma )
217
221
else :
218
- raise RuntimeError ("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR "
222
+ raise RuntimeError ("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
219
223
"are supported." .format (args .lr_scheduler ))
220
224
221
225
if args .lr_warmup_epochs > 0 :
226
+ if args .lr_warmup_method == 'linear' :
227
+ warmup_lr_scheduler = torch .optim .lr_scheduler .LinearLR (optimizer , start_factor = args .lr_warmup_decay ,
228
+ total_iters = args .lr_warmup_epochs )
229
+ elif args .lr_warmup_method == 'constant' :
230
+ warmup_lr_scheduler = torch .optim .lr_scheduler .ConstantLR (optimizer , factor = args .lr_warmup_decay ,
231
+ total_iters = args .lr_warmup_epochs )
232
+ else :
233
+ raise RuntimeError (f"Invalid warmup lr method '{ args .lr_warmup_method } '. Only linear and constant "
234
+ "are supported." )
222
235
lr_scheduler = torch .optim .lr_scheduler .SequentialLR (
223
236
optimizer ,
224
- schedulers = [torch .optim .lr_scheduler .ConstantLR (optimizer , factor = args .lr_warmup_decay ,
225
- total_iters = args .lr_warmup_epochs ), main_lr_scheduler ],
237
+ schedulers = [warmup_lr_scheduler , main_lr_scheduler ],
226
238
milestones = [args .lr_warmup_epochs ]
227
239
)
228
240
else :
@@ -307,7 +319,9 @@ def get_args_parser(add_help=True):
307
319
parser .add_argument ('--cutmix-alpha' , default = 0.0 , type = float , help = 'cutmix alpha (default: 0.0)' )
308
320
parser .add_argument ('--lr-scheduler' , default = "steplr" , help = 'the lr scheduler (default: steplr)' )
309
321
parser .add_argument ('--lr-warmup-epochs' , default = 0 , type = int , help = 'the number of epochs to warmup (default: 0)' )
310
- parser .add_argument ('--lr-warmup-decay' , default = 0.01 , type = int , help = 'the decay for lr' )
322
+ parser .add_argument ('--lr-warmup-method' , default = "constant" , type = str ,
323
+ help = 'the warmup method (default: constant)' )
324
+ parser .add_argument ('--lr-warmup-decay' , default = 0.01 , type = float , help = 'the decay for lr' )
311
325
parser .add_argument ('--lr-step-size' , default = 30 , type = int , help = 'decrease lr every step-size epochs' )
312
326
parser .add_argument ('--lr-gamma' , default = 0.1 , type = float , help = 'decrease lr by a factor of lr-gamma' )
313
327
parser .add_argument ('--print-freq' , default = 10 , type = int , help = 'print frequency' )
0 commit comments