forked from mindspore-Ecosystem/mindspore
conv2dbackpropinput & conv2dbackpropfilter c++ infer
This commit is contained in:
parent
2ec087e5bd
commit
0b0e9580d7
|
@ -83,6 +83,7 @@ constexpr auto kDropoutGenMask = "DropoutGenMask";
|
|||
constexpr auto kDropoutDoMask = "DropoutDoMask";
|
||||
constexpr auto kDropout = "Dropout";
|
||||
constexpr auto kDropoutGrad = "DropoutGrad";
|
||||
constexpr auto kConv2DTranspose = "Conv2DTranspose";
|
||||
|
||||
// Here list all primitives used in backend or some special primitives used by core.
|
||||
// Arithmetic
|
||||
|
@ -262,7 +263,7 @@ inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("Fus
|
|||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimCTCLoss = std::make_shared<Primitive>(kCTCLoss);
|
||||
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
|
||||
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose");
|
||||
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>(kConv2DTranspose);
|
||||
inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
|
@ -583,7 +584,6 @@ inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("Space
|
|||
inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion");
|
||||
inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion");
|
||||
inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize");
|
||||
inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose");
|
||||
inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue");
|
||||
inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If");
|
||||
inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion");
|
||||
|
|
|
@ -21,31 +21,11 @@
|
|||
#include <set>
|
||||
|
||||
#include "ops/conv2d_transpose.h"
|
||||
#include "ops/grad/conv2d_backprop_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("doutput_dtye", input_args[0]->BuildType());
|
||||
types.emplace("w_dtype", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
void Conv2DTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
||||
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
|
||||
const Format &format, const std::vector<int64_t> &pad_list) {
|
||||
|
@ -62,16 +42,16 @@ void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::v
|
|||
set_pad_list(pad_list);
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_in_channel(int64_t in_channel) {
|
||||
void Conv2DTranspose::set_in_channel(int64_t in_channel) {
|
||||
AddAttr(kInChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_out_channel(int64_t out_channel) {
|
||||
void Conv2DTranspose::set_out_channel(int64_t out_channel) {
|
||||
AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
|
@ -79,7 +59,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
|||
AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
|
||||
for (int64_t item : stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
|
||||
|
@ -87,12 +67,12 @@ void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
|
|||
AddAttr(kStride, MakeValue(stride));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
void Conv2DTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
|
||||
AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
|
||||
void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
|
||||
std::vector<int64_t> pad = get_pad();
|
||||
if (pad_mode == PAD) {
|
||||
for (auto item : pad) {
|
||||
|
@ -105,89 +85,83 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
|
|||
AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
void Conv2DTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_mode(int64_t mode) {
|
||||
void Conv2DTranspose::set_mode(int64_t mode) {
|
||||
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_group(int64_t group) {
|
||||
void Conv2DTranspose::set_group(int64_t group) {
|
||||
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_format(const Format &format) {
|
||||
void Conv2DTranspose::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
void Conv2DTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
|
||||
this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_in_channel() const {
|
||||
int64_t Conv2DTranspose::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_out_channel() const {
|
||||
int64_t Conv2DTranspose::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_kernel_size() const {
|
||||
std::vector<int64_t> Conv2DTranspose::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_stride() const {
|
||||
std::vector<int64_t> Conv2DTranspose::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_dilation() const {
|
||||
std::vector<int64_t> Conv2DTranspose::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
PadMode Conv2dTranspose::get_pad_mode() const {
|
||||
PadMode Conv2DTranspose::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_pad() const {
|
||||
std::vector<int64_t> Conv2DTranspose::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_mode() const {
|
||||
int64_t Conv2DTranspose::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_group() const {
|
||||
int64_t Conv2DTranspose::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
Format Conv2dTranspose::get_format() const {
|
||||
Format Conv2DTranspose::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_pad_list() const {
|
||||
std::vector<int64_t> Conv2DTranspose::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2dTransposeInferType(primitive, input_args),
|
||||
Conv2dTransposeInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2dTranspose, Conv2dTranspose);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DTranspose, prim::kPrimConv2DTranspose, Conv2DBackpropInputInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,17 +27,17 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConv2dTranspose = "Conv2dTranspose";
|
||||
class Conv2dTranspose : public PrimitiveC {
|
||||
constexpr auto kNameConv2DTranspose = "Conv2DTranspose";
|
||||
class Conv2DTranspose : public PrimitiveC {
|
||||
public:
|
||||
Conv2dTranspose() : PrimitiveC(kNameConv2dTranspose) {
|
||||
Conv2DTranspose() : PrimitiveC(kNameConv2DTranspose) {
|
||||
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
|
||||
}
|
||||
explicit Conv2dTranspose(const std::string k_name) : PrimitiveC(k_name) {
|
||||
explicit Conv2DTranspose(const std::string k_name) : PrimitiveC(k_name) {
|
||||
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
|
||||
}
|
||||
~Conv2dTranspose() = default;
|
||||
MS_DECLARE_PARENT(Conv2dTranspose, PrimitiveC);
|
||||
~Conv2DTranspose() = default;
|
||||
MS_DECLARE_PARENT(Conv2DTranspose, PrimitiveC);
|
||||
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
|
||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
|
||||
|
@ -66,9 +66,9 @@ class Conv2dTranspose : public PrimitiveC {
|
|||
Format get_format() const;
|
||||
std::vector<int64_t> get_pad_list() const;
|
||||
};
|
||||
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr Conv2DTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConv2dTransposePtr = std::shared_ptr<Conv2dTranspose>;
|
||||
using PrimConv2DTransposePtr = std::shared_ptr<Conv2DTranspose>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
|
||||
|
|
|
@ -25,13 +25,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConv2dTransposeFusion = "Conv2dTransposeFusion";
|
||||
class Conv2dTransposeFusion : public Conv2dTranspose {
|
||||
class Conv2dTransposeFusion : public Conv2DTranspose {
|
||||
public:
|
||||
Conv2dTransposeFusion() : Conv2dTranspose(kNameConv2dTransposeFusion) {
|
||||
Conv2dTransposeFusion() : Conv2DTranspose(kNameConv2dTransposeFusion) {
|
||||
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
|
||||
}
|
||||
~Conv2dTransposeFusion() = default;
|
||||
MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2dTranspose);
|
||||
MS_DECLARE_PARENT(Conv2dTransposeFusion, Conv2DTranspose);
|
||||
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
|
||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,20 +27,22 @@ namespace {
|
|||
abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto out_put = input_args[2]->BuildValue();
|
||||
auto infer_shape = GetValue<std::vector<int64_t>>(out_put);
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto w_size_v = input_args[2]->BuildValue();
|
||||
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name);
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("drotput", input_args[0]->BuildType());
|
||||
types.emplace("input_x", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
types.emplace("doutput", input_args[0]->BuildType());
|
||||
types.emplace("x", input_args[1]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -142,9 +144,17 @@ Format Conv2DBackpropFilter::get_format() const {
|
|||
|
||||
AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilter, Conv2DBackpropFilter);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,23 +18,31 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "ops/grad/conv2d_backprop_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
|
||||
const std::vector<int64_t> &x_size_v) {
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
||||
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
||||
auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) {
|
||||
primitive->AddAttr(kPadList, MakeValue(pad_list));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto kernel_size =
|
||||
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
|
||||
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
|
||||
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
|
||||
// default pad mode is valid
|
||||
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
|
||||
auto pad_mode = GetValue<int64_t>(primitive->GetAttr(kPadMode));
|
||||
ShapeVector pad_list = {0, 0, 0, 0};
|
||||
if (!attr_pad_list_prt->isa<None>()) {
|
||||
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
|
||||
} else if (pad_mode == SAME) {
|
||||
auto stride_h = stride[0];
|
||||
auto stride_w = stride[1];
|
||||
auto stride_h = stride[2];
|
||||
auto stride_w = stride[3];
|
||||
auto kernel_h = kernel_size[0];
|
||||
auto kernel_w = kernel_size[1];
|
||||
auto dilation_h = dilation[2];
|
||||
|
@ -43,7 +51,7 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
|
|||
pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
|
||||
auto pad_top = pad_needed_h / 2;
|
||||
auto pad_bottom = pad_needed_h - pad_top;
|
||||
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2];
|
||||
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3];
|
||||
pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
|
||||
auto pad_left = pad_needed_w / 2;
|
||||
auto pad_right = pad_needed_w - pad_left;
|
||||
|
@ -53,34 +61,44 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
|
|||
}
|
||||
primitive->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_size_v = input_args[2]->BuildValue();
|
||||
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name);
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
|
||||
ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
|
||||
auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
|
||||
SetPadList(primitive, dout_shape_norm, ret_shape);
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("doutput", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
|
||||
// check
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto doutput = input_args[0];
|
||||
auto x_size = input_args[2];
|
||||
auto x_size_value = x_size->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(x_size);
|
||||
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
|
||||
// infer dtype
|
||||
auto dtype = doutput->BuildType();
|
||||
if (!dtype->isa<TensorType>()) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
|
||||
}
|
||||
auto input_tensor_type = dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_type);
|
||||
auto element = input_tensor_type->element();
|
||||
// infer shape
|
||||
auto dout_shape = doutput->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(doutput);
|
||||
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
|
||||
auto dout_shape_norm = dout_shapeptr->shape();
|
||||
SetPadList(primitive, dout_shape_norm, x_size_v);
|
||||
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
|
||||
Conv2DBackpropInputInferShape(primitive, input_args));
|
||||
return abs;
|
||||
}
|
||||
|
||||
void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
|
||||
|
@ -200,6 +218,7 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
|
|||
auto value_ptr = GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,8 @@ class Conv2DBackpropInput : public PrimitiveC {
|
|||
Format get_format() const;
|
||||
std::vector<int64_t> get_pad_list() const;
|
||||
};
|
||||
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConv2DBackpropInputPtr = std::shared_ptr<Conv2DBackpropInput>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -109,6 +109,7 @@ static std::map<std::string, AttrConverterPair> ReductionMap = {
|
|||
|
||||
static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
|
||||
{"Conv2D", FormatAndPadAttrMap},
|
||||
{"Conv2DTranspose", FormatAndPadUpperAttrMap},
|
||||
{"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
|
||||
{"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
|
||||
{"Conv3D", FormatAndPadAttrMap},
|
||||
|
|
|
@ -83,7 +83,7 @@ using mindspore::ops::kNameBatchNorm;
|
|||
using mindspore::ops::kNameConv2D;
|
||||
using mindspore::ops::kNameConv2DBackpropFilter;
|
||||
using mindspore::ops::kNameConv2DBackpropInput;
|
||||
using mindspore::ops::kNameConv2dTranspose;
|
||||
using mindspore::ops::kNameConv2DTranspose;
|
||||
using mindspore::ops::kNameDiv;
|
||||
using mindspore::ops::kNameElu;
|
||||
using mindspore::ops::kNameExp;
|
||||
|
@ -573,7 +573,7 @@ REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
|||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DTranspose, MoveAttrMapCommon<ops::Conv2dTransposeFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon<ops::DivFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad)
|
||||
|
|
|
@ -1065,8 +1065,9 @@ def get_bprop_roi_align(self):
|
|||
@bprop_getters.register(P.Conv2DBackpropInput)
|
||||
def get_bprop_conv2d_backprop_input(self):
|
||||
"""Grad definition for `Conv2DBackpropInput` operation."""
|
||||
pad_list = self.get_attr_dict()['pad_list']
|
||||
filter_grad = G.Conv2DBackpropFilter(
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
|
||||
self.out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
|
||||
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
||||
)
|
||||
input_grad = P.Conv2D(
|
||||
|
|
|
@ -441,7 +441,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Conv2DBackpropFilter(PrimitiveWithInfer):
|
||||
class Conv2DBackpropFilter(Primitive):
|
||||
"""
|
||||
Computes the gradients of convolution with respect to the filter.
|
||||
|
||||
|
@ -500,21 +500,6 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
|
|||
raise ValueError("NHWC format only support in GPU target.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
def __infer__(self, doutput, x, w_size):
|
||||
w_size_v = w_size['value']
|
||||
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
|
||||
for i, dim_len in enumerate(w_size_v):
|
||||
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
|
||||
args = {"x": x['dtype'], "doutput": doutput['dtype']}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
|
||||
self.name)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': w_size_v,
|
||||
'dtype': doutput['dtype'],
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -1903,7 +1903,7 @@ class AvgPool(_Pool):
|
|||
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
|
||||
|
||||
class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||
class Conv2DBackpropInput(Primitive):
|
||||
"""
|
||||
Computes the gradients of convolution with respect to the input.
|
||||
|
||||
|
@ -2017,48 +2017,6 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|||
validator.check_non_negative_int(x, 'element of pad_list', self.name)
|
||||
self.pad_list = pad_list
|
||||
|
||||
def __infer__(self, doutput, w, x_size):
|
||||
x_size_v = x_size['value']
|
||||
validator.check_value_type('x_size', x_size_v, [tuple], self.name)
|
||||
for i, dim_len in enumerate(x_size_v):
|
||||
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
|
||||
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
|
||||
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
|
||||
# infer shape
|
||||
dout_shape = doutput['shape']
|
||||
dout_shape_norm = dout_shape if self.format == "NCHW" else \
|
||||
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
|
||||
kernel_h = self.kernel_size[0]
|
||||
kernel_w = self.kernel_size[1]
|
||||
stride_h = self.stride[2]
|
||||
stride_w = self.stride[3]
|
||||
dilation_h = self.dilation[2]
|
||||
dilation_w = self.dilation[3]
|
||||
# default pad mode is valid
|
||||
pad_list = (0, 0, 0, 0)
|
||||
if self.pad_list:
|
||||
pad_list = tuple(self.pad_list)
|
||||
elif self.pad_mode == "SAME":
|
||||
pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
|
||||
pad_top = math.floor(pad_needed_h / 2)
|
||||
pad_bottom = pad_needed_h - pad_top
|
||||
|
||||
pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
|
||||
pad_left = math.floor(pad_needed_w / 2)
|
||||
pad_right = pad_needed_w - pad_left
|
||||
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
|
||||
elif self.pad_mode == 'PAD':
|
||||
pad_list = self.padding
|
||||
self.add_prim_attr('pad_list', pad_list)
|
||||
out = {
|
||||
'value': None,
|
||||
'shape': x_size_v,
|
||||
'dtype': doutput['dtype'],
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class Conv2DTranspose(Conv2DBackpropInput):
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class SparseToDense(PrimitiveWithInfer):
|
|||
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
|
||||
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
|
||||
The shape should be :math:`(n,).
|
||||
The shape should be :math:`(n,)`.
|
||||
- **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
|
||||
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
|
||||
|
||||
|
@ -102,7 +102,7 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer):
|
|||
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
|
||||
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
|
||||
Support float16, float32, float64, int32, int64. The shape should be :math:`(n,).
|
||||
Support float16, float32, float64, int32, int64. The shape should be :math:`(n,)`.
|
||||
- **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
|
||||
should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
|
||||
- **dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
|
||||
|
|
Loading…
Reference in New Issue