Skip to content

Commit c7ec630

Browse files
committed
Merge branch 'dev-1.x' into 1.x
2 parents e9f9bb2 + 0d8f918 commit c7ec630

File tree

220 files changed

+9805
-452
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

220 files changed

+9805
-452
lines changed

.circleci/test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ workflows:
207207
- lint
208208
- build_cpu_with_3rdparty:
209209
name: maximum_version_cpu
210-
torch: 1.12.1
211-
torchvision: 0.13.1
212-
python: 3.9.0
210+
torch: 1.13.0
211+
torchvision: 0.14.0
212+
python: 3.10.0
213213
requires:
214214
- minimum_version_cpu
215215
- hold:

.dev_scripts/benchmark_regression/1-benchmark_valid.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from mmengine import Config, DictAction, MMLogger
1313
from mmengine.dataset import Compose, default_collate
1414
from mmengine.fileio import FileClient
15-
from mmengine.runner import Runner
15+
from mmengine.runner import Runner, load_checkpoint
1616
from modelindex.load_model_index import load
1717
from rich.console import Console
1818
from rich.table import Table
1919

20+
from mmcls.apis import init_model
2021
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
2122
from mmcls.utils import register_all_modules
2223
from mmcls.visualization import ClsVisualizer
@@ -82,37 +83,44 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
8283
if args.cfg_options is not None:
8384
cfg.merge_from_dict(args.cfg_options)
8485

85-
# build the data pipeline
86-
test_dataset = cfg.test_dataloader.dataset
87-
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
88-
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
89-
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
90-
# The image shape of CIFAR is (32, 32, 3)
91-
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
92-
93-
data = Compose(test_dataset.pipeline)({'img_path': args.img})
94-
data = default_collate([data] * args.batch_size)
95-
resolution = tuple(data['inputs'].shape[-2:])
96-
97-
runner: Runner = Runner.from_cfg(cfg)
98-
model = runner.model
86+
if 'test_dataloader' in cfg:
87+
# build the data pipeline
88+
test_dataset = cfg.test_dataloader.dataset
89+
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
90+
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
91+
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
92+
# The image shape of CIFAR is (32, 32, 3)
93+
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))
94+
95+
data = Compose(test_dataset.pipeline)({'img_path': args.img})
96+
data = default_collate([data] * args.batch_size)
97+
resolution = tuple(data['inputs'].shape[-2:])
98+
model = Runner.from_cfg(cfg).model
99+
forward = model.val_step
100+
else:
101+
# For configs only for get model.
102+
model = init_model(cfg)
103+
load_checkpoint(model, checkpoint, map_location='cpu')
104+
data = torch.empty(1, 3, 224, 224).to(model.data_preprocessor.device)
105+
resolution = (224, 224)
106+
forward = model.extract_feat
99107

100108
# forward the model
101109
result = {'resolution': resolution}
102110
with torch.no_grad():
103111
if args.inference_time:
104112
time_record = []
105113
for _ in range(10):
106-
model.val_step(data) # warmup before profiling
114+
forward(data) # warmup before profiling
107115
torch.cuda.synchronize()
108116
start = time()
109-
model.val_step(data)
117+
forward(data)
110118
torch.cuda.synchronize()
111119
time_record.append((time() - start) / args.batch_size * 1000)
112120
result['time_mean'] = np.mean(time_record[1:-1])
113121
result['time_std'] = np.std(time_record[1:-1])
114122
else:
115-
model.val_step(data)
123+
forward(data)
116124

117125
result['model'] = config_file.stem
118126

@@ -144,8 +152,8 @@ def show_summary(summary_data, args):
144152
if args.inference_time:
145153
table.add_column('Inference Time (std) (ms/im)')
146154
if args.flops:
147-
table.add_column('Flops', justify='right', width=11)
148-
table.add_column('Params', justify='right')
155+
table.add_column('Flops', justify='right', width=13)
156+
table.add_column('Params', justify='right', width=11)
149157

150158
for model_name, summary in summary_data.items():
151159
row = [model_name]

.dev_scripts/ckpt_tree.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import argparse
2+
import math
3+
from pathlib import Path
4+
5+
import torch
6+
from rich.console import Console
7+
8+
console = Console()
9+
10+
prog_description = """\
11+
Draw the state dict tree.
12+
"""
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser(description=prog_description)
17+
parser.add_argument(
18+
'path',
19+
type=Path,
20+
help='The path of the checkpoint or model config to draw.')
21+
parser.add_argument('--depth', type=int, help='The max depth to draw.')
22+
parser.add_argument(
23+
'--full-name',
24+
action='store_true',
25+
help='Whether to print the full name of the key.')
26+
parser.add_argument(
27+
'--shape',
28+
action='store_true',
29+
help='Whether to print the shape of the parameter.')
30+
parser.add_argument(
31+
'--state-key',
32+
type=str,
33+
help='The key of the state dict in the checkpoint.')
34+
parser.add_argument(
35+
'--number',
36+
action='store_true',
37+
help='Mark all parameters and their index number.')
38+
parser.add_argument(
39+
'--node',
40+
type=str,
41+
help='Show the sub-tree of a node, like "backbone.layers".')
42+
args = parser.parse_args()
43+
return args
44+
45+
46+
def ckpt_to_state_dict(checkpoint, key=None):
47+
if key is not None:
48+
state_dict = checkpoint[key]
49+
elif 'state_dict' in checkpoint:
50+
# try mmcls style
51+
state_dict = checkpoint['state_dict']
52+
elif 'model' in checkpoint:
53+
state_dict = checkpoint['model']
54+
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
55+
# try native style
56+
state_dict = checkpoint
57+
else:
58+
raise KeyError('Please specify the key of state '
59+
f'dict from {list(checkpoint.keys())}.')
60+
return state_dict
61+
62+
63+
class StateDictTree:
64+
65+
def __init__(self, key='', value=None):
66+
self.children = {}
67+
self.key: str = key
68+
self.value = value
69+
70+
def add_parameter(self, key, value):
71+
keys = key.split('.', 1)
72+
if len(keys) == 1:
73+
self.children[key] = StateDictTree(key, value)
74+
elif keys[0] in self.children:
75+
self.children[keys[0]].add_parameter(keys[1], value)
76+
else:
77+
node = StateDictTree(keys[0])
78+
node.add_parameter(keys[1], value)
79+
self.children[keys[0]] = node
80+
81+
def __getitem__(self, key: str):
82+
return self.children[key]
83+
84+
def __repr__(self) -> str:
85+
with console.capture() as capture:
86+
for line in self.iter_tree():
87+
console.print(line)
88+
return capture.get()
89+
90+
def __len__(self):
91+
return len(self.children)
92+
93+
def draw_tree(self,
94+
max_depth=None,
95+
full_name=False,
96+
with_shape=False,
97+
with_value=False):
98+
for line in self.iter_tree(
99+
max_depth=max_depth,
100+
full_name=full_name,
101+
with_shape=with_shape,
102+
with_value=with_value,
103+
):
104+
console.print(line, highlight=False)
105+
106+
def iter_tree(
107+
self,
108+
lead='',
109+
prefix='',
110+
max_depth=None,
111+
full_name=False,
112+
with_shape=False,
113+
with_value=False,
114+
):
115+
if self.value is None:
116+
key_str = f'[blue]{self.key}[/]'
117+
elif with_shape:
118+
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
119+
elif with_value:
120+
key_str = f'[green]{self.key}[/] {self.value}'
121+
else:
122+
key_str = f'[green]{self.key}[/]'
123+
124+
yield lead + prefix + key_str
125+
126+
lead = lead.replace('├─', '│ ')
127+
lead = lead.replace('└─', ' ')
128+
if self.key and full_name:
129+
prefix = f'{prefix}{self.key}.'
130+
131+
if max_depth == 0:
132+
return
133+
elif max_depth is not None:
134+
max_depth -= 1
135+
136+
for i, child in enumerate(self.children.values()):
137+
level_lead = '├─' if i < len(self.children) - 1 else '└─'
138+
yield from child.iter_tree(
139+
lead=f'{lead}{level_lead} ',
140+
prefix=prefix,
141+
max_depth=max_depth,
142+
full_name=full_name,
143+
with_shape=with_shape,
144+
with_value=with_value)
145+
146+
147+
def main():
148+
args = parse_args()
149+
if args.path.suffix in ['.json', '.py', '.yml']:
150+
from mmengine.runner import get_state_dict
151+
152+
from mmcls.apis import init_model
153+
model = init_model(args.path, device='cpu')
154+
state_dict = get_state_dict(model)
155+
else:
156+
ckpt = torch.load(args.path, map_location='cpu')
157+
state_dict = ckpt_to_state_dict(ckpt, args.state_key)
158+
159+
root = StateDictTree()
160+
for k, v in state_dict.items():
161+
root.add_parameter(k, v)
162+
163+
para_index = 0
164+
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
165+
if args.node is not None:
166+
for key in args.node.split('.'):
167+
root = root[key]
168+
169+
for line in root.iter_tree(
170+
max_depth=args.depth,
171+
full_name=args.full_name,
172+
with_shape=args.shape,
173+
):
174+
if not args.number:
175+
mark = ''
176+
# A hack method to determine whether a line is parameter.
177+
elif '[green]' in line:
178+
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
179+
para_index += 1
180+
else:
181+
mark = ' ' * (mark_width + 2)
182+
console.print(mark + line, highlight=False)
183+
184+
185+
if __name__ == '__main__':
186+
main()

0 commit comments

Comments
 (0)