-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathpretrain_detector.py
171 lines (130 loc) · 6.35 KB
/
pretrain_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Optional step of pretrained the detector.
Based on http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html.
"""
import os
import sys
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from detector.engine import train_one_epoch
import detector.utils as utils
import detector.transforms as T
from dataloaders.visual_genome import VG
from PIL import Image
from lib.pytorch_misc import save_checkpoint, get_smallest_lr
from config import BOX_SCALE
VG.split = sys.argv[1]
data_dir = sys.argv[2]
save_dir = sys.argv[3]
checkpoint_name = '%s_maskrcnn_res50fpn.pth' % VG.split
if not os.path.exists(save_dir):
if len(save_dir) == 0:
raise ValueError("save_dir must be a valid path")
os.mkdir(save_dir)
class VGLoader(VG):
def __init__(self, mode, data_dir, transforms):
super(VGLoader, self).__init__(mode, data_dir, num_val_im=5000, filter_duplicate_rels=True,
min_graph_size=-1,
max_graph_size=-1,
filter_non_overlap=False)
self.transforms = transforms
def __getitem__(self, idx):
index = idx
img = Image.open(os.path.join(self.images_dir, self.filenames[index])).convert('RGB')
w, h = img.size
gt_boxes = self.gt_boxes[index].copy()
if VG.split == 'stanford':
# makes boxes scale the same as images
gt_boxes = gt_boxes / (BOX_SCALE / max(w, h))
if self.is_train:
# crop boxes that are too large.
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]].clip(None, h)
gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]].clip(None, w)
if VG.split in ['vte', 'gqa']:
# width, height can become zero after clipping (need to double-check why)
ind_zero = (gt_boxes[:, 2] - gt_boxes[:, 0]) == 0 & (gt_boxes[:, 0] > 0) # x1 == x2 and x1 > 0
gt_boxes[ind_zero, 0] -= 1
ind_zero = (gt_boxes[:, 3] - gt_boxes[:, 1]) == 0 & (gt_boxes[:, 1] > 0) # y1 == y2 and y1 > 0
gt_boxes[ind_zero, 1] -= 1
gt_boxes = torch.as_tensor(gt_boxes, dtype=torch.float32)
target = {}
target["boxes"] = gt_boxes
target["labels"] = torch.from_numpy(self.gt_classes[index]).long()
# target["masks"] = masks # no mask annotations
target["image_id"] = torch.tensor([idx])
target["area"] = (gt_boxes[:, 3] - gt_boxes[:, 1]) * (gt_boxes[:, 2] - gt_boxes[:, 0])
target["iscrowd"] = torch.zeros((len(self.gt_classes[index]),), dtype=torch.int64) # suppose all instances are not crowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def get_model_optimizer(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
model.roi_heads.mask_predictor = None # no masks in these datasets
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, # didn't tune these values but looks good
momentum=0.9, weight_decay=0.0005)
start_epoch = -1
if os.path.exists(checkpoint_name):
print('loading the model and optimizer state from %s' % checkpoint_name)
state_dict = torch.load(checkpoint_name)
model.load_state_dict(state_dict['state_dict'])
optimizer.load_state_dict(state_dict['optimizer'])
start_epoch = state_dict['epoch']
return model, optimizer, start_epoch
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def main():
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# our dataset has two classes only - background and person
num_classes = 151 if VG.split == 'stanford' else 1704
# use our dataset and defined transformations
dataset = VGLoader('train', data_dir, get_transform(train=True))
# dataset_test = GQALoader('val', get_transform(train=False))
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=3 if VG.split == 'stanford' else 2, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
# data_loader_test = torch.utils.data.DataLoader(
# dataset_test, batch_size=1, shuffle=False, num_workers=4,
# collate_fn=utils.collate_fn)
# get the model using our helper function
model, optimizer, start_epoch = get_model_optimizer(num_classes)
print('start_epoch', start_epoch)
# move model to the right device
model.to(device)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
# let's train it for 10 epochs
num_epochs = 10
for epoch in range(start_epoch + 1, num_epochs):
print('\nepoch %d, smallest lr %f\n' % (epoch, get_smallest_lr(optimizer)))
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
try:
print("\nCheckpointing to %s" % os.path.join(save_dir, checkpoint_name))
save_checkpoint(model, optimizer, os.path.join(save_dir, checkpoint_name), {'epoch': epoch})
print('done!\n')
except Exception as e:
print('error saving checkpoint', e)
# update the learning rate
lr_scheduler.step(epoch)
# evaluate on the test dataset
# evaluate(model, data_loader_test, device=device) # some issues with evaluation (check coco_eval code)
print("That's it!")
if __name__ == "__main__":
main()