forked from mindspore-Ecosystem/mindspore
fix Conv2DBackpropInputUnifyMindIR when it is a forward op in pynative
This commit is contained in:
parent
56650c7b2c
commit
90a83db5f1
|
@ -37,6 +37,8 @@ constexpr auto kAttrPadList = "pad_list";
|
|||
constexpr auto kAttrMode = "mode";
|
||||
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
|
||||
constexpr auto kAttrPerm = "perm";
|
||||
constexpr auto kAttrInputSizes = "input_sizes";
|
||||
constexpr auto kAttrInputSize = "input_size";
|
||||
|
||||
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
|
@ -144,14 +146,22 @@ CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNo
|
|||
const CNodePtr &transpose) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d_backin);
|
||||
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
|
||||
<< conv2d_backin->inputs().size() - 1;
|
||||
|
||||
CNodePtr depth_conv_backin = nullptr;
|
||||
if (conv2d_backin->inputs().size() == kConv2DBackpropInputNum) {
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
|
||||
transpose, conv2d_backin->input(1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
} else {
|
||||
// In nn.Conv2DTranspose, Conv2DBackpropInput is a forward op and the input_sizes input will be convert to attr
|
||||
// in pynative mode.
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), transpose,
|
||||
conv2d_backin->input(1)};
|
||||
depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
AnfAlgo::CopyNodeAttr(kAttrInputSizes, kAttrInputSize, conv2d_backin, depth_conv_backin);
|
||||
}
|
||||
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
|
||||
transpose, conv2d_backin->input(1)};
|
||||
auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backin);
|
||||
depth_conv_backin->set_abstract(conv2d_backin->abstract());
|
||||
depth_conv_backin->set_scope(conv2d_backin->scope());
|
||||
|
@ -265,10 +275,8 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
|
|||
}
|
||||
|
||||
const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const {
|
||||
VarPtr dout = std::make_shared<Var>();
|
||||
VarPtr weight = std::make_shared<Var>();
|
||||
VarPtr input_size = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size});
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimConv2DBackpropInput, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
|
@ -285,9 +293,11 @@ const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &gra
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
|
||||
<< conv2d_backin->inputs().size() - 1;
|
||||
auto input_size = conv2d_backin->inputs().size();
|
||||
// In pynative mode, input_sizes input will be convert to attr if Conv2DBackpropInput is a forward op.
|
||||
if (input_size != kConv2DBackpropInputNum && input_size != kConv2DBackpropInputNum - 1) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << " or "
|
||||
<< kConv2DBackpropInputNum - 2 << ", but got " << input_size - 1;
|
||||
}
|
||||
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true);
|
||||
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
|
||||
|
|
|
@ -803,8 +803,8 @@ class FusedBatchNorm(Primitive):
|
|||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.9.
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
@ -893,8 +893,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.9.
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
||||
Default: "NCHW".
|
||||
|
||||
|
@ -1262,6 +1262,9 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training.
|
||||
If `is_training` is False, they're loaded from checkpoint during inference. Default: False.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
||||
Default: "NCHW".
|
||||
|
||||
|
|
Loading…
Reference in New Issue