From 9870fc083e3bad82ecc9dd7f6de6e2775535bd99 Mon Sep 17 00:00:00 2001 From: simson Date: Fri, 11 Jun 2021 15:17:49 +0800 Subject: [PATCH] fix conv2d backprop pad mode wrong type --- mindspore/core/ops/conv2d.cc | 10 ++++++---- mindspore/core/ops/grad/conv2d_backprop_input.cc | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 93f32942ce3..f00d111cbc8 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -193,10 +193,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector &input_args) { const std::set valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32}; - std::map types; - (void)types.emplace("x", input_args[0]->BuildType()); - (void)types.emplace("w", input_args[1]->BuildType()); - auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + auto out_type = CheckAndConvertUtils::CheckTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name()); if (out_type->type_id() == TypeId::kNumberTypeInt8) { out_type = kInt32; } @@ -313,6 +310,11 @@ Format Conv2D::get_format() const { AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name()); + const std::set valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32}; + std::map types; + (void)types.emplace("x", input_args[0]->BuildType()); + (void)types.emplace("w", input_args[1]->BuildType()); + CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name()); return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index a36f6c93572..e6934a53495 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -36,7 +36,8 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector &dout_ auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name); // default pad mode is valid auto attr_pad_list_prt = primitive->GetAttr(kPadList); - auto pad_mode = GetValue(primitive->GetAttr(kPadMode)); + int64_t pad_mode; + CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true); ShapeVector pad_list = {0, 0, 0, 0}; if (!attr_pad_list_prt->isa()) { pad_list = GetValue(attr_pad_list_prt);