forked from mindspore-Ecosystem/mindspore
!18225 fix conv2d backprop pad mode wrong type
Merge pull request !18225 from Simson/fix-bug
This commit is contained in:
commit
6c173716f8
|
@ -197,10 +197,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
|
||||
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
auto out_type = CheckAndConvertUtils::CheckTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
if (out_type->type_id() == TypeId::kNumberTypeInt8) {
|
||||
out_type = kInt32;
|
||||
}
|
||||
|
@ -317,6 +314,11 @@ Format Conv2D::get_format() const {
|
|||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true);
|
||||
|
|
|
@ -36,7 +36,8 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
|
|||
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
|
||||
// default pad mode is valid
|
||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||
auto pad_mode = GetValue<int64_t>(primitive->GetAttr(kPadMode));
|
||||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
|
||||
ShapeVector pad_list = {0, 0, 0, 0};
|
||||
if (!attr_pad_list_prt->isa<None>()) {
|
||||
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
|
||||
|
|
Loading…
Reference in New Issue