-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmnist_mil_main.py
91 lines (72 loc) · 3.1 KB
/
mnist_mil_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import json
from mnist_mil_trainer import *
import os
import torch
import seaborn as sns
import matplotlib.pyplot as plt
def get_args():
parser = argparse.ArgumentParser(description='MNIST MIL benchmarks:')
parser.add_argument("--project_name", default="MNIST-MIL")
parser.add_argument('--wandb', default=False, type=bool)
# Model params
parser.add_argument('--mode', default="softmax", choices=["softmax", "entmax", "sparsemax"])
parser.add_argument('--d_model', default=512, type=int)
parser.add_argument('--input_size', default=784, type=int)
parser.add_argument('--model', default="pooling", type=str)
parser.add_argument('--num_pattern', default=2, type=int)
parser.add_argument('--n_heads', default=4, type=int)
parser.add_argument('--scale', default=0.1)
parser.add_argument('--update_steps', default=1, type=int)
parser.add_argument('--dropout', default=0.3, type=float)
# Training params
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--epoch', default=100, type=int)
parser.add_argument('--seed', default=1111, type=int)
# Data params
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--train_size', default=2000, type=int)
parser.add_argument('--test_size', default=500, type=int)
parser.add_argument('--pos_per_bag', default=1, type=int)
parser.add_argument('--bag_size', default=10, type=int)
parser.add_argument('--tgt_num', default=9, type=int)
args = parser.parse_args()
return vars(args)
if __name__ == "__main__":
torch.set_num_threads(3)
config = get_args()
trails = 5
torch.manual_seed(config["seed"])
if config["bag_size"] == 100:
config["num_pattern"] = 4
bag_size = config["bag_size"]
# bag_size = [5, 10, 20, 50, 100, 200, 300]
models = ["softmax", "sparsemax", "entmax", "gsh"]
data_log = None
for m in models:
config["mode"] = m
for t in range(trails):
torch.random.manual_seed(torch.random.seed())
trainer = Trainer(config, t)
trail_log = trainer.train()
if data_log is None:
data_log = trail_log
else:
for k,v in data_log.items():
data_log[k] = data_log[k] + trail_log[k]
sns.lineplot(data=data_log, x="epoch", y="train loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
plt.tight_layout()
plt.savefig(f'./imgs/train_loss_{bag_size}.pdf')
plt.clf()
sns.lineplot(data=data_log, x="epoch", y="test loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
plt.tight_layout()
plt.savefig(f'./imgs/test_loss_{bag_size}.pdf')
plt.clf()
sns.lineplot(data=data_log, x="epoch", y="train acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
plt.tight_layout()
plt.savefig(f'./imgs/train_acc_{bag_size}.pdf')
plt.clf()
sns.lineplot(data=data_log, x="epoch", y="test acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
plt.tight_layout()
plt.savefig(f'./imgs/test_acc_{bag_size}.pdf')
plt.clf()