fix_conv3d_bn3d

This commit is contained in:
jiangzhenguang 2021-01-27 15:38:22 +08:00
parent e21bc108cd
commit 8913505536
7 changed files with 98 additions and 58 deletions

View File

@ -93,18 +93,22 @@ rel_strs = {
def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False,
ret_five=False, greater_zero=True): ret_five=False, greater_zero=True, third_one=False):
""" """
Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements. Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
""" """
def _raise_message(): def _raise_message(third_one=False):
if third_one:
raise ValueError(f"For '{prim_name}' attr '{arg_name[-3]}' should be 1, but got {arg_value}")
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three " raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}") f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
def _get_return_value(): def _get_return_value():
if isinstance(arg_value, int): if isinstance(arg_value, int):
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value) ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
if third_one:
ret = (1, 1, 1, arg_value, arg_value) if ret_five else (1, arg_value, arg_value)
elif len(arg_value) == 3: elif len(arg_value) == 3:
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
elif len(arg_value) == 5: elif len(arg_value) == 5:
@ -123,7 +127,10 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False,
continue continue
if not greater_zero and item >= 0: if not greater_zero and item >= 0:
continue continue
_raise_message() if third_one:
if ret_value[-3] != 1:
_raise_message(third_one)
return tuple(ret_value) return tuple(ret_value)

View File

@ -404,6 +404,12 @@ class BatchNorm2d(_BatchNorm):
pass pass
@constexpr
def _check_3d_shape(input_shape):
if len(input_shape) != 5:
raise ValueError("For BatchNorm3d, input data must be 5-dimensional.")
class BatchNorm3d(Cell): class BatchNorm3d(Cell):
r""" r"""
Batch normalization layer over a 5D input. Batch normalization layer over a 5D input.
@ -429,17 +435,13 @@ class BatchNorm3d(Cell):
running_mean and running_var computation. Default: 0.9. running_mean and running_var computation. Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
'he_uniform', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
'he_uniform', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use and variance of current batch data and track the running mean and variance, the evaluation process will use
@ -477,6 +479,7 @@ class BatchNorm3d(Cell):
data_format='NCDHW'): data_format='NCDHW'):
super(BatchNorm3d, self).__init__() super(BatchNorm3d, self).__init__()
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
self.reshape = P.Reshape()
self.bn2d = BatchNorm2d(num_features=num_features, self.bn2d = BatchNorm2d(num_features=num_features,
eps=eps, eps=eps,
momentum=momentum, momentum=momentum,
@ -487,11 +490,10 @@ class BatchNorm3d(Cell):
moving_var_init=moving_var_init, moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics, use_batch_statistics=use_batch_statistics,
data_format="NCHW") data_format="NCHW")
self.shape = P.Shape()
self.reshape = P.Reshape()
def construct(self, input_x): def construct(self, input_x):
x_shape = self.shape(input_x) x_shape = F.shape(input_x)
_check_3d_shape(x_shape)
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4])) input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(input_x) bn2d_out = self.bn2d(input_x)
bn3d_out = self.reshape(bn2d_out, x_shape) bn3d_out = self.reshape(bn2d_out, x_shape)

View File

@ -97,12 +97,11 @@ def get_bprop_conv3d_transpose(self):
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad", out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
) )
input_size = self.input_size
def bprop(x, w, out, dout): def bprop(x, w, out, dout):
dx = input_grad(dout, w) dx = input_grad(dout, w)
dw = filter_grad(dout, x, F.shape(w)) dw = filter_grad(dout, x, F.shape(w))
return dx, dw, zeros_like(input_size) return dx, dw, zeros_like(out)
return bprop return bprop

View File

@ -7106,8 +7106,17 @@ class Conv3D(PrimitiveWithInfer):
3D convolution layer. 3D convolution layer.
Applies a 3D convolution over an input tensor which is typically of shape Applies a 3D convolution over an input tensor which is typically of shape
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. For input shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` and output shape
For each batch of shape :math:`(C_{in}, D_{in}, H_{in}, W_{in})`. :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. where :math:`N` is batch size. :math:`C` is channel number.
the formula is defined as:
.. math::
\operatorname{out}\left(N_{i}, C_{\text {out}_j}\right)=\operatorname{bias}\left(C_{\text {out}_j}\right)+
\sum_{k=0}^{C_{in}-1} ccor(\text {weight}\left(C_{\text {out}_j}, k\right),
\operatorname{input}\left(N_{i}, k\right))
where :math:`ccor` is the cross-correlation operator.
If the 'pad_mode' is set to be "valid", the output height and width will be If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} - :math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} -
@ -7123,7 +7132,7 @@ class Conv3D(PrimitiveWithInfer):
mode (int): Modes for different convolutions. Not currently used. mode (int): Modes for different convolutions. Not currently used.
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
pad[3], pad[4] and pad[5] correspondingly. pad[3], pad[4] and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
@ -7135,6 +7144,7 @@ class Conv3D(PrimitiveWithInfer):
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is - **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`. :math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
- **bias** (Tensor) - Tensor of shape :math:`C_{in}`. Currently, only support none or zero.
Outputs: Outputs:
Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
@ -7143,8 +7153,8 @@ class Conv3D(PrimitiveWithInfer):
``Ascend`` ``Ascend``
Examples: Examples:
>>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float32) >>> input = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float32) >>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float16)
>>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3)) >>> conv3d = P.Conv3D(out_channel=32, kernel_size=(4, 3, 3))
>>> output = conv3d(input, weight) >>> output = conv3d(input, weight)
>>> print(output.shape) >>> print(output.shape)
@ -7167,7 +7177,8 @@ class Conv3D(PrimitiveWithInfer):
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('strides', self.stride) self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation) self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name) validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int): if isinstance(pad, int):
@ -7175,17 +7186,17 @@ class Conv3D(PrimitiveWithInfer):
validator.check_equal_int(len(pad), 6, 'pad size', self.name) validator.check_equal_int(len(pad), 6, 'pad size', self.name)
self.add_prim_attr("pad", pad) self.add_prim_attr("pad", pad)
self.padding = pad self.padding = pad
validator.check_int_range(self.padding[0], 0, kernel_size[0], Rel.INC_LEFT, validator.check_int_range(self.padding[0], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name) 'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.padding[1], 0, kernel_size[0], Rel.INC_LEFT, validator.check_int_range(self.padding[1], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name) 'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.padding[2], 0, kernel_size[1], Rel.INC_LEFT, validator.check_int_range(self.padding[2], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name) 'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.padding[3], 0, kernel_size[1], Rel.INC_LEFT, validator.check_int_range(self.padding[3], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name) 'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.padding[4], 0, kernel_size[2], Rel.INC_LEFT, validator.check_int_range(self.padding[4], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name) 'pad_w belonging [0, kernel_size_w)', self.name)
validator.check_int_range(self.padding[5], 0, kernel_size[2], Rel.INC_LEFT, validator.check_int_range(self.padding[5], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name) 'pad_w belonging [0, kernel_size_w)', self.name)
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
self.add_prim_attr('pad_mode', self.pad_mode) self.add_prim_attr('pad_mode', self.pad_mode)
@ -7309,8 +7320,8 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
``Ascend`` ``Ascend``
Examples: Examples:
>>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float32) >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float32) >>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
>>> x = Tensor(np.ones([16, 32, 13, 37, 33])) >>> x = Tensor(np.ones([16, 32, 13, 37, 33]))
>>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
>>> output = conv3d_backprop_input(dout, weight, F.shape(x)) >>> output = conv3d_backprop_input(dout, weight, F.shape(x))
@ -7361,12 +7372,15 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('io_format', self.format) self.add_prim_attr('io_format', self.format)
def __infer__(self, w, doutput, x_size): def __infer__(self, w, doutput, x_size):
validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name)
validator.check_equal_int(len(doutput['shape']), 5, 'The dimension of dout', self.name)
validator.check_equal_int(len(x_size['shape']), 5, 'The dimension of input_size', self.name)
x_size_v = x_size['value'] x_size_v = x_size['value']
validator.check_value_type('x_size', x_size_v, [tuple], self.name) validator.check_value_type('x_size', x_size_v, [tuple], self.name)
for i, dim_len in enumerate(x_size_v): for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']} args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_dtypes = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name) validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name)
@ -7411,15 +7425,30 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
class Conv3DTranspose(PrimitiveWithInfer): class Conv3DTranspose(PrimitiveWithInfer):
""" """
Computes the gradients of convolution 3D with respect to the input. Compute a 3D transposed convolution, which is also known as a deconvolution
(although it is not an actual deconvolution).
Input is typically of shape :math:`(N, C, D, H, W)`, where :math:`N` is batch size and :math:`C` is channel number.
If the 'pad_mode' is set to be "pad", the height and width of output are defined as:
.. math::
D_{out} = (D_{in} - 1) \times \text{stride_d} - 2 \times \text{padding_d} + \text{dilation_d} \times
(\text{kernel_size_d} - 1) + \text{output_padding_d} + 1
H_{out} = (H_{in} - 1) \times \text{stride_h} - 2 \times \text{padding_h} + \text{dilation_h} \times
(\text{kernel_size_h} - 1) + \text{output_padding_h} + 1
W_{out} = (W_{in} - 1) \times \text{stride_w} - 2 \times \text{padding_w} + \text{dilation_w} \times
(\text{kernel_size_w} - 1) + 1
Args: Args:
in_channel (int): The channel of the input x. in_channel (int): The channel of the input x.
out_channel (int): The channel of the weight x. out_channel (int): The channel of the weight x.
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
mode (int): Modes for different convolutions. Not currently used. mode (int): Modes for different convolutions. Default is 1. Not currently used.
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six integers,
the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4] the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
and pad[5] correspondingly. and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
@ -7427,14 +7456,15 @@ class Conv3DTranspose(PrimitiveWithInfer):
group (int): Splits input into groups. Default: 1. group (int): Splits input into groups. Default: 1.
output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0. output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
data_format (str): The optional value for data format. Currently only support 'NCDHW'. data_format (str): The optional value for data format. Currently only support 'NCDHW'.
input_size (tuple[int]): A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Not currently used.
Inputs: Inputs:
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default - **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is - **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`. :math:`(C_{in}//groups, C_{out}, D_{in}, K_h, K_w)`.
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format - **bias** (Tensor) - Tensor of shape :math:`C_{out}`. Currently, only support none or zero.
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
Outputs: Outputs:
Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input. Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input.
@ -7443,8 +7473,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
``Ascend`` ``Ascend``
Examples: Examples:
>>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32) >>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float16)
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float32) >>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float16)
>>> conv3d_transpose = P.Conv3DTranspose(in_channel=16, out_channel=3, kernel_size=(4, 6, 2)) >>> conv3d_transpose = P.Conv3DTranspose(in_channel=16, out_channel=3, kernel_size=(4, 6, 2))
>>> output = conv3d_transpose(input_x, weight) >>> output = conv3d_transpose(input_x, weight)
>>> print(output.shape) >>> print(output.shape)
@ -7472,7 +7502,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('strides', self.stride) self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name,
allow_five=True, ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation) self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name) validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int): if isinstance(pad, int):
@ -7481,17 +7512,17 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.pad_list = pad self.pad_list = pad
for item in self.pad_list: for item in self.pad_list:
validator.check_non_negative_int(item, 'pad item', self.name) validator.check_non_negative_int(item, 'pad item', self.name)
validator.check_int_range(self.pad_list[0], 0, kernel_size[0], Rel.INC_LEFT, validator.check_int_range(self.pad_list[0], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name) 'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.pad_list[1], 0, kernel_size[0], Rel.INC_LEFT, validator.check_int_range(self.pad_list[1], 0, self.kernel_size[0], Rel.INC_LEFT,
'pad_d belonging [0, kernel_size_d)', self.name) 'pad_d belonging [0, kernel_size_d)', self.name)
validator.check_int_range(self.pad_list[2], 0, kernel_size[1], Rel.INC_LEFT, validator.check_int_range(self.pad_list[2], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name) 'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.pad_list[3], 0, kernel_size[1], Rel.INC_LEFT, validator.check_int_range(self.pad_list[3], 0, self.kernel_size[1], Rel.INC_LEFT,
'pad_h belonging [0, kernel_size_h)', self.name) 'pad_h belonging [0, kernel_size_h)', self.name)
validator.check_int_range(self.pad_list[4], 0, kernel_size[2], Rel.INC_LEFT, validator.check_int_range(self.pad_list[4], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name) 'pad_w belonging [0, kernel_size_w)', self.name)
validator.check_int_range(self.pad_list[5], 0, kernel_size[2], Rel.INC_LEFT, validator.check_int_range(self.pad_list[5], 0, self.kernel_size[2], Rel.INC_LEFT,
'pad_w belonging [0, kernel_size_w)', self.name) 'pad_w belonging [0, kernel_size_w)', self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode) self.add_prim_attr('mode', self.mode)
@ -7517,7 +7548,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
raise ValueError("Bias currently only support None.") raise ValueError("Bias currently only support None.")
valid_dtypes = [mstype.float16, mstype.float32] valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check("filter's batch", w['shape'][0], "input x's channel", x['shape'][1], Rel.EQ, self.name) validator.check("filter's batch", w['shape'][0], "input x's channel",
x['shape'][1], Rel.EQ, self.name)
# infer shape # infer shape
x_shape = x['shape'] x_shape = x['shape']
w_shape = w['shape'] w_shape = w['shape']
@ -7529,7 +7561,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
(self.kernel_size[1] - 1) + self.output_padding[3] + 1 (self.kernel_size[1] - 1) + self.output_padding[3] + 1
w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \ w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
(self.kernel_size[2] - 1) + self.output_padding[4] + 1 (self.kernel_size[2] - 1) + self.output_padding[4] + 1
output_shape = (x_shape[0], w_shape[1], d_out, h_out, w_out) output_shape = (x_shape[0], w_shape[1]*self.group, d_out, h_out, w_out)
self.add_prim_attr('input_size', output_shape) self.add_prim_attr('input_size', output_shape)
out = { out = {
'value': None, 'value': None,

View File

@ -116,7 +116,7 @@ run_distribute_train_s16_r1.sh
run_distribute_train_s8_r1.sh run_distribute_train_s8_r1.sh
``` ```
3. Train s8 with voctrain dataset, finetuning from model in pervious step, training script is: 3. Train s8 with voctrain dataset, finetuning from model in previous step, training script is:
```shell ```shell
run_distribute_train_s8_r2.sh run_distribute_train_s8_r2.sh
@ -302,7 +302,7 @@ do
done done
``` ```
3. Train s8 with voctrain dataset, finetuning from model in pervious step, training script is as follows: 3. Train s8 with voctrain dataset, finetuning from model in previous step, training script is as follows:
```shell ```shell
# run_distribute_train_s8_r2.sh # run_distribute_train_s8_r2.sh

View File

@ -38,7 +38,7 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
datas = [] data = []
with open(args.data_lst) as f: with open(args.data_lst) as f:
lines = f.readlines() lines = f.readlines()
if args.shuffle: if args.shuffle:
@ -59,14 +59,14 @@ if __name__ == '__main__':
sample_['data'] = f.read() sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f: with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read() sample_['label'] = f.read()
datas.append(sample_) data.append(sample_)
cnt += 1 cnt += 1
if cnt % 1000 == 0: if cnt % 1000 == 0:
writer.write_raw_data(datas) writer.write_raw_data(data)
print('number of samples written:', cnt) print('number of samples written:', cnt)
datas = [] data = []
if datas: if data:
writer.write_raw_data(datas) writer.write_raw_data(data)
writer.commit() writer.commit()
print('number of samples written:', cnt) print('number of samples written:', cnt)

View File

@ -112,7 +112,7 @@ def create_voc_train_aug_lst_txt():
if id_ in voc_train_data_lst + voc_val_data_lst: if id_ in voc_train_data_lst + voc_val_data_lst:
continue continue
id_ = id_.strip() id_ = id_.strip()
img_ = os.path.join(SBD_ANNO_DIR, id_ + '.jpg') img_ = os.path.join(SBD_IMG_DIR, id_ + '.jpg')
anno_ = os.path.join(SBD_ANNO_GRAY_DIR, id_ + '.png') anno_ = os.path.join(SBD_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n') f.write(img_ + ' ' + anno_ + '\n')