diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8e496928107..e1f3a92bb8f 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -83,6 +83,7 @@ constexpr auto kDropoutGenMask = "DropoutGenMask"; constexpr auto kDropoutDoMask = "DropoutDoMask"; constexpr auto kDropout = "Dropout"; constexpr auto kDropoutGrad = "DropoutGrad"; +constexpr auto kConv2DTranspose = "Conv2DTranspose"; // Here list all primitives used in backend or some special primitives used by core. // Arithmetic @@ -262,7 +263,7 @@ inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("Fus inline const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); inline const PrimitivePtr kPrimCTCLoss = std::make_shared(kCTCLoss); inline const PrimitivePtr kPrimFullConnection = std::make_shared("FullConnection"); -inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared("Conv2DTranspose"); +inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared(kConv2DTranspose); inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared("GroupConv2DGradInput"); inline const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); @@ -583,7 +584,6 @@ inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared("Space inline const PrimitivePtr kPrimPadFusion = std::make_shared("PadFusion"); inline const PrimitivePtr kPrimPowFusion = std::make_shared("PowFusion"); inline const PrimitivePtr kPrimResize = std::make_shared("Resize"); -inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared("Conv2dTranspose"); inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared("ArgMinWithValue"); inline const PrimitivePtr kPrimIf = std::make_shared("If"); inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared("AvgPoolFusion"); diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index b7c0cc1632b..a9d1bd9d2cd 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -21,31 +21,11 @@ #include #include "ops/conv2d_transpose.h" +#include "ops/grad/conv2d_backprop_input.h" namespace mindspore { namespace ops { -namespace { -abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; - return std::make_shared(input_shape); -} - -TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector &input_args) { - CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name()); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - const std::set valid_types = {kInt8, kInt32, kFloat16, kFloat32}; - std::map types; - types.emplace("doutput_dtye", input_args[0]->BuildType()); - types.emplace("w_dtype", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); -} -} // namespace - -void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector &kernel_size, +void Conv2DTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector &kernel_size, int64_t mode, const PadMode &pad_mode, const std::vector &pad, const std::vector &stride, const std::vector &dilation, int64_t group, const Format &format, const std::vector &pad_list) { @@ -62,16 +42,16 @@ void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::v set_pad_list(pad_list); } -void Conv2dTranspose::set_in_channel(int64_t in_channel) { +void Conv2DTranspose::set_in_channel(int64_t in_channel) { AddAttr(kInChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name()))); } -void Conv2dTranspose::set_out_channel(int64_t out_channel) { +void Conv2DTranspose::set_out_channel(int64_t out_channel) { AddAttr(kOutChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name()))); } -void Conv2dTranspose::set_kernel_size(const std::vector &kernel_size) { +void Conv2DTranspose::set_kernel_size(const std::vector &kernel_size) { CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name()); for (int64_t item : kernel_size) { CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name()); @@ -79,7 +59,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector &kernel_size) { AddAttr(kKernelSize, MakeValue(kernel_size)); } -void Conv2dTranspose::set_stride(const std::vector &stride) { +void Conv2DTranspose::set_stride(const std::vector &stride) { CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name()); for (int64_t item : stride) { CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name()); @@ -87,12 +67,12 @@ void Conv2dTranspose::set_stride(const std::vector &stride) { AddAttr(kStride, MakeValue(stride)); } -void Conv2dTranspose::set_dilation(const std::vector &dilation) { +void Conv2DTranspose::set_dilation(const std::vector &dilation) { CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name()); AddAttr(kDilation, MakeValue(dilation)); } -void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) { +void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) { std::vector pad = get_pad(); if (pad_mode == PAD) { for (auto item : pad) { @@ -105,89 +85,83 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) { AddAttr(kPadMode, MakeValue(swi)); } -void Conv2dTranspose::set_pad(const std::vector &pad) { +void Conv2DTranspose::set_pad(const std::vector &pad) { CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name()); AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } -void Conv2dTranspose::set_mode(int64_t mode) { +void Conv2DTranspose::set_mode(int64_t mode) { AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name()))); } -void Conv2dTranspose::set_group(int64_t group) { +void Conv2DTranspose::set_group(int64_t group) { AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name()))); } -void Conv2dTranspose::set_format(const Format &format) { +void Conv2DTranspose::set_format(const Format &format) { int64_t f = format; AddAttr(kFormat, MakeValue(f)); } -void Conv2dTranspose::set_pad_list(const std::vector &pad_list) { +void Conv2DTranspose::set_pad_list(const std::vector &pad_list) { CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name()); this->AddAttr(kPadList, MakeValue(pad_list)); } -int64_t Conv2dTranspose::get_in_channel() const { +int64_t Conv2DTranspose::get_in_channel() const { auto value_ptr = GetAttr(kInChannel); return GetValue(value_ptr); } -int64_t Conv2dTranspose::get_out_channel() const { +int64_t Conv2DTranspose::get_out_channel() const { auto value_ptr = GetAttr(kOutChannel); return GetValue(value_ptr); } -std::vector Conv2dTranspose::get_kernel_size() const { +std::vector Conv2DTranspose::get_kernel_size() const { auto value_ptr = GetAttr(kKernelSize); return GetValue>(value_ptr); } -std::vector Conv2dTranspose::get_stride() const { +std::vector Conv2DTranspose::get_stride() const { auto value_ptr = GetAttr(kStride); return GetValue>(value_ptr); } -std::vector Conv2dTranspose::get_dilation() const { +std::vector Conv2DTranspose::get_dilation() const { auto value_ptr = GetAttr(kDilation); return GetValue>(value_ptr); } -PadMode Conv2dTranspose::get_pad_mode() const { +PadMode Conv2DTranspose::get_pad_mode() const { auto value_ptr = GetAttr(kPadMode); return PadMode(GetValue(value_ptr)); } -std::vector Conv2dTranspose::get_pad() const { +std::vector Conv2DTranspose::get_pad() const { auto value_ptr = GetAttr(kPad); return GetValue>(value_ptr); } -int64_t Conv2dTranspose::get_mode() const { +int64_t Conv2DTranspose::get_mode() const { auto value_ptr = GetAttr(kMode); return GetValue(value_ptr); } -int64_t Conv2dTranspose::get_group() const { +int64_t Conv2DTranspose::get_group() const { auto value_ptr = GetAttr(kGroup); return GetValue(value_ptr); } -Format Conv2dTranspose::get_format() const { +Format Conv2DTranspose::get_format() const { auto value_ptr = GetAttr(kFormat); return Format(GetValue(value_ptr)); } -std::vector Conv2dTranspose::get_pad_list() const { +std::vector Conv2DTranspose::get_pad_list() const { auto value_ptr = GetAttr(kPadList); return GetValue>(value_ptr); } - -AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - return std::make_shared(Conv2dTransposeInferType(primitive, input_args), - Conv2dTransposeInferShape(primitive, input_args)->shape()); -} -REGISTER_PRIMITIVE_C(kNameConv2dTranspose, Conv2dTranspose); +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DTranspose, prim::kPrimConv2DTranspose, Conv2DBackpropInputInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/conv2d_transpose.h b/mindspore/core/ops/conv2d_transpose.h index 08be3ef0156..a88e50a3a97 100644 --- a/mindspore/core/ops/conv2d_transpose.h +++ b/mindspore/core/ops/conv2d_transpose.h @@ -27,17 +27,17 @@ #include "utils/check_convert_utils.h" namespace mindspore { namespace ops { -constexpr auto kNameConv2dTranspose = "Conv2dTranspose"; -class Conv2dTranspose : public PrimitiveC { +constexpr auto kNameConv2DTranspose = "Conv2DTranspose"; +class Conv2DTranspose : public PrimitiveC { public: - Conv2dTranspose() : PrimitiveC(kNameConv2dTranspose) { + Conv2DTranspose() : PrimitiveC(kNameConv2DTranspose) { InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"}); } - explicit Conv2dTranspose(const std::string k_name) : PrimitiveC(k_name) { + explicit Conv2DTranspose(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"}); } - ~Conv2dTranspose() = default; - MS_DECLARE_PARENT(Conv2dTranspose, PrimitiveC); + ~Conv2DTranspose() = default; + MS_DECLARE_PARENT(Conv2DTranspose, PrimitiveC); void Init(int64_t in_channel, int64_t out_channel, const std::vector &kernel_size, int64_t mode = 1, const PadMode &pad_mode = VALID, const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1}, const std::vector &dilation = {1, 1}, @@ -66,9 +66,9 @@ class Conv2dTranspose : public PrimitiveC { Format get_format() const; std::vector get_pad_list() const; }; -AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr Conv2DTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); -using PrimConv2dTransposePtr = std::shared_ptr; +using PrimConv2DTransposePtr = std::shared_ptr; } // namespace ops } // namespace mindspore #endif // MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_ diff --git a/mindspore/core/ops/fusion/conv2d_transpose_fusion.h b/mindspore/core/ops/fusion/conv2d_transpose_fusion.h index 210cb7a7989..0f5f26a72c2 100644 --- a/mindspore/core/ops/fusion/conv2d_transpose_fusion.h +++ b/mindspore/core/ops/fusion/conv2d_transpose_fusion.h @@ -25,13 +25,13 @@ namespace mindspore { namespace ops { constexpr auto kNameConv2dTransposeFusion = "Conv2dTransposeFusion"; -class Conv2dTransposeFusion : public Conv2dTranspose { +class Conv2dTransposeFusion : public Conv2DTranspose { public: - Conv2dTransposeFusion() : Conv2dTranspose(kNameConv2dTransposeFusion) { + Conv2dTransposeFusion() : Conv2DTranspose(kNameConv2dTransposeFusion) { InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"}); } ~Conv2dTransposeFusion() = default; - MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2dTranspose); + MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2DTranspose); void Init(int64_t in_channel, int64_t out_channel, const std::vector &kernel_size, int64_t mode = 1, const PadMode &pad_mode = VALID, const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1}, const std::vector &dilation = {1, 1}, diff --git a/mindspore/core/ops/grad/conv2d_backprop_filter.cc b/mindspore/core/ops/grad/conv2d_backprop_filter.cc index c6815c5e0d2..1a859a11973 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_filter.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_filter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,20 +27,22 @@ namespace { abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto out_put = input_args[2]->BuildValue(); - auto infer_shape = GetValue>(out_put); - return std::make_shared(infer_shape); + auto prim_name = primitive->name(); + // check + auto w_size_v = input_args[2]->BuildValue(); + auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name); + return std::make_shared(ret_shape); } TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - const std::set valid_types = {kInt8, kInt32, kFloat16, kFloat32}; + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check std::map types; - types.emplace("drotput", input_args[0]->BuildType()); - types.emplace("input_x", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + types.emplace("doutput", input_args[0]->BuildType()); + types.emplace("x", input_args[1]->BuildType()); + std::set valid_x_type = {kInt8, kInt32, kFloat16, kFloat32}; + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name); } } // namespace @@ -142,9 +144,17 @@ Format Conv2DBackpropFilter::get_format() const { AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } return std::make_shared(Conv2DBackpropFilterInferType(primitive, input_args), Conv2DBackpropFilterInferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilter, Conv2DBackpropFilter); +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr, + true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index 24f8f26715d..a36f6c93572 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,23 +18,31 @@ #include #include #include +#include #include "ops/grad/conv2d_backprop_input.h" namespace mindspore { namespace ops { +namespace { void SetPadList(const PrimitivePtr &primitive, const std::vector &dout_shape_norm, const std::vector &x_size_v) { - auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); - auto stride = GetValue>(primitive->GetAttr(kStride)); - auto dilation = GetValue>(primitive->GetAttr(kStride)); - auto pad_list = GetValue>(primitive->GetAttr(kPadList)); - auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); - if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) { - primitive->AddAttr(kPadList, MakeValue(pad_list)); + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + // check + auto kernel_size = + CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name); + auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name); + auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name); + // default pad mode is valid + auto attr_pad_list_prt = primitive->GetAttr(kPadList); + auto pad_mode = GetValue(primitive->GetAttr(kPadMode)); + ShapeVector pad_list = {0, 0, 0, 0}; + if (!attr_pad_list_prt->isa()) { + pad_list = GetValue(attr_pad_list_prt); } else if (pad_mode == SAME) { - auto stride_h = stride[0]; - auto stride_w = stride[1]; + auto stride_h = stride[2]; + auto stride_w = stride[3]; auto kernel_h = kernel_size[0]; auto kernel_w = kernel_size[1]; auto dilation_h = dilation[2]; @@ -43,7 +51,7 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector &dout_ pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h; auto pad_top = pad_needed_h / 2; auto pad_bottom = pad_needed_h - pad_top; - auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2]; + auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3]; pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L; auto pad_left = pad_needed_w / 2; auto pad_right = pad_needed_w - pad_left; @@ -53,34 +61,44 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector &dout_ } primitive->AddAttr(kPadList, MakeValue(pad_list)); } + +abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + auto x_size_v = input_args[2]->BuildValue(); + auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name); + auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat)); + ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]}; + auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape; + SetPadList(primitive, dout_shape_norm, ret_shape); + return std::make_shared(ret_shape); +} + +TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + // check + std::map types; + types.emplace("doutput", input_args[0]->BuildType()); + types.emplace("w", input_args[1]->BuildType()); + std::set valid_x_type = {kInt8, kInt32, kFloat16, kFloat32}; + return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name); +} +} // namespace AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); + // check + CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto doutput = input_args[0]; - auto x_size = input_args[2]; - auto x_size_value = x_size->GetValueTrack(); - MS_EXCEPTION_IF_NULL(x_size); - auto x_size_v = GetValue>(x_size_value); - // infer dtype - auto dtype = doutput->BuildType(); - if (!dtype->isa()) { - MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); - } - auto input_tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(input_tensor_type); - auto element = input_tensor_type->element(); - // infer shape - auto dout_shape = doutput->BuildShape(); - MS_EXCEPTION_IF_NULL(doutput); - auto dout_shapeptr = dout_shape->cast(); - auto dout_shape_norm = dout_shapeptr->shape(); - SetPadList(primitive, dout_shape_norm, x_size_v); - return std::make_shared(element, std::make_shared(x_size_v)); + auto abs = std::make_shared(Conv2DBackpropInputInferType(primitive, input_args), + Conv2DBackpropInputInferShape(primitive, input_args)); + return abs; } void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector &kernel_size, int64_t mode, @@ -200,6 +218,7 @@ std::vector Conv2DBackpropInput::get_pad_list() const { auto value_ptr = GetAttr(kPadList); return GetValue>(value_ptr); } -REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput); +REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr, + true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.h b/mindspore/core/ops/grad/conv2d_backprop_input.h index aeb9e67a7ee..3ccb736a9ec 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.h +++ b/mindspore/core/ops/grad/conv2d_backprop_input.h @@ -60,6 +60,8 @@ class Conv2DBackpropInput : public PrimitiveC { Format get_format() const; std::vector get_pad_list() const; }; +AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimConv2DBackpropInputPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index f6e4a15b4b3..746cb910b46 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -109,6 +109,7 @@ static std::map ReductionMap = { static std::map> PrimAttrConvertMap = { {"Conv2D", FormatAndPadAttrMap}, + {"Conv2DTranspose", FormatAndPadUpperAttrMap}, {"Conv2DBackpropInput", FormatAndPadUpperAttrMap}, {"Conv2DBackpropFilter", FormatAndPadUpperAttrMap}, {"Conv3D", FormatAndPadAttrMap}, diff --git a/mindspore/lite/tools/converter/import/primitive_adjust.cc b/mindspore/lite/tools/converter/import/primitive_adjust.cc index 4f587871671..f88ec599efa 100644 --- a/mindspore/lite/tools/converter/import/primitive_adjust.cc +++ b/mindspore/lite/tools/converter/import/primitive_adjust.cc @@ -83,7 +83,7 @@ using mindspore::ops::kNameBatchNorm; using mindspore::ops::kNameConv2D; using mindspore::ops::kNameConv2DBackpropFilter; using mindspore::ops::kNameConv2DBackpropInput; -using mindspore::ops::kNameConv2dTranspose; +using mindspore::ops::kNameConv2DTranspose; using mindspore::ops::kNameDiv; using mindspore::ops::kNameElu; using mindspore::ops::kNameExp; @@ -573,7 +573,7 @@ REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D) -REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon) +REGIST_PRIMITIVE_ADJUST(kNameConv2DTranspose, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation) REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 91a8ede9242..d878c6a9238 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -1065,8 +1065,9 @@ def get_bprop_roi_align(self): @bprop_getters.register(P.Conv2DBackpropInput) def get_bprop_conv2d_backprop_input(self): """Grad definition for `Conv2DBackpropInput` operation.""" + pad_list = self.get_attr_dict()['pad_list'] filter_grad = G.Conv2DBackpropFilter( - self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode, + self.out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode, dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format ) input_grad = P.Conv2D( diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index b0818af3416..6d2950bc641 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -441,7 +441,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer): return out -class Conv2DBackpropFilter(PrimitiveWithInfer): +class Conv2DBackpropFilter(Primitive): """ Computes the gradients of convolution with respect to the filter. @@ -500,21 +500,6 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): raise ValueError("NHWC format only support in GPU target.") self.add_prim_attr('data_format', self.format) - def __infer__(self, doutput, x, w_size): - w_size_v = w_size['value'] - validator.check_value_type('w_size', w_size_v, [tuple], self.name) - for i, dim_len in enumerate(w_size_v): - validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) - args = {"x": x['dtype'], "doutput": doutput['dtype']} - validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], - self.name) - out = { - 'value': None, - 'shape': w_size_v, - 'dtype': doutput['dtype'], - } - return out - class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): """ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a6fd439a6ee..e1b09e24b02 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1903,7 +1903,7 @@ class AvgPool(_Pool): super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format) -class Conv2DBackpropInput(PrimitiveWithInfer): +class Conv2DBackpropInput(Primitive): """ Computes the gradients of convolution with respect to the input. @@ -2017,48 +2017,6 @@ class Conv2DBackpropInput(PrimitiveWithInfer): validator.check_non_negative_int(x, 'element of pad_list', self.name) self.pad_list = pad_list - def __infer__(self, doutput, w, x_size): - x_size_v = x_size['value'] - validator.check_value_type('x_size', x_size_v, [tuple], self.name) - for i, dim_len in enumerate(x_size_v): - validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) - args = {'doutput': doutput['dtype'], 'w': w['dtype']} - valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - - # infer shape - dout_shape = doutput['shape'] - dout_shape_norm = dout_shape if self.format == "NCHW" else \ - [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] - kernel_h = self.kernel_size[0] - kernel_w = self.kernel_size[1] - stride_h = self.stride[2] - stride_w = self.stride[3] - dilation_h = self.dilation[2] - dilation_w = self.dilation[3] - # default pad mode is valid - pad_list = (0, 0, 0, 0) - if self.pad_list: - pad_list = tuple(self.pad_list) - elif self.pad_mode == "SAME": - pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]) - pad_top = math.floor(pad_needed_h / 2) - pad_bottom = pad_needed_h - pad_top - - pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3]) - pad_left = math.floor(pad_needed_w / 2) - pad_right = pad_needed_w - pad_left - pad_list = (pad_top, pad_bottom, pad_left, pad_right) - elif self.pad_mode == 'PAD': - pad_list = self.padding - self.add_prim_attr('pad_list', pad_list) - out = { - 'value': None, - 'shape': x_size_v, - 'dtype': doutput['dtype'], - } - return out - class Conv2DTranspose(Conv2DBackpropInput): """ diff --git a/mindspore/ops/operations/sparse_ops.py b/mindspore/ops/operations/sparse_ops.py index 982980e64e0..7d04ba2411b 100644 --- a/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/ops/operations/sparse_ops.py @@ -31,7 +31,7 @@ class SparseToDense(PrimitiveWithInfer): - **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor. Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`. - **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`. - The shape should be :math:`(n,). + The shape should be :math:`(n,)`. - **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor, should have 2 elements, represent sparse tensor shape is :math:`(N, C)`. @@ -102,7 +102,7 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer): - **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor. Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`. - **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`. - Support float16, float32, float64, int32, int64. The shape should be :math:`(n,). + Support float16, float32, float64, int32, int64. The shape should be :math:`(n,)`. - **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor, should have 2 elements, represent sparse tensor shape is :math:`(N, C)`. - **dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.