fix Conv2DBackpropInputUnifyMindIR when it is a forward op in pynative

This commit is contained in:
yuchaojie 2021-02-04 09:46:59 +08:00
parent 56650c7b2c
commit 90a83db5f1
2 changed files with 31 additions and 18 deletions

View File

@ -37,6 +37,8 @@ constexpr auto kAttrPadList = "pad_list";
constexpr auto kAttrMode = "mode";
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
constexpr auto kAttrPerm = "perm";
constexpr auto kAttrInputSizes = "input_sizes";
constexpr auto kAttrInputSize = "input_size";
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
MS_EXCEPTION_IF_NULL(conv2d);
@ -144,14 +146,22 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo
const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backin);
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
<< conv2d_backin->inputs().size() - 1;
CNodePtr depth_conv_backin = nullptr;
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
transpose, conv2d_backin->input(1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
} else {
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
// in pynative mode.
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
conv2d_backin->input(1)};
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
}
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
transpose, conv2d_backin->input(1)};
auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
MS_EXCEPTION_IF_NULL(depth_conv_backin);
depth_conv_backin->set_abstract(conv2d_backin->abstract());
depth_conv_backin->set_scope(conv2d_backin->scope());
@ -265,10 +275,8 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
}
const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const {
VarPtr dout = std::make_shared<Var>();
VarPtr weight = std::make_shared<Var>();
VarPtr input_size = std::make_shared<Var>();
VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size});
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs});
return pattern;
}
@ -285,9 +293,11 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
return nullptr;
}
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
<< conv2d_backin->inputs().size() - 1;
auto input_size = conv2d_backin->inputs().size();
// In pynative mode, input_sizes input will be convert to attr if Conv2DBackpropInput is a forward op.
if (input_size != kConv2DBackpropInputNum && input_size != kConv2DBackpropInputNum - 1) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << " or "
<< kConv2DBackpropInputNum - 2 << ", but got " << input_size - 1;
}
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true);
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);

View File

@ -803,8 +803,8 @@ class FusedBatchNorm(Primitive):
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value must be [0, 1]. Default: 0.9.
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
@ -893,8 +893,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value must be [0, 1]. Default: 0.9.
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: "NCHW".
@ -1262,6 +1262,9 @@ class BatchNorm(PrimitiveWithInfer):
is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training.
If `is_training` is False, they're loaded from checkpoint during inference. Default: False.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: "NCHW".