Revert "!16599 c++ infer for conv2dbackpropfilter and conv2dbackpropinput"

This reverts commit 3be79efd80, reversing
changes made to cf4479756a.
This commit is contained in:
changzherui 2021-05-26 17:50:21 +08:00
parent f22e0522fe
commit d9e2da299d
14 changed files with 343 additions and 87 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include "backend/optimizer/pass/add_training_attr.h"
#include "backend/optimizer/pass/optimize_updatestate.h"
#include "backend/optimizer/pass/conv_transpose_to_conv_bp.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"
@ -43,6 +44,7 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto common_pm = std::make_shared<PassManager>("common_pm");
common_pm->AddPass(std::make_shared<ConvTransposeToConvBackpropInputPass>());
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
common_pm->AddPass(std::make_shared<ConvertAttrToUnifyMindIR>());
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/pass/conv_transpose_to_conv_bp.h"
#include <memory>
#include <vector>
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kCNodePrimitiveIdx = 0;
} // namespace
const BaseRef ConvTransposeToConvBackpropInputPass::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto conv_transpose = std::make_shared<Primitive>(kConv2DTransposeOpName);
return VectorRef({conv_transpose, Xs});
}
const AnfNodePtr ConvTransposeToConvBackpropInputPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto conv_transpose = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_transpose);
if (conv_transpose->size() <= kCNodePrimitiveIdx) {
MS_LOG(EXCEPTION) << "Invalid cnode " << node->DebugString() << " input size " << conv_transpose->size();
}
auto prim = GetValueNode<PrimitivePtr>(conv_transpose->input(kCNodePrimitiveIdx));
MS_EXCEPTION_IF_NULL(prim);
prim->Named::operator=(Named(kConv2DBackpropInputOpName));
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONV_TRANSPOSE_TO_CONV_BP_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONV_TRANSPOSE_TO_CONV_BP_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ConvTransposeToConvBackpropInputPass : public PatternProcessPass {
public:
explicit ConvTransposeToConvBackpropInputPass(bool multigraph = true)
: PatternProcessPass("conv_transpose_to_conv_backprop_input", multigraph) {}
~ConvTransposeToConvBackpropInputPass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_

View File

@ -49,6 +49,7 @@ ATTR_MAP(Conv2DBackpropInputD) = {
};
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD))
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD))
// Conv2DBackpropFilterD
INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}};

View File

@ -175,6 +175,7 @@ constexpr auto kSpaceToBatchOpName = "SpaceToBatch";
constexpr auto kBatchToSpaceOpName = "BatchToSpace";
constexpr auto kSpaceToDepthOpName = "SpaceToDepth";
constexpr auto kPadOpName = "Pad";
constexpr auto kConv2DTransposeOpName = "Conv2DTranspose";
constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput";
constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter";
constexpr auto kDepthwiseConv2dNativeOpName = "DepthwiseConv2dNative";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,22 +27,20 @@ namespace {
abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto w_size_v = input_args[2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("w_size", w_size_v, prim_name);
return std::make_shared<abstract::Shape>(ret_shape);
auto out_put = input_args[2]->BuildValue();
auto infer_shape = GetValue<std::vector<int64_t>>(out_put);
return std::make_shared<abstract::Shape>(infer_shape);
}
TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("doutput", input_args[0]->BuildType());
types.emplace("x", input_args[1]->BuildType());
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
types.emplace("drotput", input_args[0]->BuildType());
types.emplace("input_x", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
@ -144,17 +142,9 @@ Format Conv2DBackpropFilter::get_format() const {
AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
Conv2DBackpropFilterInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
true);
REGISTER_PRIMITIVE_C(kNameConv2DBackpropFilter, Conv2DBackpropFilter);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,25 +23,18 @@
namespace mindspore {
namespace ops {
namespace {
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
const std::vector<int64_t> &x_size_v) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
auto kernel_size =
CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
// default pad mode is valid
auto attr_pad_list_prt = primitive->GetAttr(kPadList);
auto pad_mode = GetValue<int64_t>(primitive->GetAttr(kPadMode));
ShapeVector pad_list = {0, 0, 0, 0};
if (!attr_pad_list_prt->isa<None>()) {
pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPadList));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
if (std::all_of(pad_list.begin(), pad_list.end(), [](int64_t elem) -> bool { return elem != 0; })) {
primitive->AddAttr(kPadList, MakeValue(pad_list));
} else if (pad_mode == SAME) {
auto stride_h = stride[2];
auto stride_w = stride[3];
auto stride_h = stride[0];
auto stride_w = stride[1];
auto kernel_h = kernel_size[0];
auto kernel_w = kernel_size[1];
auto dilation_h = dilation[2];
@ -50,7 +43,7 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
auto pad_top = pad_needed_h / 2;
auto pad_bottom = pad_needed_h - pad_top;
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3];
auto pad_needed_w = (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[2];
pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
auto pad_left = pad_needed_w / 2;
auto pad_right = pad_needed_w - pad_left;
@ -60,44 +53,34 @@ void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_
}
primitive->AddAttr(kPadList, MakeValue(pad_list));
}
abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_size_v = input_args[2]->BuildValue();
auto ret_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("x_size", x_size_v, prim_name);
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
SetPadList(primitive, dout_shape_norm, ret_shape);
return std::make_shared<abstract::Shape>(ret_shape);
}
TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
std::map<std::string, TypePtr> types;
types.emplace("doutput", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
}
} // namespace
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kGreaterEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
Conv2DBackpropInputInferShape(primitive, input_args));
return abs;
auto doutput = input_args[0];
auto x_size = input_args[2];
auto x_size_value = x_size->GetValueTrack();
MS_EXCEPTION_IF_NULL(x_size);
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
// infer dtype
auto dtype = doutput->BuildType();
if (!dtype->isa<TensorType>()) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
}
auto input_tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_tensor_type);
auto element = input_tensor_type->element();
// infer shape
auto dout_shape = doutput->BuildShape();
MS_EXCEPTION_IF_NULL(doutput);
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
auto dout_shape_norm = dout_shapeptr->shape();
SetPadList(primitive, dout_shape_norm, x_size_v);
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
}
void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
@ -217,7 +200,6 @@ std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
true);
REGISTER_PRIMITIVE_C(kNameConv2DBackpropInput, Conv2DBackpropInput);
} // namespace ops
} // namespace mindspore

View File

@ -599,7 +599,7 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
MS_EXCEPTION(NotSupportError) << class_name << " operator does not support PyNative mode.";
MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode.";
}
}

View File

@ -965,15 +965,15 @@ class Conv2dTranspose(_Conv):
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=dilation,
group=group)
# cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
self.conv2d_transpose = P.Conv2DTranspose(out_channel=in_channels,
kernel_size=kernel_size,
mode=1,
pad_mode=pad_mode,
pad=padding,
stride=stride,
dilation=dilation,
group=group)
self.bias_add = P.BiasAdd()
if isinstance(self.padding, int):
self.padding_top, self.padding_bottom, self.padding_left, self.padding_right = (self.padding,) * 4

View File

@ -1061,6 +1061,7 @@ def get_bprop_roi_align(self):
return bprop
@bprop_getters.register(P.Conv2DTranspose)
@bprop_getters.register(P.Conv2DBackpropInput)
def get_bprop_conv2d_backprop_input(self):
"""Grad definition for `Conv2DBackpropInput` operation."""

View File

@ -65,7 +65,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
LogUniformCandidateSampler)
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum,
BatchNorm, BiasAdd, Conv2D, Conv3D, Conv3DTranspose,
BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
@ -143,6 +143,7 @@ __all__ = [
'Xlogy',
'Conv2D',
'Conv3D',
'Conv2DTranspose',
'Conv3DTranspose',
'Flatten',
'MaxPoolWithArgmax',

View File

@ -448,7 +448,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
return out
class Conv2DBackpropFilter(Primitive):
class Conv2DBackpropFilter(PrimitiveWithInfer):
"""
Computes the gradients of convolution with respect to the filter.
@ -507,6 +507,21 @@ class Conv2DBackpropFilter(Primitive):
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
for i, dim_len in enumerate(w_size_v):
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32],
self.name)
out = {
'value': None,
'shape': w_size_v,
'dtype': doutput['dtype'],
}
return out
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
"""

View File

@ -1888,7 +1888,7 @@ class AvgPool(_Pool):
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
class Conv2DBackpropInput(Primitive):
class Conv2DBackpropInput(PrimitiveWithInfer):
"""
Computes the gradients of convolution with respect to the input.
@ -1996,6 +1996,109 @@ class Conv2DBackpropInput(Primitive):
validator.check_non_negative_int(x, 'element of pad_list', self.name)
self.pad_list = pad_list
def __infer__(self, doutput, w, x_size):
x_size_v = x_size['value']
validator.check_value_type('x_size', x_size_v, [tuple], self.name)
for i, dim_len in enumerate(x_size_v):
validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
args = {'doutput': doutput['dtype'], 'w': w['dtype']}
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
# infer shape
dout_shape = doutput['shape']
dout_shape_norm = dout_shape if self.format == "NCHW" else \
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
# default pad mode is valid
pad_list = (0, 0, 0, 0)
if self.pad_list:
pad_list = tuple(self.pad_list)
elif self.pad_mode == "SAME":
pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
elif self.pad_mode == 'PAD':
pad_list = self.padding
self.add_prim_attr('pad_list', pad_list)
out = {
'value': None,
'shape': x_size_v,
'dtype': doutput['dtype'],
}
return out
class Conv2DTranspose(Conv2DBackpropInput):
"""
Compute a 2D transposed convolution, which is also known as a deconvolution
(although it is not an actual deconvolution).
Args:
out_channel (int): The dimensionality of the output space.
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
pad (Union[int, tuple[int]]): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1.
stride (Union[int. tuple[int]]): The stride to be applied to the convolution filter. Default: 1.
dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution.
Default: 1.
group (int): Splits input into groups. Default: 1.
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
Inputs:
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
data_format :math:`(N, C_{out}, H_{out}, W_{out})`.
- **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2)`, then the shape is
:math:`(C_{out}, C_{in}, K_1, K_2)`.
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor, the gradients w.r.t the input of convolution. It has the same shape as the input.
Raises:
TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
TypeError: If `out_channel` or `group` is not an int.
ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
ValueError: If `padding` is a tuple whose length is not equal to 4.
ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0).
ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32)
>>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
>>> x = Tensor(np.ones([10, 32, 32, 32]))
>>> conv2d_transpose_input = ops.Conv2DTranspose(out_channel=32, kernel_size=3)
>>> output = conv2d_transpose_input(dout, weight, F.shape(x))
>>> print(output.shape)
(10, 32, 32, 32)
"""
@prim_attr_register
def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0,
pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"):
super(Conv2DTranspose, self).__init__(out_channel, kernel_size, pad_mode, pad,
pad_list, mode, stride, dilation, group, data_format)
class BiasAdd(PrimitiveWithCheck):
r"""
Returns sum of input and bias tensor.

View File

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
context.set_context(device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
out_channel = 4
kernel_size = 1
self.conv_input = P.Conv2DTranspose(out_channel,
kernel_size,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1)
self.w = Parameter(
initializer(Tensor(np.array([[[[1, 0, -1], [1, 0, -1], [1, 0, -1]]]]).astype(np.float32)), [1, 1, 3, 3]),
name='w')
self.x = Parameter(initializer(Tensor(np.array([[[
[3, 0, 1, 2, 7, 4],
[1, 5, 8, 9, 3, 1],
[2, 7, 2, 5, 1, 3],
[0, 1, 3, 1, 7, 8],
[4, 2, 1, 6, 2, 8],
[2, 4, 5, 2, 3, 9]]]]).astype(np.float32)), [1, 1, 6, 6]), name='x')
self.out = Parameter(initializer(Tensor(np.array([[[
[-5, -4, 0, 8],
[-10, -2, 2, 3],
[0, -2, -4, -7],
[-3, -2, -3, -16]]]]).astype(np.float32)), [1, 1, 4, 4]), name='y')
self.get_shape = P.Shape()
@ms_function
def construct(self):
return self.conv_input(self.out, self.w, self.get_shape(self.x))
def test_conv2d_backprop_input():
conv2d_input = Net()
output = conv2d_input()
expect = np.array([[[[-5, -4, 5, 12, 0, -8],
[-15, -6, 17, 17, -2, -11],
[-15, -8, 13, 12, 2, -4],
[-13, -6, 8, -14, 5, 20],
[-3, -4, -4, -19, 7, 23],
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
print(output)
assert (output.asnumpy() == expect).all()