This repository was archived by the owner on Sep 13, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmain.py
More file actions
141 lines (118 loc) · 4.83 KB
/
main.py
File metadata and controls
141 lines (118 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from torch.utils.data import DataLoader
import argparse
from datautils import ZSLDataset
from trainer import Trainer
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='awa2')
parser.add_argument('--gzsl', action='store_true', default=False)
parser.add_argument('--latent_dim', type=int, default=128)
parser.add_argument('--n_critic', type=int, default=5)
parser.add_argument('--lmbda', type=float, default=10.0)
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--n_epochs', type=int, default=10)
parser.add_argument('--use_cls_loss', action='store_true', default=False)
parser.add_argument('--visualize', action='store_true', default=False)
args = parser.parse_args()
if args.dataset == 'awa2' or args.dataset == 'awa1':
x_dim = 2048
attr_dim = 85
n_train = 40
n_test = 10
elif args.dataset == 'cub':
x_dim = 2048
attr_dim = 312
n_train = 150
n_test = 50
elif args.dataset == 'sun':
x_dim = 2048
attr_dim = 102
n_train = 645
n_test = 72
else:
raise NotImplementedError
n_epochs = args.n_epochs
# trainer object for mini batch training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_agent = Trainer(
device, x_dim, args.latent_dim, attr_dim,
n_train=n_train, n_test=n_test, gzsl=args.gzsl,
n_critic=args.n_critic, lmbda=args.lmbda, beta=args.beta,
batch_size=args.batch_size
)
params = {
'batch_size': args.batch_size,
'shuffle': True,
'num_workers': 0,
'drop_last': True
}
train_dataset = ZSLDataset(args.dataset, n_train, n_test, args.gzsl)
train_generator = DataLoader(train_dataset, **params)
# =============================================================
# PRETRAIN THE SOFTMAX CLASSIFIER
# =============================================================
model_name = "%s_disc_classifier" % args.dataset
success = train_agent.load_model(model=model_name)
if success:
print("Discriminative classifier parameters loaded...")
else:
print("Training the discriminative classifier...")
for ep in range(1, n_epochs + 1):
loss = 0
for idx, (img_features, label_attr, label_idx) in enumerate(train_generator):
l = train_agent.fit_classifier(img_features, label_attr, label_idx)
loss += l
print("Loss for epoch: %3d - %.4f" %(ep, loss))
train_agent.save_model(model=model_name)
# =============================================================
# TRAIN THE GANs
# =============================================================
model_name = "%s_generator" % args.dataset
success = train_agent.load_model(model=model_name)
if success:
print("\nGAN parameters loaded....")
else:
print("\nTraining the GANS...")
for ep in range(1, n_epochs + 1):
loss_dis = 0
loss_gan = 0
for idx, (img_features, label_attr, label_idx) in enumerate(train_generator):
l_d, l_g = train_agent.fit_GAN(img_features, label_attr, label_idx, args.use_cls_loss)
loss_dis += l_d
loss_gan += l_g
print("Loss for epoch: %3d - D: %.4f | G: %.4f"\
%(ep, loss_dis, loss_gan))
train_agent.save_model(model=model_name)
# =============================================================
# TRAIN FINAL CLASSIFIER ON SYNTHETIC DATASET
# =============================================================
# create new synthetic dataset using trained Generator
seen_dataset = None
if args.gzsl:
seen_dataset = train_dataset.gzsl_dataset
syn_dataset = train_agent.create_syn_dataset(
train_dataset.test_classmap, train_dataset.attributes, seen_dataset)
final_dataset = ZSLDataset(args.dataset, n_train, n_test,
gzsl=args.gzsl, train=True, synthetic=True, syn_dataset=syn_dataset)
final_train_generator = DataLoader(final_dataset, **params)
model_name = "%s_final_classifier" % args.dataset
success = train_agent.load_model(model=model_name)
if success:
print("\nFinal classifier parameters loaded....")
else:
print("\nTraining the final classifier on the synthetic dataset...")
for ep in range(1, n_epochs + 1):
syn_loss = 0
for idx, (img, label_attr, label_idx) in enumerate(final_train_generator):
l = train_agent.fit_final_classifier(img, label_attr, label_idx)
syn_loss += l
# print losses on real and synthetic datasets
print("Loss for epoch: %3d - %.4f" %(ep, syn_loss))
train_agent.save_model(model=model_name)
# =============================================================
# TESTING PHASE
# =============================================================
test_dataset = ZSLDataset(args.dataset, n_train, n_test, gzsl=args.gzsl, train=False)
test_generator = DataLoader(test_dataset, **params)
print("\nFinal Accuracy on ZSL Task: %.3f" % train_agent.test(test_generator))