!11705 fix_conv3d_bn3d

From: @jiangzg001
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-02-20 13:49:57 +08:00 committed by Gitee
commit d4abe53f34
7 changed files with 98 additions and 58 deletions

View File

@ -93,18 +93,22 @@ rel_strs = {
def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False,
ret_five=False, greater_zero=True):
ret_five=False, greater_zero=True, third_one=False):
"""
Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
"""
def _raise_message():
def _raise_message(third_one=False):
if third_one:
raise ValueError(f"For '{prim_name}' attr '{arg_name[-3]}' should be 1, but got {arg_value}")
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
def _get_return_value():
if isinstance(arg_value, int):
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
if third_one:
ret = (1, 1, 1, arg_value, arg_value) if ret_five else (1, arg_value, arg_value)
elif len(arg_value) == 3:
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
elif len(arg_value) == 5:
@ -123,7 +127,10 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False,
continue
if not greater_zero and item >= 0:
continue
_raise_message()
if third_one:
if ret_value[-3] != 1:
_raise_message(third_one)
return tuple(ret_value)

View File

@ -418,6 +418,12 @@ class BatchNorm2d(_BatchNorm):
pass
@constexpr
def _check_3d_shape(input_shape):
if len(input_shape) != 5:
raise ValueError("For BatchNorm3d, input data must be 5-dimensional.")
class BatchNorm3d(Cell):
r"""
Batch normalization layer over a 5D input.
@ -443,17 +449,13 @@ class BatchNorm3d(Cell):
running_mean and running_var computation. Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
@ -491,6 +493,7 @@ class BatchNorm3d(Cell):
data_format='NCDHW'):
super(BatchNorm3d, self).__init__()
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
self.reshape = P.Reshape()
self.bn2d = BatchNorm2d(num_features=num_features,
eps=eps,
momentum=momentum,
@ -501,11 +504,10 @@ class BatchNorm3d(Cell):
moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics,
data_format="NCHW")
self.shape = P.Shape()
self.reshape = P.Reshape()
def construct(self, input_x):
x_shape = self.shape(input_x)
x_shape = F.shape(input_x)
_check_3d_shape(x_shape)
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(input_x)
bn3d_out = self.reshape(bn2d_out, x_shape)

View File

@ -98,12 +98,11 @@ def get_bprop_conv3d_transpose(self):
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
input_size = self.input_size
def bprop(x, w, out, dout):
dx = input_grad(dout, w)
dw = filter_grad(dout, x, F.shape(w))
return dx, dw, zeros_like(input_size)
return dx, dw, zeros_like(out)
return bprop

View File

@ -7385,8 +7385,17 @@ class Conv3D(PrimitiveWithInfer):
3D convolution layer.
Applies a 3D convolution over an input tensor which is typically of shape
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape :math:`(C_{in}, D_{in}, H_{in}, W_{in})`.
For input shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` and output shape
:math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. where :math:`N` is batch size. :math:`C` is channel number.
the formula is defined as:
.. math::
\operatorname{out}\left(N_{i}, C_{\text {out}_j}\right)=\operatorname{bias}\left(C_{\text {out}_j}\right)+
\sum_{k=0}^{C_{in}-1} ccor(\text {weight}\left(C_{\text {out}_j}, k\right),
\operatorname{input}\left(N_{i}, k\right))
where :math:`ccor` is the cross-correlation operator.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} -
@ -7402,7 +7411,7 @@ class Conv3D(PrimitiveWithInfer):
mode (int): Modes for different convolutions. Not currently used.
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
pad[3], pad[4] and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
@ -7414,6 +7423,7 @@ class Conv3D(PrimitiveWithInfer):
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
- **bias** (Tensor) - Tensor of shape :math:`C_{in}`. Currently, only support none or zero.
Outputs:
Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
@ -7422,8 +7432,8 @@ class Conv3D(PrimitiveWithInfer):
``Ascend``
Examples:
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float32)
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float16)
>>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3))
>>> output = conv3d(input, weight)
>>> print(output.shape)
@ -7446,7 +7456,8 @@ class Conv3D(PrimitiveWithInfer):
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
@ -7454,17 +7465,17 @@ class Conv3D(PrimitiveWithInfer):
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
self.add_prim_attr("pad", pad)
self.padding = pad
validator.check_int_range(self.padding[0], 0, kernel_size[0], Rel.INC_LEFT,
validator.check_int_range(self.padding[0], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.padding[1], 0, kernel_size[0], Rel.INC_LEFT,
validator.check_int_range(self.padding[1], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.padding[2], 0, kernel_size[1], Rel.INC_LEFT,
validator.check_int_range(self.padding[2], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.padding[3], 0, kernel_size[1], Rel.INC_LEFT,
validator.check_int_range(self.padding[3], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.padding[4], 0, kernel_size[2], Rel.INC_LEFT,
validator.check_int_range(self.padding[4], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name)
validator.check_int_range(self.padding[5], 0, kernel_size[2], Rel.INC_LEFT,
validator.check_int_range(self.padding[5], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name)
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.add_prim_attr('pad_mode', self.pad_mode)
@ -7588,8 +7599,8 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
``Ascend``
Examples:
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float32)
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
>>> x = Tensor(np.ones([16, 32, 13, 37, 33]))
>>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
>>> output = conv3d_backprop_input(dout, weight, F.shape(x))
@ -7640,12 +7651,15 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('io_format', self.format)
def __infer__(self, w, doutput, x_size):
validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name)
validator.check_equal_int(len(doutput['shape']), 5, 'The dimension of dout', self.name)
validator.check_equal_int(len(x_size['shape']), 5, 'The dimension of input_size', self.name)
x_size_v = x_size['value']
validator.check_value_type('x_size', x_size_v, [tuple], self.name)
for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_dtypes = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name)
@ -7690,15 +7704,30 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
class Conv3DTranspose(PrimitiveWithInfer):
"""
Computes the gradients of convolution 3D with respect to the input.
Compute a 3D transposed convolution, which is also known as a deconvolution
(although it is not an actual deconvolution).
Input is typically of shape :math:`(N, C, D, H, W)`, where :math:`N` is batch size and :math:`C` is channel number.
If the 'pad_mode' is set to be "pad", the height and width of output are defined as:
.. math::
D_{out} = (D_{in} - 1) \times \text{stride_d} - 2 \times \text{padding_d} + \text{dilation_d} \times
(\text{kernel_size_d} - 1) + \text{output_padding_d} + 1
H_{out} = (H_{in} - 1) \times \text{stride_h} - 2 \times \text{padding_h} + \text{dilation_h} \times
(\text{kernel_size_h} - 1) + \text{output_padding_h} + 1
W_{out} = (W_{in} - 1) \times \text{stride_w} - 2 \times \text{padding_w} + \text{dilation_w} \times
(\text{kernel_size_w} - 1) + 1
Args:
in_channel (int): The channel of the input x.
out_channel (int): The channel of the weight x.
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
mode (int): Modes for different convolutions. Not currently used.
mode (int): Modes for different convolutions. Default is 1. Not currently used.
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers,
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six integers,
the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
@ -7706,14 +7735,15 @@ class Conv3DTranspose(PrimitiveWithInfer):
group (int): Splits input into groups. Default: 1.
output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
input_size (tuple[int]): A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Not currently used.
Inputs:
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
:math:`(C_{in}//groups, C_{out}, D_{in}, K_h, K_w)`.
- **bias** (Tensor) - Tensor of shape :math:`C_{out}`. Currently, only support none or zero.
Outputs:
Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input.
@ -7722,8 +7752,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
``Ascend``
Examples:
>>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float32)
>>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float16)
>>> conv3d_transpose = P.Conv3DTranspose(in_channel=16, out_channel=3, kernel_size=(4, 6, 2))
>>> output = conv3d_transpose(input_x, weight)
>>> print(output.shape)
@ -7751,7 +7781,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name,
allow_five=True, ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
@ -7760,17 +7791,17 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.pad_list = pad
for item in self.pad_list:
validator.check_non_negative_int(item, 'pad item', self.name)
validator.check_int_range(self.pad_list[0], 0, kernel_size[0], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[0], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.pad_list[1], 0, kernel_size[0], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[1], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.pad_list[2], 0, kernel_size[1], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[2], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.pad_list[3], 0, kernel_size[1], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[3], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.pad_list[4], 0, kernel_size[2], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[4], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name)
validator.check_int_range(self.pad_list[5], 0, kernel_size[2], Rel.INC_LEFT,
validator.check_int_range(self.pad_list[5], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
@ -7796,7 +7827,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
raise ValueError("Bias currently only support None.")
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check("filter's batch", w['shape'][0], "input x's channel", x['shape'][1], Rel.EQ, self.name)
validator.check("filter's batch", w['shape'][0], "input x's channel",
x['shape'][1], Rel.EQ, self.name)
# infer shape
x_shape = x['shape']
w_shape = w['shape']
@ -7808,7 +7840,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
(self.kernel_size[1] - 1) + self.output_padding[3] + 1
w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
(self.kernel_size[2] - 1) + self.output_padding[4] + 1
output_shape = (x_shape[0], w_shape[1], d_out, h_out, w_out)
output_shape = (x_shape[0], w_shape[1]*self.group, d_out, h_out, w_out)
self.add_prim_attr('input_size', output_shape)
out = {
'value': None,

View File

@ -116,7 +116,7 @@ run_distribute_train_s16_r1.sh
run_distribute_train_s8_r1.sh
```
3. Train s8 with voctrain dataset, finetuning from model in pervious step, training script is:
3. Train s8 with voctrain dataset, finetuning from model in previous step, training script is:
```shell
run_distribute_train_s8_r2.sh
@ -302,7 +302,7 @@ do
done
```
3. Train s8 with voctrain dataset, finetuning from model in pervious step, training script is as follows:
3. Train s8 with voctrain dataset, finetuning from model in previous step, training script is as follows:
```shell
# run_distribute_train_s8_r2.sh

View File

@ -38,7 +38,7 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
datas = []
data = []
with open(args.data_lst) as f:
lines = f.readlines()
if args.shuffle:
@ -59,14 +59,14 @@ if __name__ == '__main__':
sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read()
datas.append(sample_)
data.append(sample_)
cnt += 1
if cnt % 1000 == 0:
writer.write_raw_data(datas)
writer.write_raw_data(data)
print('number of samples written:', cnt)
datas = []
data = []
if datas:
writer.write_raw_data(datas)
if data:
writer.write_raw_data(data)
writer.commit()
print('number of samples written:', cnt)

View File

@ -112,7 +112,7 @@ def create_voc_train_aug_lst_txt():
if id_ in voc_train_data_lst + voc_val_data_lst:
continue
id_ = id_.strip()
img_ = os.path.join(SBD_ANNO_DIR, id_ + '.jpg')
img_ = os.path.join(SBD_IMG_DIR, id_ + '.jpg')
anno_ = os.path.join(SBD_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')