conv2dbackpropinput & conv2dbackpropfilter c++ infer

This commit is contained in:
simson 2021-06-04 11:26:23 +08:00
parent 2ec087e5bd
commit 0b0e9580d7
13 changed files with 123 additions and 173 deletions

View File

@ -83,6 +83,7 @@ constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kDropoutDoMask = "DropoutDoMask";
constexpr auto kDropout = "Dropout";
constexpr auto kDropoutGrad = "DropoutGrad";
constexpr auto kConv2DTranspose = "Conv2DTranspose";
// Here list all primitives used in backend or some special primitives used by core.
// Arithmetic
@ -262,7 +263,7 @@ inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("Fus
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
inline const PrimitivePtr kPrimCTCLoss = std::make_shared<Primitive>(kCTCLoss);
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose");
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>(kConv2DTranspose);
inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput");
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
@ -583,7 +584,6 @@ inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("Space
inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion");
inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion");
inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize");
inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose");
inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue");
inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If");
inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion");

View File

@ -21,31 +21,11 @@
#include <set>
#include "ops/conv2d_transpose.h"
#include "ops/grad/conv2d_backprop_input.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape);
}
TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("doutput_dtye", input_args[0]->BuildType());
types.emplace("w_dtype", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
void Conv2DTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
const Format &format, const std::vector<int64_t> &pad_list) {
@ -62,16 +42,16 @@ void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::v
set_pad_list(pad_list);
}
void Conv2dTranspose::set_in_channel(int64_t in_channel) {
void Conv2DTranspose::set_in_channel(int64_t in_channel) {
AddAttr(kInChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_out_channel(int64_t out_channel) {
void Conv2DTranspose::set_out_channel(int64_t out_channel) {
AddAttr(kOutChannel,
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
for (int64_t item : kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
@ -79,7 +59,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
AddAttr(kKernelSize, MakeValue(kernel_size));
}
void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
for (int64_t item : stride) {
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
@ -87,12 +67,12 @@ void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(stride));
}
void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
void Conv2DTranspose::set_dilation(const std::vector<int64_t> &dilation) {
CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
AddAttr(kDilation, MakeValue(dilation));
}
void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
std::vector<int64_t> pad = get_pad();
if (pad_mode == PAD) {
for (auto item : pad) {
@ -105,89 +85,83 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
AddAttr(kPadMode, MakeValue(swi));
}
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
void Conv2DTranspose::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}
void Conv2dTranspose::set_mode(int64_t mode) {
void Conv2DTranspose::set_mode(int64_t mode) {
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
}
void Conv2dTranspose::set_group(int64_t group) {
void Conv2DTranspose::set_group(int64_t group) {
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_format(const Format &format) {
void Conv2DTranspose::set_format(const Format &format) {
int64_t f = format;
AddAttr(kFormat, MakeValue(f));
}
void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
void Conv2DTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
this->AddAttr(kPadList, MakeValue(pad_list));
}
int64_t Conv2dTranspose::get_in_channel() const {
int64_t Conv2DTranspose::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2dTranspose::get_out_channel() const {
int64_t Conv2DTranspose::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_kernel_size() const {
std::vector<int64_t> Conv2DTranspose::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_stride() const {
std::vector<int64_t> Conv2DTranspose::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_dilation() const {
std::vector<int64_t> Conv2DTranspose::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
PadMode Conv2dTranspose::get_pad_mode() const {
PadMode Conv2DTranspose::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2dTranspose::get_pad() const {
std::vector<int64_t> Conv2DTranspose::get_pad() const {
auto value_ptr = GetAttr(kPad);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Conv2dTranspose::get_mode() const {
int64_t Conv2DTranspose::get_mode() const {
auto value_ptr = GetAttr(kMode);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2dTranspose::get_group() const {
int64_t Conv2DTranspose::get_group() const {
auto value_ptr = GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
Format Conv2dTranspose::get_format() const {
Format Conv2DTranspose::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2dTranspose::get_pad_list() const {
std::vector<int64_t> Conv2DTranspose::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dTransposeInferType(primitive, input_args),
Conv2dTransposeInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameConv2dTranspose, Conv2dTranspose);
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DTranspose, prim::kPrimConv2DTranspose, Conv2DBackpropInputInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -27,17 +27,17 @@
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConv2dTranspose = "Conv2dTranspose";
class Conv2dTranspose : public PrimitiveC {
constexpr auto kNameConv2DTranspose = "Conv2DTranspose";
class Conv2DTranspose : public PrimitiveC {
public:
Conv2dTranspose() : PrimitiveC(kNameConv2dTranspose) {
Conv2DTranspose() : PrimitiveC(kNameConv2DTranspose) {
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
}
explicit Conv2dTranspose(const std::string k_name) : PrimitiveC(k_name) {
explicit Conv2DTranspose(const std::string k_name) : PrimitiveC(k_name) {
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
}
~Conv2dTranspose() = default;
MS_DECLARE_PARENT(Conv2dTranspose, PrimitiveC);
~Conv2DTranspose() = default;
MS_DECLARE_PARENT(Conv2DTranspose, PrimitiveC);
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
@ -66,9 +66,9 @@ class Conv2dTranspose : public PrimitiveC {
Format get_format() const;
std::vector<int64_t> get_pad_list() const;
};
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr Conv2DTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dTransposePtr = std::shared_ptr<Conv2dTranspose>;
using PrimConv2DTransposePtr = std::shared_ptr<Conv2DTranspose>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_

View File

@ -25,13 +25,13 @@
namespace mindspore {
namespace ops {
constexpr auto kNameConv2dTransposeFusion = "Conv2dTransposeFusion";
class Conv2dTransposeFusion : public Conv2dTranspose {
class Conv2dTransposeFusion : public Conv2DTranspose {
public:
Conv2dTransposeFusion() : Conv2dTranspose(kNameConv2dTransposeFusion) {
Conv2dTransposeFusion() : Conv2DTranspose(kNameConv2dTransposeFusion) {
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
}
~Conv2dTransposeFusion() = default;
MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2dTranspose);
MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2DTranspose);
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,20 +27,22 @@ namespace {
abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto out_put = input_args[2]->BuildValue();
auto infer_shape = GetValue<std::vector<int64_t>>(out_put);
return std::make_shared<abstract::Shape>(infer_shape);
auto prim_name = primitive->name();
// check
auto w_size_v = input_args[2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name);
return std::make_shared<abstract::Shape>(ret_shape);
}
TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::map<std::string, TypePtr> types;
types.emplace("drotput", input_args[0]->BuildType());
types.emplace("input_x", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
types.emplace("doutput", input_args[0]->BuildType());
types.emplace("x", input_args[1]->BuildType());
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
}
} // namespace
@ -142,9 +144,17 @@ Format Conv2DBackpropFilter::get_format() const {
AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
Conv2DBackpropFilterInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilter, Conv2DBackpropFilter);
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,23 +18,31 @@
#include <vector>
#include <string>
#include <memory>
#include <set>
#include "ops/grad/conv2d_backprop_input.h"
namespace mindspore {
namespace ops {
namespace {
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
const std::vector<int64_t> &x_size_v) {
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) {
primitive->AddAttr(kPadList, MakeValue(pad_list));
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto kernel_size =
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
auto pad_mode = GetValue<int64_t>(primitive->GetAttr(kPadMode));
ShapeVector pad_list = {0, 0, 0, 0};
if (!attr_pad_list_prt->isa<None>()) {
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
} else if (pad_mode == SAME) {
auto stride_h = stride[0];
auto stride_w = stride[1];
auto stride_h = stride[2];
auto stride_w = stride[3];
auto kernel_h = kernel_size[0];
auto kernel_w = kernel_size[1];
auto dilation_h = dilation[2];
@ -43,7 +51,7 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
auto pad_top = pad_needed_h / 2;
auto pad_bottom = pad_needed_h - pad_top;
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2];
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3];
pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
auto pad_left = pad_needed_w / 2;
auto pad_right = pad_needed_w - pad_left;
@ -53,34 +61,44 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
}
primitive->AddAttr(kPadList, MakeValue(pad_list));
}
abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_size_v = input_args[2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name);
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
SetPadList(primitive, dout_shape_norm, ret_shape);
return std::make_shared<abstract::Shape>(ret_shape);
}
TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::map<std::string, TypePtr> types;
types.emplace("doutput", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
}
} // namespace
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
// check
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto doutput = input_args[0];
auto x_size = input_args[2];
auto x_size_value = x_size->GetValueTrack();
MS_EXCEPTION_IF_NULL(x_size);
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
// infer dtype
auto dtype = doutput->BuildType();
if (!dtype->isa<TensorType>()) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
}
auto input_tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_tensor_type);
auto element = input_tensor_type->element();
// infer shape
auto dout_shape = doutput->BuildShape();
MS_EXCEPTION_IF_NULL(doutput);
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
auto dout_shape_norm = dout_shapeptr->shape();
SetPadList(primitive, dout_shape_norm, x_size_v);
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
Conv2DBackpropInputInferShape(primitive, input_args));
return abs;
}
void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
@ -200,6 +218,7 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput);
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -60,6 +60,8 @@ class Conv2DBackpropInput : public PrimitiveC {
Format get_format() const;
std::vector<int64_t> get_pad_list() const;
};
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2DBackpropInputPtr = std::shared_ptr<Conv2DBackpropInput>;
} // namespace ops
} // namespace mindspore

View File

@ -109,6 +109,7 @@ static std::map<std::string, AttrConverterPair> ReductionMap = {
static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
{"Conv2D", FormatAndPadAttrMap},
{"Conv2DTranspose", FormatAndPadUpperAttrMap},
{"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
{"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
{"Conv3D", FormatAndPadAttrMap},

View File

@ -83,7 +83,7 @@ using mindspore::ops::kNameBatchNorm;
using mindspore::ops::kNameConv2D;
using mindspore::ops::kNameConv2DBackpropFilter;
using mindspore::ops::kNameConv2DBackpropInput;
using mindspore::ops::kNameConv2dTranspose;
using mindspore::ops::kNameConv2DTranspose;
using mindspore::ops::kNameDiv;
using mindspore::ops::kNameElu;
using mindspore::ops::kNameExp;
@ -573,7 +573,7 @@ REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
REGIST_PRIMITIVE_ADJUST(kNameConv2DTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>)
REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad)

View File

@ -1065,8 +1065,9 @@ def get_bprop_roi_align(self):
@bprop_getters.register(P.Conv2DBackpropInput)
def get_bprop_conv2d_backprop_input(self):
"""Grad definition for `Conv2DBackpropInput` operation."""
pad_list = self.get_attr_dict()['pad_list']
filter_grad = G.Conv2DBackpropFilter(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
self.out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
input_grad = P.Conv2D(

View File

@ -441,7 +441,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
return out
class Conv2DBackpropFilter(PrimitiveWithInfer):
class Conv2DBackpropFilter(Primitive):
"""
Computes the gradients of convolution with respect to the filter.
@ -500,21 +500,6 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = {
'value': None,
'shape': w_size_v,
'dtype': doutput['dtype'],
}
return out
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
"""

View File

@ -1903,7 +1903,7 @@ class AvgPool(_Pool):
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
class Conv2DBackpropInput(PrimitiveWithInfer):
class Conv2DBackpropInput(Primitive):
"""
Computes the gradients of convolution with respect to the input.
@ -2017,48 +2017,6 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
validator.check_non_negative_int(x, 'element of pad_list', self.name)
self.pad_list = pad_list
def __infer__(self, doutput, w, x_size):
x_size_v = x_size['value']
validator.check_value_type('x_size', x_size_v, [tuple], self.name)
for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
# infer shape
dout_shape = doutput['shape']
dout_shape_norm = dout_shape if self.format == "NCHW" else \
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
# default pad mode is valid
pad_list = (0, 0, 0, 0)
if self.pad_list:
pad_list = tuple(self.pad_list)
elif self.pad_mode == "SAME":
pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
elif self.pad_mode == 'PAD':
pad_list = self.padding
self.add_prim_attr('pad_list', pad_list)
out = {
'value': None,
'shape': x_size_v,
'dtype': doutput['dtype'],
}
return out
class Conv2DTranspose(Conv2DBackpropInput):
"""

View File

@ -31,7 +31,7 @@ class SparseToDense(PrimitiveWithInfer):
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
The shape should be :math:`(n,).
The shape should be :math:`(n,)`.
- **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
@ -102,7 +102,7 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer):
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
Support float16, float32, float64, int32, int64. The shape should be :math:`(n,).
Support float16, float32, float64, int32, int64. The shape should be :math:`(n,)`.
- **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
- **dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.