forked from mindspore-Ecosystem/mindspore
!43588 update backprop of conv2d op register
Merge pull request !43588 from xulei/backprop_conv
This commit is contained in:
commit
a44e39e1cd
|
@ -37,30 +37,20 @@ ATTR_MAP(Conv2D) = {
|
|||
OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2D, prim::kPrimConv2D->name(), ADPT_DESC(Conv2D))
|
||||
|
||||
// Conv2DBackpropInputD
|
||||
INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}};
|
||||
INPUT_ATTR_MAP(Conv2DBackpropInputD) = {
|
||||
{3, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(Conv2DBackpropInputD) = {
|
||||
// Conv2DBackpropInput
|
||||
INPUT_MAP(Conv2DBackpropInput) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(input_size)}};
|
||||
ATTR_MAP(Conv2DBackpropInput) = {
|
||||
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>())},
|
||||
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
|
||||
};
|
||||
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD))
|
||||
OUTPUT_MAP(Conv2DBackpropInput) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInput))
|
||||
|
||||
// Conv2DBackpropInput for tf inference
|
||||
INPUT_MAP(Conv2DBackpropInput) = {{1, INPUT_DESC(input_size)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(out_backprop)}};
|
||||
ATTR_MAP(Conv2DBackpropInput) = {
|
||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>())},
|
||||
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
};
|
||||
OUTPUT_MAP(Conv2DBackpropInput) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropInput, kNameConv2DBackpropInputV2, ADPT_DESC(Conv2DBackpropInput))
|
||||
REG_ADPT_DESC(Conv2DBackpropInputV2, kNameConv2DBackpropInputV2, ADPT_DESC(Conv2DBackpropInput))
|
||||
|
||||
// Deconvolution for caffe inference
|
||||
INPUT_MAP(Deconvolution) = {
|
||||
|
@ -74,7 +64,7 @@ ATTR_MAP(Deconvolution) = {
|
|||
{"offset", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(Deconvolution) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Deconvolution, kNameDeconvolution, ADPT_DESC(Deconvolution))
|
||||
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD))
|
||||
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInput))
|
||||
|
||||
// Conv2DTransposeD for tf onnx inference
|
||||
INPUT_MAP(Conv2DTransposeD) = {
|
||||
|
@ -91,19 +81,17 @@ ATTR_MAP(Conv2DTransposeD) = {
|
|||
OUTPUT_MAP(Conv2DTransposeD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DTransposeD, kNameConv2DTransposeD, ADPT_DESC(Conv2DTransposeD))
|
||||
|
||||
// Conv2DBackpropFilterD
|
||||
INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(Conv2DBackpropFilterD) = {
|
||||
{3, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(Conv2DBackpropFilterD) = {
|
||||
// Conv2DBackpropFilter
|
||||
INPUT_MAP(Conv2DBackpropFilter) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(filter_size)}};
|
||||
ATTR_MAP(Conv2DBackpropFilter) = {
|
||||
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
|
||||
};
|
||||
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropFilterD, prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD))
|
||||
OUTPUT_MAP(Conv2DBackpropFilter) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilter))
|
||||
|
||||
// Conv3DTransposeD
|
||||
INPUT_MAP(Conv3DTransposeD) = {
|
||||
|
|
|
@ -29,18 +29,13 @@ DECLARE_OP_ADAPTER(Conv2D)
|
|||
DECLARE_OP_USE_ENUM(Conv2D)
|
||||
DECLARE_OP_USE_OUTPUT(Conv2D)
|
||||
|
||||
DECLARE_OP_ADAPTER(Conv2DBackpropInputD)
|
||||
DECLARE_OP_USE_ENUM(Conv2DBackpropInputD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD)
|
||||
DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD)
|
||||
|
||||
DECLARE_OP_ADAPTER(Conv2DBackpropInput)
|
||||
DECLARE_OP_USE_ENUM(Conv2DBackpropInput)
|
||||
DECLARE_OP_USE_OUTPUT(Conv2DBackpropInput)
|
||||
|
||||
DECLARE_OP_ADAPTER(Conv2DBackpropFilterD)
|
||||
DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD)
|
||||
DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD)
|
||||
DECLARE_OP_ADAPTER(Conv2DBackpropFilter)
|
||||
DECLARE_OP_USE_ENUM(Conv2DBackpropFilter)
|
||||
DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilter)
|
||||
|
||||
DECLARE_OP_ADAPTER(Conv3DTransposeD)
|
||||
DECLARE_OP_USE_ENUM(Conv3DTransposeD)
|
||||
|
|
Loading…
Reference in New Issue