From 212d7ebd0bbc5ad52310f2e1e33d8d33a92ccdf6 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Wed, 27 Jan 2021 15:29:57 +0800 Subject: [PATCH] replace DepthWiseConv with nn.Conv2D --- RELEASE.md | 16 +++--- mindspore/ops/operations/nn_ops.py | 3 + model_zoo/official/cv/centerface/README.md | 12 ++-- .../cv/centerface/src/convert_weight.py | 12 ---- .../src/convert_weight_mobilenetv2.py | 6 -- .../official/cv/centerface/src/mobile_v2.py | 33 +---------- model_zoo/official/cv/centerface/src/utils.py | 4 +- model_zoo/official/cv/centerface/train.py | 4 +- .../cv/efficientnet/src/efficientnet.py | 44 +++----------- .../cv/ssd_ghostnet/src/ssd_ghostnet.py | 57 ++----------------- model_zoo/research/cv/tinynet/src/tinynet.py | 55 +++--------------- model_zoo/research/nlp/dscnn/src/ds_cnn.py | 33 +---------- 12 files changed, 50 insertions(+), 229 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 12dccf126ec..026216075a0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -20,22 +20,22 @@ Previously the kernel size and pad mode attrs of pooling ops are named "ksize" a ```python ->>> from mindspore.ops import operations as P +>>> import mindspore.ops as ops >>> ->>> avg_pool = P.AvgPool(ksize=2, padding='same') ->>> max_pool = P.MaxPool(ksize=2, padding='same') ->>> max_pool_with_argmax = P.MaxPoolWithArgmax(ksize=2, padding='same') +>>> avg_pool = ops.AvgPool(ksize=2, padding='same') +>>> max_pool = ops.MaxPool(ksize=2, padding='same') +>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(ksize=2, padding='same') ``` ```python ->>> from mindspore.ops import operations as P +>>> import mindspore.ops as ops >>> ->>> avg_pool = P.AvgPool(kernel_size=2, pad_mode='same') ->>> max_pool = P.MaxPool(kernel_size=2, pad_mode='same') ->>> max_pool_with_argmax = P.MaxPoolWithArgmax(kernel_size=2, pad_mode='same') +>>> avg_pool = ops.AvgPool(kernel_size=2, pad_mode='same') +>>> max_pool = ops.MaxPool(kernel_size=2, pad_mode='same') +>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(kernel_size=2, pad_mode='same') ``` diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 60f140b3288..6659e9dc54d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -18,6 +18,7 @@ import math import operator from functools import reduce, partial +from mindspore import log as logger from mindspore._checkparam import _check_3d_int_or_tuple import numpy as np from ... import context @@ -1476,6 +1477,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): dilation=1, group=1): """Initialize DepthwiseConv2dNative""" + logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated." + " Please use nn.Conv2D.") self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) self.stride = _check_positive_int_or_tuple('stride', stride, self.name) diff --git a/model_zoo/official/cv/centerface/README.md b/model_zoo/official/cv/centerface/README.md index 7650f574097..119d37f04d9 100644 --- a/model_zoo/official/cv/centerface/README.md +++ b/model_zoo/official/cv/centerface/README.md @@ -102,7 +102,7 @@ step1: prepare pretrained model: train a mobilenet_v2 model by mindspore or use # The key/cell/module name must as follow, otherwise you need to modify "name_map" function: # --mindspore: as the same as mobilenet_v2_key.ckpt # --pytorch: same as official pytorch model(e.g., official mobilenet_v2-b0353104.pth) -python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt +python convert_weight_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt ``` step2: prepare user rank_table @@ -120,7 +120,7 @@ step3: train cd scripts; # prepare data_path, use symbolic link ln -sf [USE_DATA_DIR] dataset -# check you dir to make sure your datas are in the right path +# check you dir to make sure your data are in the right path ls ./dataset/centerface # data path ls ./dataset/centerface/annotations/train.json # annot_path ls ./dataset/centerface/images/train/images # img_dir @@ -147,7 +147,7 @@ python setup.py install; # used for eval cd -; #cd ../../scripts; mkdir ./output mkdir ./output/centerface -# check you dir to make sure your datas are in the right path +# check you dir to make sure your data are in the right path ls ./dataset/images/val/images/ # data path ls ./dataset/centerface/ground_truth/val.mat # annot_path ``` @@ -195,7 +195,7 @@ sh eval_all.sh │ ├──lr_scheduler.py // learning rate scheduler │ ├──mobile_v2.py // modified mobilenet_v2 backbone │ ├──utils.py // auxiliary functions for train, to log and preload - │ ├──var_init.py // weight initilization + │ ├──var_init.py // weight initialization │ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore │ ├──convert_weight.py // CenterFace model convert to mindspore └── dependency // third party codes: MIT License @@ -414,7 +414,7 @@ After testing, you can find many txt file save the box information and scores, open it you can see: ```python -646.3 189.1 42.1 51.8 0.747 # left top hight weight score +646.3 189.1 42.1 51.8 0.747 # left top height weight score 157.4 408.6 43.1 54.1 0.667 120.3 212.4 38.7 42.8 0.650 ... @@ -553,7 +553,7 @@ CenterFace on 3.2K images(The annotation and data format must be the same as wid # [Description of Random Situation](#contents) In dataset.py, we set the seed inside ```create_dataset``` function. -In var_init.py, we set seed for weight initilization +In var_init.py, we set seed for weight initialization # [ModelZoo Homepage](#contents) diff --git a/model_zoo/official/cv/centerface/src/convert_weight.py b/model_zoo/official/cv/centerface/src/convert_weight.py index c3dcc8c872e..66fe32c69c5 100644 --- a/model_zoo/official/cv/centerface/src/convert_weight.py +++ b/model_zoo/official/cv/centerface/src/convert_weight.py @@ -133,11 +133,6 @@ def pt_to_ckpt(pt, ckpt, out_path): parameter = state_dict_torch[key] parameter = parameter.numpy() - # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout - if state_dict_ms[name_relate[key]].data.shape != parameter.shape: - parameter = parameter.transpose(1, 0, 2, 3) - print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) - param_dict['name'] = name_relate[key] param_dict['data'] = Tensor(parameter) new_params_list.append(param_dict) @@ -158,13 +153,6 @@ def ckpt_to_pt(pt, ckpt, out_path): name = name_relate[key] parameter = state_dict_ms[name].data parameter = parameter.asnumpy() - if state_dict_ms[name_relate[key]].data.shape != state_dict_torch[key].numpy().shape: - print('before ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', - state_dict_torch[key].numpy().shape, 'name=', key) - parameter = parameter.transpose(1, 0, 2, 3) - print('after ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', - state_dict_torch[key].numpy().shape, 'name=', key) - state_dict[key] = torch.from_numpy(parameter) save_model(out_path, epoch=0, model=None, optimizer=None, state_dict=state_dict) diff --git a/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py index e4ef3bf4270..798da8b8f00 100644 --- a/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py +++ b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py @@ -120,12 +120,6 @@ def pt_to_ckpt(pt, ckpt, out_ckpt): parameter = state_dict_torch[key] parameter = parameter.numpy() - # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout - if state_dict_ms[name_relate[key]].data.shape != parameter.shape: - parameter = parameter.transpose(1, 0, 2, 3) - print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) - - param_dict['name'] = name_relate[key] param_dict['data'] = Tensor(parameter) new_params_list.append(param_dict) diff --git a/model_zoo/official/cv/centerface/src/mobile_v2.py b/model_zoo/official/cv/centerface/src/mobile_v2.py index 9345c554e5f..2b84a002a37 100644 --- a/model_zoo/official/cv/centerface/src/mobile_v2.py +++ b/model_zoo/official/cv/centerface/src/mobile_v2.py @@ -17,12 +17,10 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.ops.operations import TensorAdd -from mindspore import Parameter -from mindspore.common.initializer import initializer from src.var_init import KaimingNormal -__all__ = ['MobileNetV2', 'mobilenet_v2', 'DepthWiseConv'] +__all__ = ['MobileNetV2', 'mobilenet_v2'] def _make_divisible(v, divisor, min_value=None): """ @@ -43,32 +41,6 @@ def _make_divisible(v, divisor, min_value=None): new_v += divisor return new_v -class DepthWiseConv(nn.Cell): - """ - Depthwise convolution - """ - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthWiseConv, self).__init__() - self.has_bias = has_bias - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - - weight_shape = [channel_multiplier, in_planes, kernel_size, kernel_size] - self.weight = Parameter(initializer(KaimingNormal(mode='fan_out'), weight_shape)) - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape)) - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - class ConvBNReLU(nn.Cell): """ @@ -81,7 +53,8 @@ class ConvBNReLU(nn.Cell): conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding, has_bias=False) else: - conv = DepthWiseConv(in_planes, kernel_size, stride, pad_mode="pad", pad=padding) + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding, + has_bias=False, group=groups, weight_init=KaimingNormal(mode='fan_out')) layers = [conv, nn.BatchNorm2d(out_planes).add_flags_recursive(fp32=True), nn.ReLU6()] #, momentum=0.9 self.features = nn.SequentialCell(layers) diff --git a/model_zoo/official/cv/centerface/src/utils.py b/model_zoo/official/cv/centerface/src/utils.py index bbf48bca2f5..5f37d7132dd 100644 --- a/model_zoo/official/cv/centerface/src/utils.py +++ b/model_zoo/official/cv/centerface/src/utils.py @@ -24,8 +24,6 @@ import numpy as np from mindspore.train.serialization import load_checkpoint import mindspore.nn as nn -from src.mobile_v2 import DepthWiseConv - def load_backbone(net, ckpt_path, args): """ Load backbone @@ -52,7 +50,7 @@ def load_backbone(net, ckpt_path, args): for name, cell in net.cells_and_names(): if name.startswith(centerface_backbone_prefix): name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix) - if isinstance(cell, (nn.Conv2d, nn.Dense, DepthWiseConv)): + if isinstance(cell, (nn.Conv2d, nn.Dense)): name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx) mobilev2_weight = '{}.weight'.format(name) mobilev2_bias = '{}.bias'.format(name) diff --git a/model_zoo/official/cv/centerface/train.py b/model_zoo/official/cv/centerface/train.py index d8dfde1b2cc..a25ca33f2a8 100644 --- a/model_zoo/official/cv/centerface/train.py +++ b/model_zoo/official/cv/centerface/train.py @@ -33,6 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.profiler.profiling import Profiler +from mindspore.common import set_seed from src.utils import get_logger from src.utils import AverageMeter @@ -47,6 +48,7 @@ from src.config import ConfigCenterface from src.centerface import CenterFaceWithLossCell, TrainingWrapper from src.dataset import GetDataLoader +set_seed(1) dev_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False, device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False) @@ -130,7 +132,7 @@ if __name__ == "__main__": args.rank = get_rank() args.group_size = get_group_size() - # select for master rank save ckpt or all rank save, compatiable for model parallel + # select for master rank save ckpt or all rank save, compatible for model parallel args.rank_save_ckpt_flag = 0 if args.is_save_on_master: if args.rank == 0: diff --git a/model_zoo/official/cv/efficientnet/src/efficientnet.py b/model_zoo/official/cv/efficientnet/src/efficientnet.py index 01bfbf9f4da..8b43ad25d9c 100644 --- a/model_zoo/official/cv/efficientnet/src/efficientnet.py +++ b/model_zoo/official/cv/efficientnet/src/efficientnet.py @@ -20,10 +20,8 @@ from copy import deepcopy import mindspore as ms import mindspore.nn as nn -from mindspore import context, ms_function -from mindspore.common.initializer import (Normal, One, Uniform, Zero, - initializer) -from mindspore.common.parameter import Parameter +from mindspore import ms_function +from mindspore.common.initializer import (Normal, One, Uniform, Zero) from mindspore.ops import operations as P from mindspore.ops.composite import clip_by_value @@ -224,13 +222,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): # activation fn key = op[0] v = op[1:] - if v == 're': - print('not support') - elif v == 'r6': - print('not support') - elif v == 'hs': - print('not support') - elif v == 'sw': + if v in ('re', 'r6', 'hs', 'sw'): print('not support') else: continue @@ -459,28 +451,6 @@ class BlockBuilder(nn.Cell): return self.layer(x) -class DepthWiseConv(nn.Cell): - def __init__(self, in_planes, kernel_size, stride): - super(DepthWiseConv, self).__init__() - platform = context.get_context("device_target") - weight_shape = [1, kernel_size, in_planes] - weight_init = _initialize_weight_goog(shape=weight_shape) - if platform == "GPU": - self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size, - stride=stride, pad_mode="same", group=in_planes) - self.weight = Parameter(initializer( - weight_init, [in_planes * 1, 1, kernel_size, kernel_size])) - else: - self.depthwise_conv = P.DepthwiseConv2dNative( - channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',) - self.weight = Parameter(initializer( - weight_init, [1, in_planes, kernel_size, kernel_size])) - - def construct(self, x): - x = self.depthwise_conv(x, self.weight) - return x - - class DropConnect(nn.Cell): def __init__(self, drop_connect_rate=0., seed0=0, seed1=0): super(DropConnect, self).__init__() @@ -540,7 +510,9 @@ class DepthwiseSeparableConv(nn.Cell): self.has_pw_act = pw_act self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate - self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) + self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="same", + has_bias=False, group=in_chs, + weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs])) self.bn1 = _fused_bn(in_chs, **bn_args) # @@ -595,7 +567,9 @@ class InvertedResidual(nn.Cell): if self.shuffle_type is not None and isinstance(exp_kernel_size, list): self.shuffle = None - self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="same", + has_bias=False, group=mid_chs, + weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs])) self.bn2 = _fused_bn(mid_chs, **bn_args) if self.has_se: diff --git a/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py b/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py index f25deba99e7..ccfbece691e 100644 --- a/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py +++ b/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py @@ -20,13 +20,12 @@ import numpy as np import mindspore.common.dtype as mstype import mindspore as ms import mindspore.nn as nn -from mindspore import Parameter, context, Tensor +from mindspore import context, Tensor from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import get_group_size from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C -from mindspore.common.initializer import initializer def _make_divisible(x, divisor=4): @@ -44,8 +43,8 @@ def _bn(channel): def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): - depthwise_conv = DepthwiseConv( - in_channel, kernel_size, stride, pad_mode='same', pad=pad) + depthwise_conv = nn.Conv2d(in_channel, in_channel, kernel_size, stride, pad_mode='same', padding=pad, + has_bias=False, group=in_channel, weight_init='ones') conv = _conv2d(in_channel, out_channel, kernel_size=1) return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) @@ -75,8 +74,8 @@ class ConvBNReLU(nn.Cell): conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', padding=padding) else: - conv = DepthwiseConv(in_planes, kernel_size, - stride, pad_mode='same', pad=padding) + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', padding=padding, + has_bias=False, group=groups, weight_init='ones') layers = [conv, _bn(out_planes)] if use_act: layers.append(Activation(act_type)) @@ -87,52 +86,6 @@ class ConvBNReLU(nn.Cell): return output -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name="weight") - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name="bias") - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - class MyHSigmoid(nn.Cell): def __init__(self): super(MyHSigmoid, self).__init__() diff --git a/model_zoo/research/cv/tinynet/src/tinynet.py b/model_zoo/research/cv/tinynet/src/tinynet.py index 50b518b937d..ebb5663872b 100755 --- a/model_zoo/research/cv/tinynet/src/tinynet.py +++ b/model_zoo/research/cv/tinynet/src/tinynet.py @@ -20,9 +20,8 @@ from copy import deepcopy import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P -from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform -from mindspore import context, ms_function -from mindspore.common.parameter import Parameter +from mindspore.common.initializer import Normal, Zero, One, Uniform +from mindspore import ms_function from mindspore import Tensor # Imagenet constant values @@ -244,13 +243,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): # activation fn key = op[0] v = op[1:] - if v == 're': - print('not support') - elif v == 'r6': - print('not support') - elif v == 'hs': - print('not support') - elif v == 'sw': + if v in ('re', 'r6', 'hs', 'sw'): print('not support') else: continue @@ -485,40 +478,6 @@ class BlockBuilder(nn.Cell): return self.layer(x) -class DepthWiseConv(nn.Cell): - """depth-wise convolution""" - - def __init__(self, in_planes, kernel_size, stride): - super(DepthWiseConv, self).__init__() - platform = context.get_context("device_target") - weight_shape = [1, kernel_size, in_planes] - weight_init = _initialize_weight_goog(shape=weight_shape) - - if platform == "GPU": - self.depthwise_conv = P.Conv2D(out_channel=in_planes*1, - kernel_size=kernel_size, - stride=stride, - pad=int(kernel_size/2), - pad_mode="pad", - group=in_planes) - - self.weight = Parameter(initializer(weight_init, - [in_planes*1, 1, kernel_size, kernel_size])) - - else: - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=kernel_size, - stride=stride, pad_mode='pad', - pad=int(kernel_size/2)) - - self.weight = Parameter(initializer(weight_init, - [1, in_planes, kernel_size, kernel_size])) - - def construct(self, x): - x = self.depthwise_conv(x, self.weight) - return x - - class DropConnect(nn.Cell): """drop connect implementation""" @@ -584,7 +543,9 @@ class DepthwiseSeparableConv(nn.Cell): self.has_pw_act = pw_act self.act_fn = Swish() self.drop_connect_rate = drop_connect_rate - self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride) + self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad", + padding=int(dw_kernel_size/2), has_bias=False, group=in_chs, + weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs])) self.bn1 = _fused_bn(in_chs, **bn_args) if self.has_se: @@ -640,7 +601,9 @@ class InvertedResidual(nn.Cell): if self.shuffle_type is not None and isinstance(exp_kernel_size, list): self.shuffle = None - self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride) + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad", + padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs, + weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs])) self.bn2 = _fused_bn(mid_chs, **bn_args) if self.has_se: diff --git a/model_zoo/research/nlp/dscnn/src/ds_cnn.py b/model_zoo/research/nlp/dscnn/src/ds_cnn.py index f7daaa5bfa7..0cad78fdde0 100644 --- a/model_zoo/research/nlp/dscnn/src/ds_cnn.py +++ b/model_zoo/research/nlp/dscnn/src/ds_cnn.py @@ -16,34 +16,6 @@ import math import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.initializer import initializer -from mindspore import Parameter - - -class DepthWiseConv(nn.Cell): - '''Build DepthWise conv.''' - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthWiseConv, self).__init__() - self.has_bias = has_bias - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - - weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]] - self.weight = Parameter(initializer('ones', weight_shape)) - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape)) - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output class DSCNN(nn.Cell): @@ -85,8 +57,9 @@ class DSCNN(nn.Cell): seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98)) in_channel = conv_feat[layer_no] else: - seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]), - stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0)) + seq_cell.append(nn.Conv2d(in_channel, in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]), + stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', + has_bias=False, group=in_channel, weight_init='ones')) seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98)) seq_cell.append(nn.ReLU()) seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),