Modified pad validator of conv3d.
This commit is contained in:
parent
0a870440e0
commit
57a75bb455
|
@ -100,7 +100,7 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
|
|||
|
||||
def _raise_message(third_one_flag=False, three_input_flag=False):
|
||||
if third_one_flag:
|
||||
raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {arg_value[-3]}")
|
||||
raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {ret_value[-3]}")
|
||||
if three_input_flag:
|
||||
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of "
|
||||
f"three positive int numbers, but got {arg_value}")
|
||||
|
@ -110,8 +110,6 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
|
|||
def _get_return_value():
|
||||
if isinstance(arg_value, int):
|
||||
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
|
||||
if third_one:
|
||||
ret = (1, 1, 1, arg_value, arg_value) if ret_five else (1, arg_value, arg_value)
|
||||
elif len(arg_value) == 3:
|
||||
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
|
||||
elif len(arg_value) == 5:
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
|
||||
const std::set<std::string> InvalidOps = {kSplitOpName, kSplitVOpName, kConcatOpName};
|
||||
|
@ -211,4 +212,5 @@ const AnfNodePtr SplitOpOptimizer::Process(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kDynamicGRUV2GradInputNum = 12;
|
||||
constexpr size_t kDynamicGRUV2GradOutputNum = 6;
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kInputNum = 3;
|
||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
||||
namespace {
|
||||
|
@ -117,4 +118,5 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
|
|||
}
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator, Rel, twice, triple
|
||||
from mindspore._checkparam import Validator, Rel, twice, _check_3d_int_or_tuple
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -594,9 +594,9 @@ class Conv3d(_Conv):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
data_format='NCDHW'):
|
||||
kernel_size = triple(kernel_size)
|
||||
stride = triple(stride)
|
||||
dilation = triple(dilation)
|
||||
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
|
||||
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
|
||||
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
|
||||
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
|
||||
if isinstance(padding, tuple):
|
||||
Validator.check_equal_int(len(padding), 6, 'padding size', self.cls_name)
|
||||
|
@ -765,13 +765,13 @@ class Conv3dTranspose(_Conv):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
data_format='NCDHW'):
|
||||
kernel_size = triple(kernel_size)
|
||||
stride = triple(stride)
|
||||
dilation = triple(dilation)
|
||||
kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.cls_name)
|
||||
stride = _check_3d_int_or_tuple("stride", stride, self.cls_name)
|
||||
dilation = _check_3d_int_or_tuple("dilation", dilation, self.cls_name)
|
||||
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
|
||||
if isinstance(padding, tuple):
|
||||
Validator.check_equal_int(len(padding), 6, 'padding size', self.cls_name)
|
||||
output_padding = triple(output_padding)
|
||||
output_padding = _check_3d_int_or_tuple("output_padding", output_padding, self.cls_name, greater_zero=False)
|
||||
super(Conv3dTranspose, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
|
|
@ -7759,8 +7759,8 @@ class Conv3D(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
Currently input data type only support float16 and float32.
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
|
||||
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`. Currently weight data type only support float16 and float32.
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(k_d, K_h, K_w)`, then the shape is
|
||||
:math:`(C_{out}, C_{in}//groups, k_d, K_h, K_w)`. Currently weight data type only support float16 and float32.
|
||||
- **bias** (Tensor) - Tensor of shape :math:`C_{in}`. Currently, only support none.
|
||||
|
||||
Outputs:
|
||||
|
@ -7815,18 +7815,7 @@ class Conv3D(PrimitiveWithInfer):
|
|||
f"six positive int numbers, but got `{len(pad)}`.")
|
||||
self.add_prim_attr("pad", pad)
|
||||
self.padding = pad
|
||||
validator.check_int_range(self.padding[0], 0, self.kernel_size[0], Rel.INC_LEFT,
|
||||
'pad_d belonging [0, kernel_size_d)', self.name)
|
||||
validator.check_int_range(self.padding[1], 0, self.kernel_size[0], Rel.INC_LEFT,
|
||||
'pad_d belonging [0, kernel_size_d)', self.name)
|
||||
validator.check_int_range(self.padding[2], 0, self.kernel_size[1], Rel.INC_LEFT,
|
||||
'pad_h belonging [0, kernel_size_h)', self.name)
|
||||
validator.check_int_range(self.padding[3], 0, self.kernel_size[1], Rel.INC_LEFT,
|
||||
'pad_h belonging [0, kernel_size_h)', self.name)
|
||||
validator.check_int_range(self.padding[4], 0, self.kernel_size[2], Rel.INC_LEFT,
|
||||
'pad_w belonging [0, kernel_size_w)', self.name)
|
||||
validator.check_int_range(self.padding[5], 0, self.kernel_size[2], Rel.INC_LEFT,
|
||||
'pad_w belonging [0, kernel_size_w)', self.name)
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
|
@ -7902,6 +7891,21 @@ class Conv3D(PrimitiveWithInfer):
|
|||
w_out = math.floor(w_out)
|
||||
|
||||
self.pad_list = [pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right]
|
||||
filter_d = (self.kernel_size[0] - 1) * dilation_d + 1
|
||||
filter_h = (self.kernel_size[1] - 1) * dilation_h + 1
|
||||
filter_w = (self.kernel_size[2] - 1) * dilation_w + 1
|
||||
validator.check_int_range(self.pad_list[0], 0, filter_d, Rel.INC_LEFT,
|
||||
'pad_d belonging [0, filter_d)', self.name)
|
||||
validator.check_int_range(self.pad_list[1], 0, filter_d, Rel.INC_LEFT,
|
||||
'pad_d belonging [0, filter_d)', self.name)
|
||||
validator.check_int_range(self.pad_list[2], 0, filter_h, Rel.INC_LEFT,
|
||||
'pad_h belonging [0, filter_h)', self.name)
|
||||
validator.check_int_range(self.pad_list[3], 0, filter_h, Rel.INC_LEFT,
|
||||
'pad_h belonging [0, filter_h)', self.name)
|
||||
validator.check_int_range(self.pad_list[4], 0, filter_w, Rel.INC_LEFT,
|
||||
'pad_w belonging [0, filter_w)', self.name)
|
||||
validator.check_int_range(self.pad_list[5], 0, filter_w, Rel.INC_LEFT,
|
||||
'pad_w belonging [0, filter_w)', self.name)
|
||||
self.add_prim_attr('pad_list', (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right))
|
||||
out_channel = self.out_channel
|
||||
out_shape = [x_shape[0], out_channel, d_out, h_out, w_out]
|
||||
|
@ -8124,8 +8128,8 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
|
||||
data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
|
||||
and float32.
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
|
||||
:math:`(C_{in}//groups, C_{out}, D_{in}, K_h, K_w)`. Currently weight data type only support float16
|
||||
- **weight** (Tensor) - Set size of kernel is :math:`(k_d, K_h, K_w)`, then the shape is
|
||||
:math:`(C_{in}, C_{out}//groups, k_d, K_h, K_w)`. Currently weight data type only support float16
|
||||
and float32.
|
||||
- **bias** (Tensor) - Tensor of shape :math:`C_{out}`. Currently, only support none.
|
||||
|
||||
|
@ -8189,6 +8193,7 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
raise ValueError(f"For `conv3d` attr 'pad' should be an positive int number or a tuple of "
|
||||
f"six positive int numbers, but got `{len(pad)}`.")
|
||||
self.pad_list = pad
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
|
||||
self.add_prim_attr('pad_mode', self.pad_mode)
|
||||
|
||||
|
@ -8198,21 +8203,9 @@ class Conv3DTranspose(PrimitiveWithInfer):
|
|||
if self.pad_mode == 'pad':
|
||||
for item in self.pad_list:
|
||||
validator.check_non_negative_int(item, 'pad item', self.name)
|
||||
validator.check_int_range(self.pad_list[0], 0, self.kernel_size[0], Rel.INC_LEFT,
|
||||
'pad_d belonging [0, kernel_size_d)', self.name)
|
||||
validator.check_int_range(self.pad_list[1], 0, self.kernel_size[0], Rel.INC_LEFT,
|
||||
'pad_d belonging [0, kernel_size_d)', self.name)
|
||||
validator.check_int_range(self.pad_list[2], 0, self.kernel_size[1], Rel.INC_LEFT,
|
||||
'pad_h belonging [0, kernel_size_h)', self.name)
|
||||
validator.check_int_range(self.pad_list[3], 0, self.kernel_size[1], Rel.INC_LEFT,
|
||||
'pad_h belonging [0, kernel_size_h)', self.name)
|
||||
validator.check_int_range(self.pad_list[4], 0, self.kernel_size[2], Rel.INC_LEFT,
|
||||
'pad_w belonging [0, kernel_size_w)', self.name)
|
||||
validator.check_int_range(self.pad_list[5], 0, self.kernel_size[2], Rel.INC_LEFT,
|
||||
'pad_w belonging [0, kernel_size_w)', self.name)
|
||||
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
||||
self.add_prim_attr('mode', self.mode)
|
||||
self.mode = validator.check_equal_int(group, 1, 'group', self.name)
|
||||
self.group = validator.check_equal_int(group, 1, 'group', self.name)
|
||||
self.add_prim_attr('groups', self.group)
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
|
Loading…
Reference in New Issue