From 0497092cf827b3957f9d589e5641d24870b95b8c Mon Sep 17 00:00:00 2001 From: zhouyaqiang Date: Mon, 21 Sep 2020 10:29:48 +0800 Subject: [PATCH] add hub for densenet121 and inceptionv3 --- .../cv/densenet121/mindspore_hub_conf.py | 21 +++++++++++++++++++ .../cv/densenet121/src/network/densenet.py | 8 +++++-- .../cv/inceptionv3/mindspore_hub_conf.py | 21 +++++++++++++++++++ .../cv/inceptionv3/src/inception_v3.py | 8 +++++-- 4 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 model_zoo/official/cv/densenet121/mindspore_hub_conf.py create mode 100644 model_zoo/official/cv/inceptionv3/mindspore_hub_conf.py diff --git a/model_zoo/official/cv/densenet121/mindspore_hub_conf.py b/model_zoo/official/cv/densenet121/mindspore_hub_conf.py new file mode 100644 index 00000000000..5e1ed149c05 --- /dev/null +++ b/model_zoo/official/cv/densenet121/mindspore_hub_conf.py @@ -0,0 +1,21 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.network import DenseNet121 + +def create_network(name, *args, **kwargs): + if name == 'densenet121': + return DenseNet121(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/densenet121/src/network/densenet.py b/model_zoo/official/cv/densenet121/src/network/densenet.py index 69f42228428..42a0665a8e8 100644 --- a/model_zoo/official/cv/densenet121/src/network/densenet.py +++ b/model_zoo/official/cv/densenet121/src/network/densenet.py @@ -205,11 +205,13 @@ class DenseNet121(nn.Cell): """ the densenet121 architectur """ - def __init__(self, num_classes): + def __init__(self, num_classes, include_top=True): super(DenseNet121, self).__init__() self.backbone = _densenet121() out_channels = self.backbone.get_out_channels() - self.head = CommonHead(num_classes, out_channels) + self.include_top = include_top + if self.include_top: + self.head = CommonHead(num_classes, out_channels) default_recurisive_init(self) for _, cell in self.cells_and_names(): @@ -226,5 +228,7 @@ class DenseNet121(nn.Cell): def construct(self, x): x = self.backbone(x) + if not self.include_top: + return x x = self.head(x) return x diff --git a/model_zoo/official/cv/inceptionv3/mindspore_hub_conf.py b/model_zoo/official/cv/inceptionv3/mindspore_hub_conf.py new file mode 100644 index 00000000000..972ed0302b7 --- /dev/null +++ b/model_zoo/official/cv/inceptionv3/mindspore_hub_conf.py @@ -0,0 +1,21 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.inception_v3 import InceptionV3 + +def create_network(name, *args, **kwargs): + if name == 'inceptionv3': + return InceptionV3(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/inceptionv3/src/inception_v3.py b/model_zoo/official/cv/inceptionv3/src/inception_v3.py index 8d2faf7a914..1facbcf3f42 100644 --- a/model_zoo/official/cv/inceptionv3/src/inception_v3.py +++ b/model_zoo/official/cv/inceptionv3/src/inception_v3.py @@ -203,7 +203,7 @@ class AuxLogits(nn.Cell): class InceptionV3(nn.Cell): - def __init__(self, num_classes=10, is_training=True, has_bias=False, dropout_keep_prob=0.8): + def __init__(self, num_classes=10, is_training=True, has_bias=False, dropout_keep_prob=0.8, include_top=True): super(InceptionV3, self).__init__() self.is_training = is_training self.Conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2, pad_mode='valid', has_bias=has_bias) @@ -226,7 +226,9 @@ class InceptionV3(nn.Cell): self.Mixed_7c = Inception_E(2048, has_bias=has_bias) if is_training: self.aux_logits = AuxLogits(768, num_classes) - self.logits = Logits(num_classes, dropout_keep_prob) + self.include_top = include_top + if self.include_top: + self.logits = Logits(num_classes, dropout_keep_prob) def construct(self, x): x = self.Conv2d_1a(x) @@ -251,6 +253,8 @@ class InceptionV3(nn.Cell): x = self.Mixed_7a(x) x = self.Mixed_7b(x) x = self.Mixed_7c(x) + if not self.include_top: + return x logits = self.logits(x) if self.is_training: return logits, aux_logits