From 60759e85229802af5b9f005e5c49a814b485f391 Mon Sep 17 00:00:00 2001 From: huzhifeng Date: Wed, 16 Sep 2020 15:17:47 +0800 Subject: [PATCH] add googlenet include top for hub --- model_zoo/official/cv/googlenet/src/googlenet.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/model_zoo/official/cv/googlenet/src/googlenet.py b/model_zoo/official/cv/googlenet/src/googlenet.py index 78695f2d6ce..2ccf3954871 100644 --- a/model_zoo/official/cv/googlenet/src/googlenet.py +++ b/model_zoo/official/cv/googlenet/src/googlenet.py @@ -81,7 +81,7 @@ class GoogleNet(nn.Cell): Googlenet architecture """ - def __init__(self, num_classes): + def __init__(self, num_classes, include_top=True): super(GoogleNet, self).__init__() self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") @@ -104,11 +104,13 @@ class GoogleNet(nn.Cell): self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) - self.mean = P.ReduceMean(keep_dims=True) self.dropout = nn.Dropout(keep_prob=0.8) - self.flatten = nn.Flatten() - self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), - bias_init=weight_variable()) + self.include_top = include_top + if self.include_top: + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), + bias_init=weight_variable()) def construct(self, x): @@ -133,6 +135,8 @@ class GoogleNet(nn.Cell): x = self.block5a(x) x = self.block5b(x) + if not self.include_top: + return x x = self.mean(x, (2, 3)) x = self.flatten(x)