!43588 update backprop of conv2d op register

Merge pull request !43588 from xulei/backprop_conv
This commit is contained in:
i-robot 2022-10-12 06:43:22 +00:00 committed by Gitee
commit a44e39e1cd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 16 additions and 33 deletions

View File

@ -37,30 +37,20 @@ ATTR_MAP(Conv2D) = {
OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2D, prim::kPrimConv2D->name(), ADPT_DESC(Conv2D))
// Conv2DBackpropInputD
INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}};
INPUT_ATTR_MAP(Conv2DBackpropInputD) = {
{3, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropInputD) = {
// Conv2DBackpropInput
INPUT_MAP(Conv2DBackpropInput) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(input_size)}};
ATTR_MAP(Conv2DBackpropInput) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
};
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD))
OUTPUT_MAP(Conv2DBackpropInput) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInput))
// Conv2DBackpropInput for tf inference
INPUT_MAP(Conv2DBackpropInput) = {{1, INPUT_DESC(input_size)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(out_backprop)}};
ATTR_MAP(Conv2DBackpropInput) = {
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())},
};
OUTPUT_MAP(Conv2DBackpropInput) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropInput, kNameConv2DBackpropInputV2, ADPT_DESC(Conv2DBackpropInput))
REG_ADPT_DESC(Conv2DBackpropInputV2, kNameConv2DBackpropInputV2, ADPT_DESC(Conv2DBackpropInput))
// Deconvolution for caffe inference
INPUT_MAP(Deconvolution) = {
@ -74,7 +64,7 @@ ATTR_MAP(Deconvolution) = {
{"offset", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
OUTPUT_MAP(Deconvolution) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Deconvolution, kNameDeconvolution, ADPT_DESC(Deconvolution))
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD))
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInput))
// Conv2DTransposeD for tf onnx inference
INPUT_MAP(Conv2DTransposeD) = {
@ -91,19 +81,17 @@ ATTR_MAP(Conv2DTransposeD) = {
OUTPUT_MAP(Conv2DTransposeD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DTransposeD, kNameConv2DTransposeD, ADPT_DESC(Conv2DTransposeD))
// Conv2DBackpropFilterD
INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}};
INPUT_ATTR_MAP(Conv2DBackpropFilterD) = {
{3, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(Conv2DBackpropFilterD) = {
// Conv2DBackpropFilter
INPUT_MAP(Conv2DBackpropFilter) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(filter_size)}};
ATTR_MAP(Conv2DBackpropFilter) = {
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
};
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropFilterD, prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD))
OUTPUT_MAP(Conv2DBackpropFilter) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilter))
// Conv3DTransposeD
INPUT_MAP(Conv3DTransposeD) = {

View File

@ -29,18 +29,13 @@ DECLARE_OP_ADAPTER(Conv2D)
DECLARE_OP_USE_ENUM(Conv2D)
DECLARE_OP_USE_OUTPUT(Conv2D)
DECLARE_OP_ADAPTER(Conv2DBackpropInputD)
DECLARE_OP_USE_ENUM(Conv2DBackpropInputD)
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD)
DECLARE_OP_ADAPTER(Conv2DBackpropInput)
DECLARE_OP_USE_ENUM(Conv2DBackpropInput)
DECLARE_OP_USE_OUTPUT(Conv2DBackpropInput)
DECLARE_OP_ADAPTER(Conv2DBackpropFilterD)
DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD)
DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD)
DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD)
DECLARE_OP_ADAPTER(Conv2DBackpropFilter)
DECLARE_OP_USE_ENUM(Conv2DBackpropFilter)
DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilter)
DECLARE_OP_ADAPTER(Conv3DTransposeD)
DECLARE_OP_USE_ENUM(Conv3DTransposeD)