Skip to content

Commit 56ae5fc

Browse files
authored
Updating hyrax_cnn so that the first linear layer is dynamic based on input image size. Removing superfluous testing notebook. (#386)
1 parent da513c0 commit 56ae5fc

File tree

3 files changed

+42
-117
lines changed

3 files changed

+42
-117
lines changed

src/hyrax/data_sets/random_dataset_testing.ipynb

Lines changed: 0 additions & 111 deletions
This file was deleted.

src/hyrax/hyrax_default_config.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ latent_dim = 64
9595
final_layer = "tanh"
9696

9797

98+
[model.hyrax_cnn]
99+
# The number of classes to predict as the output of the model. i.e. 2 would be a
100+
# binary classifer, 10 would predict the 10 classes in the CiFAR dataset.
101+
output_classes = 10
102+
103+
98104
[criterion]
99105
# The name of the built-in criterion to use or the import path to an external criterion
100106
name = "torch.nn.CrossEntropyLoss"

src/hyrax/models/hyrax_cnn.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,46 @@ class HyraxCNN(nn.Module):
1919
This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.
2020
"""
2121

22-
def __init__(self, config, shape):
22+
def __init__(self, config, shape=(3, 32, 32)):
2323
super().__init__()
24-
self.conv1 = nn.Conv2d(3, 6, 5)
24+
self.config = config
25+
26+
self.num_input_channels, self.image_width, self.image_height = shape
27+
hidden_channels_1 = 6
28+
hidden_channels_2 = 16
29+
30+
# Calculate how much our convolutional layers and pooling will affect
31+
# the size of final convolution.
32+
#
33+
# If the number of layers are changed this will need to be rewritten.
34+
conv1_end_w = self.conv2d_output_size(self.image_width, kernel_size=5)
35+
conv1_end_h = self.conv2d_output_size(self.image_height, kernel_size=5)
36+
37+
pool1_end_w = self.pool2d_output_size(conv1_end_w, kernel_size=2, stride=2)
38+
pool1_end_h = self.pool2d_output_size(conv1_end_h, kernel_size=2, stride=2)
39+
40+
conv2_end_w = self.conv2d_output_size(pool1_end_w, kernel_size=5)
41+
conv2_end_h = self.conv2d_output_size(pool1_end_h, kernel_size=5)
42+
43+
pool2_end_w = self.pool2d_output_size(conv2_end_w, kernel_size=2, stride=2)
44+
pool2_end_h = self.pool2d_output_size(conv2_end_h, kernel_size=2, stride=2)
45+
46+
self.conv1 = nn.Conv2d(self.num_input_channels, hidden_channels_1, 5)
2547
self.pool = nn.MaxPool2d(2, 2)
26-
self.conv2 = nn.Conv2d(6, 16, 5)
27-
self.fc1 = nn.Linear(16 * 5 * 5, 120)
48+
self.conv2 = nn.Conv2d(hidden_channels_1, hidden_channels_2, 5)
49+
self.fc1 = nn.Linear(hidden_channels_2 * pool2_end_h * pool2_end_w, 120)
2850
self.fc2 = nn.Linear(120, 84)
29-
self.fc3 = nn.Linear(84, 10)
51+
self.fc3 = nn.Linear(84, self.config["model"]["hyrax_cnn"]["output_classes"])
3052

31-
self.config = config
53+
def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int:
54+
# From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
55+
numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1
56+
return int((numerator / stride) + 1)
57+
58+
def pool2d_output_size(self, input_size, kernel_size, stride, padding=0, dilation=1) -> int:
59+
# From https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
60+
numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1
61+
return int((numerator / stride) + 1)
3262

3363
def forward(self, x):
3464
# This check is inefficient - we assume that the example CNN will be primarily

0 commit comments

Comments
 (0)