forked from yan-hao-tian/VW
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_flops.py
executable file
·120 lines (102 loc) · 4.22 KB
/
get_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
from mmcv import Config
from mmcv.cnn import get_model_complexity_info
from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
from mmseg.models import build_segmentor
import torch
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[2048, 1024],
help='input image size')
args = parser.parse_args()
return args
def sra_flops(h, w, r, dim, num_heads):
dim_h = dim / num_heads
n1 = h * w
n2 = h / r * w / r
f1 = n1 * dim_h * n2 * num_heads
f2 = n1 * n2 * dim_h * num_heads
return f1 + f2
def get_tr_flops(net, input_shape):
flops, params = get_model_complexity_info(net, input_shape, as_strings=False)
_, H, W = input_shape
net = net.backbone
try:
stage1 = sra_flops(H // 4, W // 4,
net.block1[0].attn.sr_ratio,
net.block1[0].attn.dim,
net.block1[0].attn.num_heads) * len(net.block1)
stage2 = sra_flops(H // 8, W // 8,
net.block2[0].attn.sr_ratio,
net.block2[0].attn.dim,
net.block2[0].attn.num_heads) * len(net.block2)
stage3 = sra_flops(H // 16, W // 16,
net.block3[0].attn.sr_ratio,
net.block3[0].attn.dim,
net.block3[0].attn.num_heads) * len(net.block3)
stage4 = sra_flops(H // 32, W // 32,
net.block4[0].attn.sr_ratio,
net.block4[0].attn.dim,
net.block4[0].attn.num_heads) * len(net.block4)
except:
stage1 = sra_flops(H // 4, W // 4,
net.block1[0].attn.squeeze_ratio,
64,
net.block1[0].attn.num_heads) * len(net.block1)
stage2 = sra_flops(H // 8, W // 8,
net.block2[0].attn.squeeze_ratio,
128,
net.block2[0].attn.num_heads) * len(net.block2)
stage3 = sra_flops(H // 16, W // 16,
net.block3[0].attn.squeeze_ratio,
320,
net.block3[0].attn.num_heads) * len(net.block3)
stage4 = sra_flops(H // 32, W // 32,
net.block4[0].attn.squeeze_ratio,
512,
net.block4[0].attn.num_heads) * len(net.block4)
print(stage1 + stage2 + stage3 + stage4)
flops += stage1 + stage2 + stage3 + stage4
return flops_to_string(flops), params_to_string(params)
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
cfg.model.pretrained = None
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')).cuda()
model.eval()
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
# from IPython import embed; embed()
if hasattr(model.backbone, 'block1'):
print('#### get transformer flops ####')
with torch.no_grad():
flops, params = get_tr_flops(model, input_shape)
else:
print('#### get CNN flops ####')
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()