@@ -19,16 +19,46 @@ class HyraxCNN(nn.Module):
19
19
This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.
20
20
"""
21
21
22
- def __init__ (self , config , shape ):
22
+ def __init__ (self , config , shape = ( 3 , 32 , 32 ) ):
23
23
super ().__init__ ()
24
- self .conv1 = nn .Conv2d (3 , 6 , 5 )
24
+ self .config = config
25
+
26
+ self .num_input_channels , self .image_width , self .image_height = shape
27
+ hidden_channels_1 = 6
28
+ hidden_channels_2 = 16
29
+
30
+ # Calculate how much our convolutional layers and pooling will affect
31
+ # the size of final convolution.
32
+ #
33
+ # If the number of layers are changed this will need to be rewritten.
34
+ conv1_end_w = self .conv2d_output_size (self .image_width , kernel_size = 5 )
35
+ conv1_end_h = self .conv2d_output_size (self .image_height , kernel_size = 5 )
36
+
37
+ pool1_end_w = self .pool2d_output_size (conv1_end_w , kernel_size = 2 , stride = 2 )
38
+ pool1_end_h = self .pool2d_output_size (conv1_end_h , kernel_size = 2 , stride = 2 )
39
+
40
+ conv2_end_w = self .conv2d_output_size (pool1_end_w , kernel_size = 5 )
41
+ conv2_end_h = self .conv2d_output_size (pool1_end_h , kernel_size = 5 )
42
+
43
+ pool2_end_w = self .pool2d_output_size (conv2_end_w , kernel_size = 2 , stride = 2 )
44
+ pool2_end_h = self .pool2d_output_size (conv2_end_h , kernel_size = 2 , stride = 2 )
45
+
46
+ self .conv1 = nn .Conv2d (self .num_input_channels , hidden_channels_1 , 5 )
25
47
self .pool = nn .MaxPool2d (2 , 2 )
26
- self .conv2 = nn .Conv2d (6 , 16 , 5 )
27
- self .fc1 = nn .Linear (16 * 5 * 5 , 120 )
48
+ self .conv2 = nn .Conv2d (hidden_channels_1 , hidden_channels_2 , 5 )
49
+ self .fc1 = nn .Linear (hidden_channels_2 * pool2_end_h * pool2_end_w , 120 )
28
50
self .fc2 = nn .Linear (120 , 84 )
29
- self .fc3 = nn .Linear (84 , 10 )
51
+ self .fc3 = nn .Linear (84 , self . config [ "model" ][ "hyrax_cnn" ][ "output_classes" ] )
30
52
31
- self .config = config
53
+ def conv2d_output_size (self , input_size , kernel_size , padding = 0 , stride = 1 , dilation = 1 ) -> int :
54
+ # From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
55
+ numerator = input_size + 2 * padding - dilation * (kernel_size - 1 ) - 1
56
+ return int ((numerator / stride ) + 1 )
57
+
58
+ def pool2d_output_size (self , input_size , kernel_size , stride , padding = 0 , dilation = 1 ) -> int :
59
+ # From https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
60
+ numerator = input_size + 2 * padding - dilation * (kernel_size - 1 ) - 1
61
+ return int ((numerator / stride ) + 1 )
32
62
33
63
def forward (self , x ):
34
64
# This check is inefficient - we assume that the example CNN will be primarily
0 commit comments