forked from mindspore-Ecosystem/mindspore
replace DepthWiseConv with nn.Conv2D
This commit is contained in:
parent
a50a65adf9
commit
212d7ebd0b
16
RELEASE.md
16
RELEASE.md
|
@ -20,22 +20,22 @@ Previously the kernel size and pad mode attrs of pooling ops are named "ksize" a
|
|||
<td>
|
||||
|
||||
```python
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.ops as ops
|
||||
>>>
|
||||
>>> avg_pool = P.AvgPool(ksize=2, padding='same')
|
||||
>>> max_pool = P.MaxPool(ksize=2, padding='same')
|
||||
>>> max_pool_with_argmax = P.MaxPoolWithArgmax(ksize=2, padding='same')
|
||||
>>> avg_pool = ops.AvgPool(ksize=2, padding='same')
|
||||
>>> max_pool = ops.MaxPool(ksize=2, padding='same')
|
||||
>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(ksize=2, padding='same')
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> import mindspore.ops as ops
|
||||
>>>
|
||||
>>> avg_pool = P.AvgPool(kernel_size=2, pad_mode='same')
|
||||
>>> max_pool = P.MaxPool(kernel_size=2, pad_mode='same')
|
||||
>>> max_pool_with_argmax = P.MaxPoolWithArgmax(kernel_size=2, pad_mode='same')
|
||||
>>> avg_pool = ops.AvgPool(kernel_size=2, pad_mode='same')
|
||||
>>> max_pool = ops.MaxPool(kernel_size=2, pad_mode='same')
|
||||
>>> max_pool_with_argmax = ops.MaxPoolWithArgmax(kernel_size=2, pad_mode='same')
|
||||
```
|
||||
|
||||
</td>
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce, partial
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import _check_3d_int_or_tuple
|
||||
import numpy as np
|
||||
from ... import context
|
||||
|
@ -1476,6 +1477,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|||
dilation=1,
|
||||
group=1):
|
||||
"""Initialize DepthwiseConv2dNative"""
|
||||
logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated."
|
||||
" Please use nn.Conv2D.")
|
||||
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
|
||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
|
||||
|
|
|
@ -102,7 +102,7 @@ step1: prepare pretrained model: train a mobilenet_v2 model by mindspore or use
|
|||
# The key/cell/module name must as follow, otherwise you need to modify "name_map" function:
|
||||
# --mindspore: as the same as mobilenet_v2_key.ckpt
|
||||
# --pytorch: same as official pytorch model(e.g., official mobilenet_v2-b0353104.pth)
|
||||
python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt
|
||||
python convert_weight_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt
|
||||
```
|
||||
|
||||
step2: prepare user rank_table
|
||||
|
@ -120,7 +120,7 @@ step3: train
|
|||
cd scripts;
|
||||
# prepare data_path, use symbolic link
|
||||
ln -sf [USE_DATA_DIR] dataset
|
||||
# check you dir to make sure your datas are in the right path
|
||||
# check you dir to make sure your data are in the right path
|
||||
ls ./dataset/centerface # data path
|
||||
ls ./dataset/centerface/annotations/train.json # annot_path
|
||||
ls ./dataset/centerface/images/train/images # img_dir
|
||||
|
@ -147,7 +147,7 @@ python setup.py install; # used for eval
|
|||
cd -; #cd ../../scripts;
|
||||
mkdir ./output
|
||||
mkdir ./output/centerface
|
||||
# check you dir to make sure your datas are in the right path
|
||||
# check you dir to make sure your data are in the right path
|
||||
ls ./dataset/images/val/images/ # data path
|
||||
ls ./dataset/centerface/ground_truth/val.mat # annot_path
|
||||
```
|
||||
|
@ -195,7 +195,7 @@ sh eval_all.sh
|
|||
│ ├──lr_scheduler.py // learning rate scheduler
|
||||
│ ├──mobile_v2.py // modified mobilenet_v2 backbone
|
||||
│ ├──utils.py // auxiliary functions for train, to log and preload
|
||||
│ ├──var_init.py // weight initilization
|
||||
│ ├──var_init.py // weight initialization
|
||||
│ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore
|
||||
│ ├──convert_weight.py // CenterFace model convert to mindspore
|
||||
└── dependency // third party codes: MIT License
|
||||
|
@ -414,7 +414,7 @@ After testing, you can find many txt file save the box information and scores,
|
|||
open it you can see:
|
||||
|
||||
```python
|
||||
646.3 189.1 42.1 51.8 0.747 # left top hight weight score
|
||||
646.3 189.1 42.1 51.8 0.747 # left top height weight score
|
||||
157.4 408.6 43.1 54.1 0.667
|
||||
120.3 212.4 38.7 42.8 0.650
|
||||
...
|
||||
|
@ -553,7 +553,7 @@ CenterFace on 3.2K images(The annotation and data format must be the same as wid
|
|||
# [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside ```create_dataset``` function.
|
||||
In var_init.py, we set seed for weight initilization
|
||||
In var_init.py, we set seed for weight initialization
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
|
|
|
@ -133,11 +133,6 @@ def pt_to_ckpt(pt, ckpt, out_path):
|
|||
parameter = state_dict_torch[key]
|
||||
parameter = parameter.numpy()
|
||||
|
||||
# depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout
|
||||
if state_dict_ms[name_relate[key]].data.shape != parameter.shape:
|
||||
parameter = parameter.transpose(1, 0, 2, 3)
|
||||
print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key)
|
||||
|
||||
param_dict['name'] = name_relate[key]
|
||||
param_dict['data'] = Tensor(parameter)
|
||||
new_params_list.append(param_dict)
|
||||
|
@ -158,13 +153,6 @@ def ckpt_to_pt(pt, ckpt, out_path):
|
|||
name = name_relate[key]
|
||||
parameter = state_dict_ms[name].data
|
||||
parameter = parameter.asnumpy()
|
||||
if state_dict_ms[name_relate[key]].data.shape != state_dict_torch[key].numpy().shape:
|
||||
print('before ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=',
|
||||
state_dict_torch[key].numpy().shape, 'name=', key)
|
||||
parameter = parameter.transpose(1, 0, 2, 3)
|
||||
print('after ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=',
|
||||
state_dict_torch[key].numpy().shape, 'name=', key)
|
||||
|
||||
state_dict[key] = torch.from_numpy(parameter)
|
||||
|
||||
save_model(out_path, epoch=0, model=None, optimizer=None, state_dict=state_dict)
|
||||
|
|
|
@ -120,12 +120,6 @@ def pt_to_ckpt(pt, ckpt, out_ckpt):
|
|||
parameter = state_dict_torch[key]
|
||||
parameter = parameter.numpy()
|
||||
|
||||
# depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout
|
||||
if state_dict_ms[name_relate[key]].data.shape != parameter.shape:
|
||||
parameter = parameter.transpose(1, 0, 2, 3)
|
||||
print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key)
|
||||
|
||||
|
||||
param_dict['name'] = name_relate[key]
|
||||
param_dict['data'] = Tensor(parameter)
|
||||
new_params_list.append(param_dict)
|
||||
|
|
|
@ -17,12 +17,10 @@
|
|||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
from src.var_init import KaimingNormal
|
||||
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2', 'DepthWiseConv']
|
||||
__all__ = ['MobileNetV2', 'mobilenet_v2']
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
|
@ -43,32 +41,6 @@ def _make_divisible(v, divisor, min_value=None):
|
|||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
"""
|
||||
Depthwise convolution
|
||||
"""
|
||||
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
self.has_bias = has_bias
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
|
||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
weight_shape = [channel_multiplier, in_planes, kernel_size, kernel_size]
|
||||
self.weight = Parameter(initializer(KaimingNormal(mode='fan_out'), weight_shape))
|
||||
|
||||
if has_bias:
|
||||
bias_shape = [channel_multiplier * in_planes]
|
||||
self.bias = Parameter(initializer('zeros', bias_shape))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, x):
|
||||
output = self.depthwise_conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
|
@ -81,7 +53,8 @@ class ConvBNReLU(nn.Cell):
|
|||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding,
|
||||
has_bias=False)
|
||||
else:
|
||||
conv = DepthWiseConv(in_planes, kernel_size, stride, pad_mode="pad", pad=padding)
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding,
|
||||
has_bias=False, group=groups, weight_init=KaimingNormal(mode='fan_out'))
|
||||
|
||||
layers = [conv, nn.BatchNorm2d(out_planes).add_flags_recursive(fp32=True), nn.ReLU6()] #, momentum=0.9
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
|
|
@ -24,8 +24,6 @@ import numpy as np
|
|||
from mindspore.train.serialization import load_checkpoint
|
||||
import mindspore.nn as nn
|
||||
|
||||
from src.mobile_v2 import DepthWiseConv
|
||||
|
||||
def load_backbone(net, ckpt_path, args):
|
||||
"""
|
||||
Load backbone
|
||||
|
@ -52,7 +50,7 @@ def load_backbone(net, ckpt_path, args):
|
|||
for name, cell in net.cells_and_names():
|
||||
if name.startswith(centerface_backbone_prefix):
|
||||
name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix)
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense, DepthWiseConv)):
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||
name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
|
||||
mobilev2_weight = '{}.weight'.format(name)
|
||||
mobilev2_bias = '{}.bias'.format(name)
|
||||
|
|
|
@ -33,6 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, RunContext
|
|||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.profiler.profiling import Profiler
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.utils import get_logger
|
||||
from src.utils import AverageMeter
|
||||
|
@ -47,6 +48,7 @@ from src.config import ConfigCenterface
|
|||
from src.centerface import CenterFaceWithLossCell, TrainingWrapper
|
||||
from src.dataset import GetDataLoader
|
||||
|
||||
set_seed(1)
|
||||
dev_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
|
||||
device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
|
||||
|
@ -130,7 +132,7 @@ if __name__ == "__main__":
|
|||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatiable for model parallel
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
|
|
|
@ -20,10 +20,8 @@ from copy import deepcopy
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, ms_function
|
||||
from mindspore.common.initializer import (Normal, One, Uniform, Zero,
|
||||
initializer)
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore import ms_function
|
||||
from mindspore.common.initializer import (Normal, One, Uniform, Zero)
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import clip_by_value
|
||||
|
||||
|
@ -224,13 +222,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
|||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
print('not support')
|
||||
elif v == 'r6':
|
||||
print('not support')
|
||||
elif v == 'hs':
|
||||
print('not support')
|
||||
elif v == 'sw':
|
||||
if v in ('re', 'r6', 'hs', 'sw'):
|
||||
print('not support')
|
||||
else:
|
||||
continue
|
||||
|
@ -459,28 +451,6 @@ class BlockBuilder(nn.Cell):
|
|||
return self.layer(x)
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
def __init__(self, in_planes, kernel_size, stride):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
platform = context.get_context("device_target")
|
||||
weight_shape = [1, kernel_size, in_planes]
|
||||
weight_init = _initialize_weight_goog(shape=weight_shape)
|
||||
if platform == "GPU":
|
||||
self.depthwise_conv = P.Conv2D(out_channel=in_planes * 1, kernel_size=kernel_size,
|
||||
stride=stride, pad_mode="same", group=in_planes)
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [in_planes * 1, 1, kernel_size, kernel_size]))
|
||||
else:
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(
|
||||
channel_multiplier=1, kernel_size=kernel_size, stride=stride, pad_mode='same',)
|
||||
self.weight = Parameter(initializer(
|
||||
weight_init, [1, in_planes, kernel_size, kernel_size]))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv(x, self.weight)
|
||||
return x
|
||||
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
|
||||
super(DropConnect, self).__init__()
|
||||
|
@ -540,7 +510,9 @@ class DepthwiseSeparableConv(nn.Cell):
|
|||
self.has_pw_act = pw_act
|
||||
self.act_fn = act_fn
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
|
||||
self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="same",
|
||||
has_bias=False, group=in_chs,
|
||||
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
|
||||
self.bn1 = _fused_bn(in_chs, **bn_args)
|
||||
|
||||
#
|
||||
|
@ -595,7 +567,9 @@ class InvertedResidual(nn.Cell):
|
|||
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
|
||||
self.shuffle = None
|
||||
|
||||
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
|
||||
self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="same",
|
||||
has_bias=False, group=mid_chs,
|
||||
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
|
||||
self.bn2 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
|
|
|
@ -20,13 +20,12 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter, context, Tensor
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
||||
def _make_divisible(x, divisor=4):
|
||||
|
@ -44,8 +43,8 @@ def _bn(channel):
|
|||
|
||||
|
||||
def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
|
||||
depthwise_conv = DepthwiseConv(
|
||||
in_channel, kernel_size, stride, pad_mode='same', pad=pad)
|
||||
depthwise_conv = nn.Conv2d(in_channel, in_channel, kernel_size, stride, pad_mode='same', padding=pad,
|
||||
has_bias=False, group=in_channel, weight_init='ones')
|
||||
conv = _conv2d(in_channel, out_channel, kernel_size=1)
|
||||
return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])
|
||||
|
||||
|
@ -75,8 +74,8 @@ class ConvBNReLU(nn.Cell):
|
|||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
|
||||
padding=padding)
|
||||
else:
|
||||
conv = DepthwiseConv(in_planes, kernel_size,
|
||||
stride, pad_mode='same', pad=padding)
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', padding=padding,
|
||||
has_bias=False, group=groups, weight_init='ones')
|
||||
layers = [conv, _bn(out_planes)]
|
||||
if use_act:
|
||||
layers.append(Activation(act_type))
|
||||
|
@ -87,52 +86,6 @@ class ConvBNReLU(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
class DepthwiseConv(nn.Cell):
|
||||
"""
|
||||
Depthwise Convolution warpper definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size.
|
||||
pad_mode (str): pad mode in (pad, same, valid)
|
||||
channel_multiplier (int): Output channel multiplier
|
||||
has_bias (bool): has bias or not
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
|
||||
"""
|
||||
|
||||
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
|
||||
super(DepthwiseConv, self).__init__()
|
||||
self.has_bias = has_bias
|
||||
self.in_channels = in_planes
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.out_channels = in_planes * channel_multiplier
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||
self.bias_add = P.BiasAdd()
|
||||
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
|
||||
self.weight = Parameter(initializer('ones', weight_shape), name="weight")
|
||||
|
||||
if has_bias:
|
||||
bias_shape = [channel_multiplier * in_planes]
|
||||
self.bias = Parameter(initializer('zeros', bias_shape), name="bias")
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, x):
|
||||
output = self.depthwise_conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class MyHSigmoid(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyHSigmoid, self).__init__()
|
||||
|
|
|
@ -20,9 +20,8 @@ from copy import deepcopy
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
|
||||
from mindspore import context, ms_function
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import Normal, Zero, One, Uniform
|
||||
from mindspore import ms_function
|
||||
from mindspore import Tensor
|
||||
|
||||
# Imagenet constant values
|
||||
|
@ -244,13 +243,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
|||
# activation fn
|
||||
key = op[0]
|
||||
v = op[1:]
|
||||
if v == 're':
|
||||
print('not support')
|
||||
elif v == 'r6':
|
||||
print('not support')
|
||||
elif v == 'hs':
|
||||
print('not support')
|
||||
elif v == 'sw':
|
||||
if v in ('re', 'r6', 'hs', 'sw'):
|
||||
print('not support')
|
||||
else:
|
||||
continue
|
||||
|
@ -485,40 +478,6 @@ class BlockBuilder(nn.Cell):
|
|||
return self.layer(x)
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
"""depth-wise convolution"""
|
||||
|
||||
def __init__(self, in_planes, kernel_size, stride):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
platform = context.get_context("device_target")
|
||||
weight_shape = [1, kernel_size, in_planes]
|
||||
weight_init = _initialize_weight_goog(shape=weight_shape)
|
||||
|
||||
if platform == "GPU":
|
||||
self.depthwise_conv = P.Conv2D(out_channel=in_planes*1,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad=int(kernel_size/2),
|
||||
pad_mode="pad",
|
||||
group=in_planes)
|
||||
|
||||
self.weight = Parameter(initializer(weight_init,
|
||||
[in_planes*1, 1, kernel_size, kernel_size]))
|
||||
|
||||
else:
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride, pad_mode='pad',
|
||||
pad=int(kernel_size/2))
|
||||
|
||||
self.weight = Parameter(initializer(weight_init,
|
||||
[1, in_planes, kernel_size, kernel_size]))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv(x, self.weight)
|
||||
return x
|
||||
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
"""drop connect implementation"""
|
||||
|
||||
|
@ -584,7 +543,9 @@ class DepthwiseSeparableConv(nn.Cell):
|
|||
self.has_pw_act = pw_act
|
||||
self.act_fn = Swish()
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
|
||||
self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad",
|
||||
padding=int(dw_kernel_size/2), has_bias=False, group=in_chs,
|
||||
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
|
||||
self.bn1 = _fused_bn(in_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
|
@ -640,7 +601,9 @@ class InvertedResidual(nn.Cell):
|
|||
if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
|
||||
self.shuffle = None
|
||||
|
||||
self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
|
||||
self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad",
|
||||
padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs,
|
||||
weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
|
||||
self.bn2 = _fused_bn(mid_chs, **bn_args)
|
||||
|
||||
if self.has_se:
|
||||
|
|
|
@ -16,34 +16,6 @@
|
|||
import math
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore import Parameter
|
||||
|
||||
|
||||
class DepthWiseConv(nn.Cell):
|
||||
'''Build DepthWise conv.'''
|
||||
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
|
||||
super(DepthWiseConv, self).__init__()
|
||||
self.has_bias = has_bias
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
|
||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
weight_shape = [channel_multiplier, in_planes, kernel_size[0], kernel_size[1]]
|
||||
self.weight = Parameter(initializer('ones', weight_shape))
|
||||
|
||||
if has_bias:
|
||||
bias_shape = [channel_multiplier * in_planes]
|
||||
self.bias = Parameter(initializer('zeros', bias_shape))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def construct(self, x):
|
||||
output = self.depthwise_conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class DSCNN(nn.Cell):
|
||||
|
@ -85,8 +57,9 @@ class DSCNN(nn.Cell):
|
|||
seq_cell.append(nn.BatchNorm2d(num_features=conv_feat[layer_no], momentum=0.98))
|
||||
in_channel = conv_feat[layer_no]
|
||||
else:
|
||||
seq_cell.append(DepthWiseConv(in_planes=in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
|
||||
stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same', pad=0))
|
||||
seq_cell.append(nn.Conv2d(in_channel, in_channel, kernel_size=(conv_kt[layer_no], conv_kf[layer_no]),
|
||||
stride=(conv_st[layer_no], conv_sf[layer_no]), pad_mode='same',
|
||||
has_bias=False, group=in_channel, weight_init='ones'))
|
||||
seq_cell.append(nn.BatchNorm2d(num_features=in_channel, momentum=0.98))
|
||||
seq_cell.append(nn.ReLU())
|
||||
seq_cell.append(nn.Conv2d(in_channels=in_channel, out_channels=conv_feat[layer_no], kernel_size=(1, 1),
|
||||
|
|
Loading…
Reference in New Issue