diff --git a/RELEASE.md b/RELEASE.md
index 12dccf126ec..026216075a0 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -20,22 +20,22 @@ Previously the kernel size and pad mode attrs of pooling ops are named "ksize" a
```python
->>> from mindspore.ops import operations as P
+>>> import mindspore.ops as ops
>>>
->>> avg_pool = P.AvgPool(ksize=2, padding='same')
->>> max_pool = P.MaxPool(ksize=2, padding='same')
->>> max_pool_with_argmax = P.MaxPoolWithArgmax(ksize=2, padding='same')
+>>> avg_pool = ops.AvgPool(ksize=2, padding='same')
+>>> max_pool = ops.MaxPool(ksize=2, padding='same')
+>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(ksize=2, padding='same')
```
|
```python
->>> from mindspore.ops import operations as P
+>>> import mindspore.ops as ops
>>>
->>> avg_pool = P.AvgPool(kernel_size=2, pad_mode='same')
->>> max_pool = P.MaxPool(kernel_size=2, pad_mode='same')
->>> max_pool_with_argmax = P.MaxPoolWithArgmax(kernel_size=2, pad_mode='same')
+>>> avg_pool = ops.AvgPool(kernel_size=2, pad_mode='same')
+>>> max_pool = ops.MaxPool(kernel_size=2, pad_mode='same')
+>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(kernel_size=2, pad_mode='same')
```
|
diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py
index 60f140b3288..6659e9dc54d 100644
--- a/mindspore/ops/operations/nn_ops.py
+++ b/mindspore/ops/operations/nn_ops.py
@@ -18,6 +18,7 @@
import math
import operator
from functools import reduce, partial
+from mindspore import log as logger
from mindspore._checkparam import _check_3d_int_or_tuple
import numpy as np
from ... import context
@@ -1476,6 +1477,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
dilation=1,
group=1):
"""Initialize DepthwiseConv2dNative"""
+ logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated."
+ " Please use nn.Conv2D.")
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
diff --git a/model_zoo/official/cv/centerface/README.md b/model_zoo/official/cv/centerface/README.md
index 7650f574097..119d37f04d9 100644
--- a/model_zoo/official/cv/centerface/README.md
+++ b/model_zoo/official/cv/centerface/README.md
@@ -102,7 +102,7 @@ step1: prepare pretrained model: train a mobilenet_v2 model by mindspore or use
# The key/cell/module name must as follow, otherwise you need to modify "name_map" function:
# --mindspore: as the same as mobilenet_v2_key.ckpt
# --pytorch: same as official pytorch model(e.g., official mobilenet_v2-b0353104.pth)
-python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt
+python convert_weight_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt
```
step2: prepare user rank_table
@@ -120,7 +120,7 @@ step3: train
cd scripts;
# prepare data_path, use symbolic link
ln -sf [USE_DATA_DIR] dataset
-# check you dir to make sure your datas are in the right path
+# check you dir to make sure your data are in the right path
ls ./dataset/centerface # data path
ls ./dataset/centerface/annotations/train.json # annot_path
ls ./dataset/centerface/images/train/images # img_dir
@@ -147,7 +147,7 @@ python setup.py install; # used for eval
cd -; #cd ../../scripts;
mkdir ./output
mkdir ./output/centerface
-# check you dir to make sure your datas are in the right path
+# check you dir to make sure your data are in the right path
ls ./dataset/images/val/images/ # data path
ls ./dataset/centerface/ground_truth/val.mat # annot_path
```
@@ -195,7 +195,7 @@ sh eval_all.sh
│ ├──lr_scheduler.py // learning rate scheduler
│ ├──mobile_v2.py // modified mobilenet_v2 backbone
│ ├──utils.py // auxiliary functions for train, to log and preload
- │ ├──var_init.py // weight initilization
+ │ ├──var_init.py // weight initialization
│ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore
│ ├──convert_weight.py // CenterFace model convert to mindspore
└── dependency // third party codes: MIT License
@@ -414,7 +414,7 @@ After testing, you can find many txt file save the box information and scores,
open it you can see:
```python
-646.3 189.1 42.1 51.8 0.747 # left top hight weight score
+646.3 189.1 42.1 51.8 0.747 # left top height weight score
157.4 408.6 43.1 54.1 0.667
120.3 212.4 38.7 42.8 0.650
...
@@ -553,7 +553,7 @@ CenterFace on 3.2K images(The annotation and data format must be the same as wid
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside ```create_dataset``` function.
-In var_init.py, we set seed for weight initilization
+In var_init.py, we set seed for weight initialization
# [ModelZoo Homepage](#contents)
diff --git a/model_zoo/official/cv/centerface/src/convert_weight.py b/model_zoo/official/cv/centerface/src/convert_weight.py
index c3dcc8c872e..66fe32c69c5 100644
--- a/model_zoo/official/cv/centerface/src/convert_weight.py
+++ b/model_zoo/official/cv/centerface/src/convert_weight.py
@@ -133,11 +133,6 @@ def pt_to_ckpt(pt, ckpt, out_path):
parameter = state_dict_torch[key]
parameter = parameter.numpy()
- # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout
- if state_dict_ms[name_relate[key]].data.shape != parameter.shape:
- parameter = parameter.transpose(1, 0, 2, 3)
- print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key)
-
param_dict['name'] = name_relate[key]
param_dict['data'] = Tensor(parameter)
new_params_list.append(param_dict)
@@ -158,13 +153,6 @@ def ckpt_to_pt(pt, ckpt, out_path):
name = name_relate[key]
parameter = state_dict_ms[name].data
parameter = parameter.asnumpy()
- if state_dict_ms[name_relate[key]].data.shape != state_dict_torch[key].numpy().shape:
- print('before ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=',
- state_dict_torch[key].numpy().shape, 'name=', key)
- parameter = parameter.transpose(1, 0, 2, 3)
- print('after ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=',
- state_dict_torch[key].numpy().shape, 'name=', key)
-
state_dict[key] = torch.from_numpy(parameter)
save_model(out_path, epoch=0, model=None, optimizer=None, state_dict=state_dict)
diff --git a/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py
index e4ef3bf4270..798da8b8f00 100644
--- a/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py
+++ b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py
@@ -120,12 +120,6 @@ def pt_to_ckpt(pt, ckpt, out_ckpt):
parameter = state_dict_torch[key]
parameter = parameter.numpy()
- # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout
- if state_dict_ms[name_relate[key]].data.shape != parameter.shape:
- parameter = parameter.transpose(1, 0, 2, 3)
- print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key)
-
-
param_dict['name'] = name_relate[key]
param_dict['data'] = Tensor(parameter)
new_params_list.append(param_dict)
diff --git a/model_zoo/official/cv/centerface/src/mobile_v2.py b/model_zoo/official/cv/centerface/src/mobile_v2.py
index 9345c554e5f..2b84a002a37 100644
--- a/model_zoo/official/cv/centerface/src/mobile_v2.py
+++ b/model_zoo/official/cv/centerface/src/mobile_v2.py
@@ -17,12 +17,10 @@
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd
-from mindspore import Parameter
-from mindspore.common.initializer import initializer
from src.var_init import KaimingNormal
-__all__ = ['MobileNetV2', 'mobilenet_v2', 'DepthWiseConv']
+__all__ = ['MobileNetV2', 'mobilenet_v2']
def _make_divisible(v, divisor, min_value=None):
"""
@@ -43,32 +41,6 @@ def _make_divisible(v, divisor, min_value=None):
new_v += divisor
return new_v
-class DepthWiseConv(nn.Cell):
- """
- Depthwise convolution
- """
- def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
- super(DepthWiseConv, self).__init__()
- self.has_bias = has_bias
- self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
- stride=stride, pad_mode=pad_mode, pad=pad)
- self.bias_add = P.BiasAdd()
-
- weight_shape = [channel_multiplier, in_planes, kernel_size, kernel_size]
- self.weight = Parameter(initializer(KaimingNormal(mode='fan_out'), weight_shape))
-
- if has_bias:
- bias_shape = [channel_multiplier * in_planes]
- self.bias = Parameter(initializer('zeros', bias_shape))
- else:
- self.bias = None
-
- def construct(self, x):
- output = self.depthwise_conv(x, self.weight)
- if self.has_bias:
- output = self.bias_add(output, self.bias)
- return output
-
class ConvBNReLU(nn.Cell):
"""
@@ -81,7 +53,8 @@ class ConvBNReLU(nn.Cell):
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding,
has_bias=False)
else:
- conv = DepthWiseConv(in_planes, kernel_size, stride, pad_mode="pad", pad=padding)
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding,
+ has_bias=False, group=groups, weight_init=KaimingNormal(mode='fan_out'))
layers = [conv, nn.BatchNorm2d(out_planes).add_flags_recursive(fp32=True), nn.ReLU6()] #, momentum=0.9
self.features = nn.SequentialCell(layers)
diff --git a/model_zoo/official/cv/centerface/src/utils.py b/model_zoo/official/cv/centerface/src/utils.py
index bbf48bca2f5..5f37d7132dd 100644
--- a/model_zoo/official/cv/centerface/src/utils.py
+++ b/model_zoo/official/cv/centerface/src/utils.py
@@ -24,8 +24,6 @@ import numpy as np
from mindspore.train.serialization import load_checkpoint
import mindspore.nn as nn
-from src.mobile_v2 import DepthWiseConv
-
def load_backbone(net, ckpt_path, args):
"""
Load backbone
@@ -52,7 +50,7 @@ def load_backbone(net, ckpt_path, args):
for name, cell in net.cells_and_names():
if name.startswith(centerface_backbone_prefix):
name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix)
- if isinstance(cell, (nn.Conv2d, nn.Dense, DepthWiseConv)):
+ if isinstance(cell, (nn.Conv2d, nn.Dense)):
name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
mobilev2_weight = '{}.weight'.format(name)
mobilev2_bias = '{}.bias'.format(name)
diff --git a/model_zoo/official/cv/centerface/train.py b/model_zoo/official/cv/centerface/train.py
index d8dfde1b2cc..a25ca33f2a8 100644
--- a/model_zoo/official/cv/centerface/train.py
+++ b/model_zoo/official/cv/centerface/train.py
@@ -33,6 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.profiler.profiling import Profiler
+from mindspore.common import set_seed
from src.utils import get_logger
from src.utils import AverageMeter
@@ -47,6 +48,7 @@ from src.config import ConfigCenterface
from src.centerface import CenterFaceWithLossCell, TrainingWrapper
from src.dataset import GetDataLoader
+set_seed(1)
dev_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
@@ -130,7 +132,7 @@ if __name__ == "__main__":
args.rank = get_rank()
args.group_size = get_group_size()
- # select for master rank save ckpt or all rank save, compatiable for model parallel
+ # select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
diff --git a/model_zoo/official/cv/efficientnet/src/efficientnet.py b/model_zoo/official/cv/efficientnet/src/efficientnet.py
index 01bfbf9f4da..8b43ad25d9c 100644
--- a/model_zoo/official/cv/efficientnet/src/efficientnet.py
+++ b/model_zoo/official/cv/efficientnet/src/efficientnet.py
@@ -20,10 +20,8 @@ from copy import deepcopy
import mindspore as ms
import mindspore.nn as nn
-from mindspore import context, ms_function
-from mindspore.common.initializer import (Normal, One, Uniform, Zero,
- initializer)
-from mindspore.common.parameter import Parameter
+from mindspore import ms_function
+from mindspore.common.initializer import (Normal, One, Uniform, Zero)
from mindspore.ops import operations as P
from mindspore.ops.composite import clip_by_value
@@ -224,13 +222,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
# activation fn
key = op[0]
v = op[1:]
- if v == 're':
- print('not support')
- elif v == 'r6':
- print('not support')
- elif v == 'hs':
- print('not support')
- elif v == 'sw':
+ if v in ('re', 'r6', 'hs', 'sw'):
print('not support')
else:
continue
@@ -459,28 +451,6 @@ class BlockBuilder(nn.Cell):
return self.layer(x)
-class DepthWiseConv(nn.Cell):
- def __init__(self, in_planes, kernel_size, stride):
- super(DepthWiseConv, self).__init__()
- platform = context.get_context("device_target")
- weight_shape = [1, kernel_size, in_planes]
- weight_init = _initialize_weight_goog(shape=weight_shape)
- if platform == "GPU":
- self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size,
- stride=stride, pad_mode="same", group=in_planes)
- self.weight = Parameter(initializer(
- weight_init, [in_planes * 1, 1, kernel_size, kernel_size]))
- else:
- self.depthwise_conv = P.DepthwiseConv2dNative(
- channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',)
- self.weight = Parameter(initializer(
- weight_init, [1, in_planes, kernel_size, kernel_size]))
-
- def construct(self, x):
- x = self.depthwise_conv(x, self.weight)
- return x
-
-
class DropConnect(nn.Cell):
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
super(DropConnect, self).__init__()
@@ -540,7 +510,9 @@ class DepthwiseSeparableConv(nn.Cell):
self.has_pw_act = pw_act
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
- self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
+ self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="same",
+ has_bias=False, group=in_chs,
+ weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
self.bn1 = _fused_bn(in_chs, **bn_args)
#
@@ -595,7 +567,9 @@ class InvertedResidual(nn.Cell):
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = None
- self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
+ self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="same",
+ has_bias=False, group=mid_chs,
+ weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
self.bn2 = _fused_bn(mid_chs, **bn_args)
if self.has_se:
diff --git a/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py b/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py
index f25deba99e7..ccfbece691e 100644
--- a/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py
+++ b/model_zoo/research/cv/ssd_ghostnet/src/ssd_ghostnet.py
@@ -20,13 +20,12 @@ import numpy as np
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
-from mindspore import Parameter, context, Tensor
+from mindspore import context, Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
-from mindspore.common.initializer import initializer
def _make_divisible(x, divisor=4):
@@ -44,8 +43,8 @@ def _bn(channel):
def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
- depthwise_conv = DepthwiseConv(
- in_channel, kernel_size, stride, pad_mode='same', pad=pad)
+ depthwise_conv = nn.Conv2d(in_channel, in_channel, kernel_size, stride, pad_mode='same', padding=pad,
+ has_bias=False, group=in_channel, weight_init='ones')
conv = _conv2d(in_channel, out_channel, kernel_size=1)
return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])
@@ -75,8 +74,8 @@ class ConvBNReLU(nn.Cell):
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
padding=padding)
else:
- conv = DepthwiseConv(in_planes, kernel_size,
- stride, pad_mode='same', pad=padding)
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', padding=padding,
+ has_bias=False, group=groups, weight_init='ones')
layers = [conv, _bn(out_planes)]
if use_act:
layers.append(Activation(act_type))
@@ -87,52 +86,6 @@ class ConvBNReLU(nn.Cell):
return output
-class DepthwiseConv(nn.Cell):
- """
- Depthwise Convolution warpper definition.
-
- Args:
- in_planes (int): Input channel.
- kernel_size (int): Input kernel size.
- stride (int): Stride size.
- pad_mode (str): pad mode in (pad, same, valid)
- channel_multiplier (int): Output channel multiplier
- has_bias (bool): has bias or not
-
- Returns:
- Tensor, output tensor.
-
- Examples:
- >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
- """
-
- def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
- super(DepthwiseConv, self).__init__()
- self.has_bias = has_bias
- self.in_channels = in_planes
- self.channel_multiplier = channel_multiplier
- self.out_channels = in_planes * channel_multiplier
- self.kernel_size = (kernel_size, kernel_size)
- self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
- kernel_size=self.kernel_size,
- stride=stride, pad_mode=pad_mode, pad=pad)
- self.bias_add = P.BiasAdd()
- weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
- self.weight = Parameter(initializer('ones', weight_shape), name="weight")
-
- if has_bias:
- bias_shape = [channel_multiplier * in_planes]
- self.bias = Parameter(initializer('zeros', bias_shape), name="bias")
- else:
- self.bias = None
-
- def construct(self, x):
- output = self.depthwise_conv(x, self.weight)
- if self.has_bias:
- output = self.bias_add(output, self.bias)
- return output
-
-
class MyHSigmoid(nn.Cell):
def __init__(self):
super(MyHSigmoid, self).__init__()
diff --git a/model_zoo/research/cv/tinynet/src/tinynet.py b/model_zoo/research/cv/tinynet/src/tinynet.py
index 50b518b937d..ebb5663872b 100755
--- a/model_zoo/research/cv/tinynet/src/tinynet.py
+++ b/model_zoo/research/cv/tinynet/src/tinynet.py
@@ -20,9 +20,8 @@ from copy import deepcopy
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
-from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
-from mindspore import context, ms_function
-from mindspore.common.parameter import Parameter
+from mindspore.common.initializer import Normal, Zero, One, Uniform
+from mindspore import ms_function
from mindspore import Tensor
# Imagenet constant values
@@ -244,13 +243,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
# activation fn
key = op[0]
v = op[1:]
- if v == 're':
- print('not support')
- elif v == 'r6':
- print('not support')
- elif v == 'hs':
- print('not support')
- elif v == 'sw':
+ if v in ('re', 'r6', 'hs', 'sw'):
print('not support')
else:
continue
@@ -485,40 +478,6 @@ class BlockBuilder(nn.Cell):
return self.layer(x)
-class DepthWiseConv(nn.Cell):
- """depth-wise convolution"""
-
- def __init__(self, in_planes, kernel_size, stride):
- super(DepthWiseConv, self).__init__()
- platform = context.get_context("device_target")
- weight_shape = [1, kernel_size, in_planes]
- weight_init = _initialize_weight_goog(shape=weight_shape)
-
- if platform == "GPU":
- self.depthwise_conv = P.Conv2D(out_channel=in_planes*1,
- kernel_size=kernel_size,
- stride=stride,
- pad=int(kernel_size/2),
- pad_mode="pad",
- group=in_planes)
-
- self.weight = Parameter(initializer(weight_init,
- [in_planes*1, 1, kernel_size, kernel_size]))
-
- else:
- self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1,
- kernel_size=kernel_size,
- stride=stride, pad_mode='pad',
- pad=int(kernel_size/2))
-
- self.weight = Parameter(initializer(weight_init,
- [1, in_planes, kernel_size, kernel_size]))
-
- def construct(self, x):
- x = self.depthwise_conv(x, self.weight)
- return x
-
-
class DropConnect(nn.Cell):
"""drop connect implementation"""
@@ -584,7 +543,9 @@ class DepthwiseSeparableConv(nn.Cell):
self.has_pw_act = pw_act
self.act_fn = Swish()
self.drop_connect_rate = drop_connect_rate
- self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
+ self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad",
+ padding=int(dw_kernel_size/2), has_bias=False, group=in_chs,
+ weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
self.bn1 = _fused_bn(in_chs, **bn_args)
if self.has_se:
@@ -640,7 +601,9 @@ class InvertedResidual(nn.Cell):
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = None
- self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
+ self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad",
+ padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs,
+ weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
self.bn2 = _fused_bn(mid_chs, **bn_args)
if self.has_se:
diff --git a/model_zoo/research/nlp/dscnn/src/ds_cnn.py b/model_zoo/research/nlp/dscnn/src/ds_cnn.py
index f7daaa5bfa7..0cad78fdde0 100644
--- a/model_zoo/research/nlp/dscnn/src/ds_cnn.py
+++ b/model_zoo/research/nlp/dscnn/src/ds_cnn.py
@@ -16,34 +16,6 @@
import math
import mindspore.nn as nn
-from mindspore.ops import operations as P
-from mindspore.common.initializer import initializer
-from mindspore import Parameter
-
-
-class DepthWiseConv(nn.Cell):
- '''Build DepthWise conv.'''
- def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
- super(DepthWiseConv, self).__init__()
- self.has_bias = has_bias
- self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
- stride=stride, pad_mode=pad_mode, pad=pad)
- self.bias_add = P.BiasAdd()
-
- weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]]
- self.weight = Parameter(initializer('ones', weight_shape))
-
- if has_bias:
- bias_shape = [channel_multiplier * in_planes]
- self.bias = Parameter(initializer('zeros', bias_shape))
- else:
- self.bias = None
-
- def construct(self, x):
- output = self.depthwise_conv(x, self.weight)
- if self.has_bias:
- output = self.bias_add(output, self.bias)
- return output
class DSCNN(nn.Cell):
@@ -85,8 +57,9 @@ class DSCNN(nn.Cell):
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
in_channel = conv_feat[layer_no]
else:
- seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
- stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0))
+ seq_cell.append(nn.Conv2d(in_channel, in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
+ stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same',
+ has_bias=False, group=in_channel, weight_init='ones'))
seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98))
seq_cell.append(nn.ReLU())
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),