Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py', 'vit-huge-p14_8xb128-coslr-50e_in1k.py'
]

# optimizer wrapper
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))

# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))

# runner which supports strategies
runner_type = 'FlexibleRunner'
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py', 'vit-large-p16_8xb128-coslr-50e_in1k.py'
]

# optimizer wrapper
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))

# training strategy
# Deepspeed with ZeRO3 + fp16
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))

# runner which supports strategies
runner_type = 'FlexibleRunner'
9 changes: 8 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mmengine
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.registry import RUNNERS
from mmengine.runner import Runner


Expand Down Expand Up @@ -169,7 +170,13 @@ def main():
cfg = merge_args(cfg, args)

# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)

if args.out and args.out_item in ['pred', None]:
runner.test_evaluator.metrics.append(
Expand Down
9 changes: 8 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from copy import deepcopy

from mmengine.config import Config, ConfigDict, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand Down Expand Up @@ -149,7 +150,13 @@ def main():
cfg = merge_args(cfg, args)

# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)

# start training
runner.train()
Expand Down