add googlenet include top for hub

This commit is contained in:
huzhifeng 2020-09-16 15:17:47 +08:00
parent 8670306870
commit 60759e8522
1 changed files with 9 additions and 5 deletions

View File

@ -81,7 +81,7 @@ class GoogleNet(nn.Cell):
Googlenet architecture
"""
def __init__(self, num_classes):
def __init__(self, num_classes, include_top=True):
super(GoogleNet, self).__init__()
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
@ -104,11 +104,13 @@ class GoogleNet(nn.Cell):
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
self.mean = P.ReduceMean(keep_dims=True)
self.dropout = nn.Dropout(keep_prob=0.8)
self.flatten = nn.Flatten()
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
bias_init=weight_variable())
self.include_top = include_top
if self.include_top:
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
bias_init=weight_variable())
def construct(self, x):
@ -133,6 +135,8 @@ class GoogleNet(nn.Cell):
x = self.block5a(x)
x = self.block5b(x)
if not self.include_top:
return x
x = self.mean(x, (2, 3))
x = self.flatten(x)