forked from mindspore-Ecosystem/mindspore
!6335 add include_top parameter in googlenet for hub
Merge pull request !6335 from hzf/add_inctop
This commit is contained in:
commit
defd74e261
|
@ -81,7 +81,7 @@ class GoogleNet(nn.Cell):
|
|||
Googlenet architecture
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes):
|
||||
def __init__(self, num_classes, include_top=True):
|
||||
super(GoogleNet, self).__init__()
|
||||
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
@ -104,11 +104,13 @@ class GoogleNet(nn.Cell):
|
|||
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
|
||||
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.dropout = nn.Dropout(keep_prob=0.8)
|
||||
self.flatten = nn.Flatten()
|
||||
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
|
||||
bias_init=weight_variable())
|
||||
self.include_top = include_top
|
||||
if self.include_top:
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
|
||||
bias_init=weight_variable())
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -133,6 +135,8 @@ class GoogleNet(nn.Cell):
|
|||
|
||||
x = self.block5a(x)
|
||||
x = self.block5b(x)
|
||||
if not self.include_top:
|
||||
return x
|
||||
|
||||
x = self.mean(x, (2, 3))
|
||||
x = self.flatten(x)
|
||||
|
|
Loading…
Reference in New Issue