-
Notifications
You must be signed in to change notification settings - Fork 93
Open
Description
I tried using wide_resnet_imagenet64_1000, but found that it is giving zero accuracy. So, wanted to double check if this usage is correct. Here I am evaluating its natural accuracy on: Imagenet64_val_npz which I downloaded form ImageNet website (I also tried certfiied accuracy, which is also zero)
Step 1 load the model:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import sys
import numpy as np
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class wide_basic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1):
super(wide_basic, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
# self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
)
def forward(self, x):
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out
class Wide_ResNet(nn.Module):
def __init__(self, depth, widen_factor, dropout_rate, num_classes,
in_planes=16, in_dim=56):
super(Wide_ResNet, self).__init__()
self.in_planes = in_planes
assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
n = (depth-4)/6
k = widen_factor
print('| Wide-Resnet %dx%d' %(depth, k))
nStages = [in_planes, in_planes*k, in_planes*2*k, in_planes*4*k]
self.conv1 = conv3x3(3,nStages[0])
self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.1)
self.linear = nn.Linear(nStages[3] * (in_dim//4//7)**2, num_classes)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
strides = [stride] + [1]*(int(num_blocks)-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
out = F.avg_pool2d(out, 7)
out = torch.flatten(out, 1)
out = self.linear(out)
return out
def wide_resnet_imagenet64(in_ch=3, in_dim=56, in_planes=16, widen_factor=10):
return Wide_ResNet(10, widen_factor, 0.3, 200, in_dim=in_dim, in_planes=in_planes)
def wide_resnet_imagenet64_1000class(in_ch=3, in_dim=56, in_planes=16, widen_factor=10):
return Wide_ResNet(10, widen_factor, 0.3, 1000, in_dim=in_dim, in_planes=in_planes)
model = wide_resnet_imagenet64_1000class()
chkpt = torch.load("wide_resnet_imagenet64_1000")
model.load_state_dict(chkpt['state_dict'])
model.eval()
pass
Step 2: Load the dataset
import numpy as np
file_path = 'val_data.npz'
npz_data = np.load(file_path)
all_data, all_labels = npz_data['data'], npz_data['labels']
mean = torch.Tensor([0.4815, 0.4578, 0.4082]).reshape([1, 3, 1, 1])
std = torch.Tensor([0.2153, 0.2111, 0.2121]).reshape([1, 3, 1, 1])
Step 3: Evaluate
b_size = 32
correct = 0
for i in range(len(all_data)//b_size):
data, labels = all_data[i*b_size:(i+1)*b_size], all_labels[i*b_size:(i+1)*b_size]
data = data.reshape([len(data), 3, 64, 64])/255
data = torch.Tensor(data[:,:, 4:60, 4:60]).float()
data = (data-mean)/std
labels = torch.Tensor(labels).int()
pred = model(data).max(1)[1]
correct += float((pred == labels).float().sum())
print("\n\n", correct, "\n", pred, "\n", labels)
#if i > 10: break
Here is top 10 output: (No correct prediction). Q: how to evaluate it?
0.0
tensor([765, 228, 773, 561, 428, 190, 41, 988, 196, 859, 424, 934, 286, 26,
16, 146, 442, 719, 675, 951, 342, 864, 608, 778, 479, 153, 324, 736,
210, 351, 977, 327])
tensor([372, 814, 227, 604, 226, 51, 663, 308, 276, 749, 384, 95, 357, 772,
665, 896, 189, 856, 664, 771, 407, 451, 647, 896, 14, 230, 603, 181,
860, 387, 351, 134], dtype=torch.int32)
0.0
tensor([294, 605, 831, 368, 293, 881, 608, 553, 610, 286, 129, 771, 417, 382,
732, 820, 945, 444, 267, 684, 607, 427, 474, 709, 128, 296, 197, 680,
232, 765, 758, 98])
tensor([363, 901, 856, 44, 245, 746, 940, 349, 475, 132, 167, 935, 1, 443,
970, 606, 138, 500, 929, 598, 484, 639, 958, 738, 205, 733, 513, 703,
310, 556, 977, 1], dtype=torch.int32)
0.0
tensor([898, 101, 994, 417, 791, 519, 120, 549, 287, 974, 257, 638, 909, 818,
38, 261, 427, 697, 909, 680, 354, 989, 60, 90, 252, 105, 279, 287,
209, 561, 433, 796])
tensor([754, 180, 803, 78, 969, 291, 684, 706, 330, 977, 34, 674, 179, 257,
731, 731, 639, 898, 149, 930, 418, 721, 175, 19, 326, 13, 836, 358,
288, 606, 618, 489], dtype=torch.int32)
0.0
tensor([986, 518, 9, 928, 29, 304, 864, 723, 137, 453, 582, 512, 928, 23,
366, 356, 432, 438, 765, 904, 371, 398, 77, 609, 536, 435, 465, 902,
617, 287, 315, 691])
tensor([351, 566, 504, 220, 49, 560, 956, 348, 222, 606, 745, 57, 790, 466,
158, 123, 464, 642, 978, 913, 433, 458, 417, 648, 504, 762, 297, 983,
650, 332, 429, 912], dtype=torch.int32)
0.0
tensor([668, 293, 417, 282, 407, 217, 205, 107, 182, 904, 363, 419, 934, 26,
115, 631, 21, 675, 600, 992, 782, 998, 670, 398, 71, 966, 419, 823,
956, 941, 882, 382])
tensor([179, 81, 669, 348, 634, 540, 378, 588, 21, 424, 714, 478, 126, 345,
204, 303, 136, 704, 640, 768, 769, 670, 832, 653, 60, 341, 356, 367,
66, 947, 967, 444], dtype=torch.int32)
0.0
tensor([719, 743, 770, 414, 312, 227, 347, 912, 765, 709, 131, 941, 884, 251,
942, 292, 407, 929, 345, 144, 648, 743, 674, 555, 951, 70, 26, 135,
237, 590, 547, 286])
tensor([ 77, 670, 41, 996, 90, 203, 411, 93, 556, 816, 495, 611, 866, 737,
913, 59, 446, 936, 456, 229, 683, 620, 709, 164, 406, 114, 111, 884,
297, 43, 592, 711], dtype=torch.int32)
0.0
tensor([207, 289, 127, 287, 556, 182, 354, 411, 478, 370, 89, 74, 980, 579,
100, 6, 745, 756, 180, 226, 590, 864, 0, 11, 841, 284, 357, 657,
288, 798, 74, 989])
tensor([286, 360, 395, 739, 113, 926, 695, 6, 373, 675, 672, 166, 280, 174,
124, 46, 108, 18, 261, 406, 727, 216, 196, 109, 580, 350, 99, 699,
816, 818, 624, 197], dtype=torch.int32)
0.0
tensor([559, 690, 902, 35, 604, 545, 155, 371, 145, 463, 51, 941, 814, 162,
377, 31, 697, 50, 707, 942, 841, 832, 118, 100, 881, 622, 205, 573,
291, 770, 32, 989])
tensor([633, 676, 940, 212, 624, 158, 239, 433, 503, 978, 4, 947, 831, 245,
142, 552, 250, 427, 455, 698, 270, 774, 174, 158, 795, 512, 257, 562,
295, 966, 60, 100], dtype=torch.int32)
0.0
tensor([356, 529, 943, 631, 171, 555, 253, 127, 486, 545, 122, 903, 439, 71,
74, 441, 766, 442, 679, 783, 344, 966, 244, 32, 647, 179, 621, 174,
605, 716, 879, 425])
tensor([508, 982, 865, 317, 572, 733, 653, 199, 906, 127, 837, 722, 147, 471,
38, 494, 696, 909, 72, 564, 998, 634, 319, 151, 527, 71, 659, 256,
44, 501, 929, 487], dtype=torch.int32)
0.0
tensor([394, 873, 908, 29, 330, 388, 109, 457, 632, 436, 407, 464, 178, 836,
992, 127, 709, 373, 73, 933, 436, 741, 988, 115, 568, 371, 388, 5,
212, 544, 711, 711])
tensor([541, 635, 292, 168, 397, 449, 198, 394, 935, 479, 23, 559, 260, 814,
304, 440, 734, 826, 115, 916, 498, 917, 474, 187, 1, 431, 449, 170,
30, 545, 322, 740], dtype=torch.int32)
0.0
tensor([942, 152, 601, 890, 128, 207, 716, 189, 333, 618, 192, 900, 530, 794,
71, 6, 638, 77, 934, 633, 600, 755, 864, 992, 247, 606, 914, 811,
561, 631, 516, 785])
tensor([948, 881, 165, 91, 890, 889, 744, 880, 400, 657, 96, 774, 1, 546,
919, 588, 597, 169, 59, 375, 632, 508, 893, 872, 328, 738, 732, 730,
606, 666, 902, 69], dtype=torch.int32)
0.0
tensor([ 93, 225, 408, 227, 209, 555, 35, 241, 399, 380, 155, 178, 622, 875,
872, 987, 494, 973, 962, 585, 205, 613, 564, 439, 261, 335, 48, 208,
71, 337, 862, 611])
tensor([ 211, 301, 506, 710, 68, 374, 375, 301, 794, 751, 112, 274,
660, 622, 152, 716, 858, 113, 811, 134, 549, 651, 609, 741,
903, 1000, 735, 280, 922, 410, 292, 788], dtype=torch.int32)
Metadata
Metadata
Assignees
Labels
No labels