!6688 vgg16 hub support

Merge pull request !6688 from caojian05/ms_master_vgg16_hub
This commit is contained in:
mindspore-ci-bot 2020-09-22 14:31:50 +08:00 committed by Gitee
commit f7150a6dd8
2 changed files with 33 additions and 3 deletions

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vgg import vgg16 as VGG16
def vgg16(*args, **kwargs):
return VGG16(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == "vgg16":
return vgg16(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -60,6 +60,7 @@ class Vgg(nn.Cell):
num_classes (int): Class numbers. Default: 1000.
batch_norm (bool): Whether to do the batchnorm. Default: False.
batch_size (int): Batch size. Default: 1.
include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True.
Returns:
Tensor, infer output tensor.
@ -69,10 +70,12 @@ class Vgg(nn.Cell):
>>> num_classes=1000, batch_norm=False, batch_size=1)
"""
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train",
include_top=True):
super(Vgg, self).__init__()
_ = batch_size
self.layers = _make_layer(base, args, batch_norm=batch_norm)
self.include_top = include_top
self.flatten = nn.Flatten()
dropout_ratio = 0.5
if not args.has_dropout or phase == "test":
@ -91,8 +94,9 @@ class Vgg(nn.Cell):
def construct(self, x):
x = self.layers(x)
x = self.flatten(x)
x = self.classifier(x)
if self.include_top:
x = self.flatten(x)
x = self.classifier(x)
return x
def custom_init_weight(self):