-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
49 lines (39 loc) · 1.58 KB
/
classifier.py
File metadata and controls
49 lines (39 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch import nn
import copy
class MaskedClassifier(nn.Module):
def __init__(self, x_dim, mask, device):
super().__init__()
self.x_dim = x_dim
self.mask = mask.to(device)
self.device = device
self.all_classes_dim, = mask.shape
self.net = nn.Sequential(nn.Linear(self.x_dim, self.all_classes_dim)).to(device)
def forward(self, X):
pred = self.net(X)
masked_pred = torch.where(self.mask, pred, torch.FloatTensor([-1e9]).to(self.device))
return masked_pred
def train_cls(cls_mask, train_X, train_y, val_X, val_y, device):
_, fea_dim = train_X.shape
linear_cls = MaskedClassifier(fea_dim, cls_mask.to(device), device)
optimizer = torch.optim.Adam(linear_cls.parameters())
criterion = nn.CrossEntropyLoss()
def accuracy(pred, y):
return (torch.sum(torch.argmax(pred, dim=-1) == y)).item() / y.shape[0]
best_val_acc = 0.0
best_linear_cls = None
for i in range(10): # 100
optimizer.zero_grad()
train_pred = linear_cls(train_X.to(device))
loss = criterion(train_pred, train_y)
loss.backward()
optimizer.step()
val_seen_pred = linear_cls(val_X.to(device))
train_acc = accuracy(train_pred, train_y)
val_acc = accuracy(val_seen_pred, val_y)
print(f'Loss: {loss}, train acc: {train_acc}, val acc: {val_acc}')
if val_acc > best_val_acc:
print('Best!')
best_val_acc = val_acc
best_linear_cls = copy.deepcopy(linear_cls).to(device)
return best_linear_cls