Convert the input "xdiff_shape" of ROIAlignGrad to attr

This commit is contained in:
suxin 2023-02-17 20:41:35 +08:00
parent 12771a8f68
commit f145cb916f
2 changed files with 7 additions and 6 deletions

View File

@ -78,12 +78,12 @@ REG_ADPT_DESC(ROIAlign, kNameROIAlign, ADPT_DESC(ROIAlign))
// ROIAlignGrad
INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}};
OUTPUT_MAP(ROIAlignGrad) = {{0, OUTPUT_DESC(xdiff)}};
ATTR_MAP(ROIAlignGrad) = {
{"xdiff_shape", ATTR_DESC(xdiff_shape, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pooled_height", ATTR_DESC(pooled_height, AnyTraits<int64_t>())},
{"pooled_width", ATTR_DESC(pooled_width, AnyTraits<int64_t>())},
{"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits<float>())},
{"sample_num", ATTR_DESC(sample_num, AnyTraits<int64_t>())}};
INPUT_ATTR_MAP(ROIAlignGrad) = {
{3, ATTR_DESC(xdiff_shape, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ROIAlignGrad) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits<int64_t>())},
{"pooled_width", ATTR_DESC(pooled_width, AnyTraits<int64_t>())},
{"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits<float>())},
{"sample_num", ATTR_DESC(sample_num, AnyTraits<int64_t>())}};
REG_ADPT_DESC(ROIAlignGrad, kNameROIAlignGrad, ADPT_DESC(ROIAlignGrad))
// PSROIPooling

View File

@ -41,6 +41,7 @@ DECLARE_OP_ADAPTER(ROIAlign)
DECLARE_OP_USE_OUTPUT(ROIAlign)
DECLARE_OP_ADAPTER(ROIAlignGrad)
DECLARE_OP_USE_INPUT_ATTR(ROIAlignGrad)
DECLARE_OP_USE_OUTPUT(ROIAlignGrad)
DECLARE_OP_ADAPTER(PSROIPooling)