Skip to content

Commit

Permalink
Merge pull request #308 from not-lain/fix-shape-mismatch
Browse files Browse the repository at this point in the history
fix shape mismatch in VGG architecture
  • Loading branch information
ATaylorAerospace authored Jul 18, 2024
2 parents 75af513 + b9392c4 commit d6c99cb
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions chapters/en/unit2/cnns/vgg.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,20 @@ class VGG19(nn.Module):
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)

# Pooling Layer
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))

# Fully connected layers for classification
self.classifier = nn.Sequential(
nn.Linear(
Expand All @@ -87,6 +99,7 @@ class VGG19(nn.Module):

def forward(self, x):
x = self.feature_extractor(x) # Pass input through the feature extractor layers
x = self.avgpool(x) # Pass Data through a pooling layer
x = x.view(x.size(0), -1) # Flatten the output for the fully connected layers
x = self.classifier(x) # Pass flattened output through the classifier layers
return x
Expand Down

0 comments on commit d6c99cb

Please sign in to comment.