!18225 fix conv2d backprop pad mode wrong type

Merge pull request !18225 from Simson/fix-bug
This commit is contained in:
i-robot 2021-06-18 01:27:03 +00:00 committed by Gitee
commit 6c173716f8
2 changed files with 8 additions and 5 deletions

View File

@ -197,10 +197,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("w", input_args[1]->BuildType());
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
auto out_type = CheckAndConvertUtils::CheckTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
if (out_type->type_id() == TypeId::kNumberTypeInt8) {
out_type = kInt32;
}
@ -317,6 +314,11 @@ Format Conv2D::get_format() const {
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name());
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("w", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true);

View File

@ -36,7 +36,8 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
auto pad_mode = GetValue<int64_t>(primitive->GetAttr(kPadMode));
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
ShapeVector pad_list = {0, 0, 0, 0};
if (!attr_pad_list_prt->isa<None>()) {
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);