forked from mindspore-Ecosystem/mindspore
!6335 add include_top parameter in googlenet for hub
Merge pull request !6335 from hzf/add_inctop
This commit is contained in:
commit
defd74e261
|
@ -81,7 +81,7 @@ class GoogleNet(nn.Cell):
|
||||||
Googlenet architecture
|
Googlenet architecture
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes, include_top=True):
|
||||||
super(GoogleNet, self).__init__()
|
super(GoogleNet, self).__init__()
|
||||||
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
|
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
|
||||||
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||||
|
@ -104,11 +104,13 @@ class GoogleNet(nn.Cell):
|
||||||
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
|
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
|
||||||
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
|
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
|
||||||
|
|
||||||
self.mean = P.ReduceMean(keep_dims=True)
|
|
||||||
self.dropout = nn.Dropout(keep_prob=0.8)
|
self.dropout = nn.Dropout(keep_prob=0.8)
|
||||||
self.flatten = nn.Flatten()
|
self.include_top = include_top
|
||||||
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
|
if self.include_top:
|
||||||
bias_init=weight_variable())
|
self.mean = P.ReduceMean(keep_dims=True)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
|
||||||
|
bias_init=weight_variable())
|
||||||
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -133,6 +135,8 @@ class GoogleNet(nn.Cell):
|
||||||
|
|
||||||
x = self.block5a(x)
|
x = self.block5a(x)
|
||||||
x = self.block5b(x)
|
x = self.block5b(x)
|
||||||
|
if not self.include_top:
|
||||||
|
return x
|
||||||
|
|
||||||
x = self.mean(x, (2, 3))
|
x = self.mean(x, (2, 3))
|
||||||
x = self.flatten(x)
|
x = self.flatten(x)
|
||||||
|
|
Loading…
Reference in New Issue