!6677 add mobilenetv3 and resnext50 hub

Merge pull request !6677 from zhaoting/hub
This commit is contained in:
mindspore-ci-bot 2020-09-22 21:20:57 +08:00 committed by Gitee
commit ff5828b66a
12 changed files with 103 additions and 21 deletions

View File

@ -77,6 +77,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
│ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn
├── train.py # training script
├── eval.py # evaluation script
├── mindspore_hub_conf.py # mindspore hub interface
```
## [Training process](#contents)

View File

@ -119,7 +119,7 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
for param in network.get_parameters():
param.requires_grad = False
def define_net(config, is_training):
def define_net(config, is_training=True):
backbone_net = MobileNetV2Backbone()
activation = config.activation if not is_training else "None"
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,

View File

@ -69,6 +69,7 @@ Dataset used: [imagenet](http://www.image-net.org/)
│ ├──mobilenetV3.py # MobileNetV3 architecture
├── train.py # training script
├── eval.py # evaluation script
├── mindspore_hub_conf.py # mindspore hub interface
```
## [Training process](#contents)

View File

@ -42,7 +42,7 @@ if __name__ == '__main__':
raise ValueError("Unsupported device_target.")
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net = mobilenet_v3_large(num_classes=config.num_classes)
net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax")
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,

View File

@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.mobilenetV3 import mobilenet_v3_large, mobilenet_v3_small
def create_network(name, *args, **kwargs):
if name == "mobilenetv3_large":
net = mobilenet_v3_large(*args, **kwargs)
elif name == "mobilenetv3_small":
net = mobilenet_v3_small(*args, **kwargs)
else:
raise NotImplementedError(f"{name} is not implemented in the repo")
return net

View File

@ -246,7 +246,8 @@ class MobileNetV3(nn.Cell):
>>> MobileNetV3(num_classes=1000)
"""
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0.,
round_nearest=8, include_top=True, activation="None"):
super(MobileNetV3, self).__init__()
self.cfgs = model_cfgs['cfg']
self.inplanes = 16
@ -285,19 +286,34 @@ class MobileNetV3(nn.Cell):
# make it nn.CellList
self.features = nn.SequentialCell(self.features)
self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'],
out_channels=num_classes,
kernel_size=1, has_bias=True, pad_mode='pad')
self.squeeze = P.Squeeze(axis=(2, 3))
self.include_top = include_top
self.need_activation = False
if self.include_top:
self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'],
out_channels=num_classes,
kernel_size=1, has_bias=True, pad_mode='pad')
self.squeeze = P.Squeeze(axis=(2, 3))
if activation != "None":
self.need_activation = True
if activation == "Sigmoid":
self.activation = P.Sigmoid()
elif activation == "Softmax":
self.activation = P.Softmax()
else:
raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].")
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.output(x)
x = self.squeeze(x)
if self.include_top:
x = self.output(x)
x = self.squeeze(x)
if self.need_activation:
x = self.activation(x)
return x
def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1):
mid_planes = exp_ch
out_planes = out_channel

View File

@ -96,7 +96,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
├─warmup_cosine_annealing.py # learning rate each step
├─warmup_step_lr.py # warmup step learning rate
├─eval.py # eval net
└─train.py # train net
├──train.py # train net
├──mindspore_hub_conf.py # mindspore hub interface
```

View File

@ -201,7 +201,7 @@ def test(cloud_args=None):
max_epoch=1, rank=args.rank, group_size=args.group_size,
mode='eval')
eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
network = get_network(args.backbone, args.num_classes, platform=args.platform)
network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform)
if network is None:
raise NotImplementedError('not implement {}'.format(args.backbone))

View File

@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.image_classification import get_network
def create_network(name, *args, **kwargs):
if name == "renext50":
get_network("renext50", *args, **kwargs)
return net
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -31,31 +31,46 @@ class ImageClassificationNetwork(nn.Cell):
Returns:
Tensor, output tensor.
"""
def __init__(self, backbone, head):
def __init__(self, backbone, head, include_top=True, activation="None"):
super(ImageClassificationNetwork, self).__init__()
self.backbone = backbone
self.head = head
self.include_top = include_top
self.need_activation = False
if self.include_top:
self.head = head
if activation != "None":
self.need_activation = True
if activation == "Sigmoid":
self.activation = P.Sigmoid()
elif activation == "Softmax":
self.activation = P.Softmax()
else:
raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].")
def construct(self, x):
x = self.backbone(x)
x = self.head(x)
if self.include_top:
x = self.head(x)
if self.need_activation:
x = self.activation(x)
return x
class Resnet(ImageClassificationNetwork):
"""
Resnet architecture.
Args:
backbone_name (string): backbone.
num_classes (int): number of classes.
num_classes (int): number of classes, Default is 1000.
Returns:
Resnet.
"""
def __init__(self, backbone_name, num_classes, platform="Ascend"):
def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"):
self.backbone_name = backbone_name
backbone = backbones.__dict__[self.backbone_name](platform=platform)
out_channels = backbone.get_out_channels()
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
super(Resnet, self).__init__(backbone, head)
super(Resnet, self).__init__(backbone, head, include_top, activation)
default_recurisive_init(self)
@ -79,7 +94,7 @@ class Resnet(ImageClassificationNetwork):
def get_network(backbone_name, num_classes, platform="Ascend"):
def get_network(backbone_name, **kwargs):
if backbone_name in ['resnext50']:
return Resnet(backbone_name, num_classes, platform)
return Resnet(backbone_name, **kwargs)
return None

View File

@ -213,7 +213,7 @@ def train(cloud_args=None):
# network
args.logger.important_info('start create network')
# get network and init
network = get_network(args.backbone, args.num_classes, platform=args.platform)
network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform)
if network is None:
raise NotImplementedError('not implement {}'.format(args.backbone))

View File

@ -114,7 +114,8 @@ sh run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
├─ lr_schedule.py ## learning ratio generator
└─ ssd.py ## ssd architecture
├─ eval.py ## eval scripts
└─ train.py ## train scripts
├─ train.py ## train scripts
├── mindspore_hub_conf.py # mindspore hub interface
```
## [Script Parameters](#contents)