forked from mindspore-Ecosystem/mindspore
!28060 fix shared valuenode bug in Conv2DUnifyMindIR pass
Merge pull request !28060 from yuchaojie/ir_fusion
This commit is contained in:
commit
398cec40b3
|
@ -34,6 +34,7 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr size_t kConv2DBackpropInputNum = 3;
|
||||
constexpr size_t kConv2DAxisNum = 4;
|
||||
constexpr size_t kConv2DFilterSize = 4;
|
||||
constexpr auto kAttrOffsetA = "offset_a";
|
||||
constexpr auto kAttrPadList = "pad_list";
|
||||
constexpr auto kAttrMode = "mode";
|
||||
|
@ -291,15 +292,16 @@ CNodePtr Conv2DBackpropFilterUnifyMindIR::CreateDepthwiseConv2DBackpropFilter(co
|
|||
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(filter_size_vnode);
|
||||
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
|
||||
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
|
||||
// when the filter_size value is same.
|
||||
if (filter_size[0] != 1) {
|
||||
std::swap(filter_size[0], filter_size[1]);
|
||||
conv2d_backfil->input(kIndex3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
|
||||
if (filter_size.size() < kConv2DFilterSize) {
|
||||
MS_LOG(EXCEPTION) << "Filter size input of node[" << conv2d_backfil->fullname_with_scope()
|
||||
<< "] should be 4-D, but got " << filter_size;
|
||||
}
|
||||
std::swap(filter_size[0], filter_size[1]);
|
||||
auto new_filter_size_vnode = CreateShapeValueNode(graph, filter_size);
|
||||
|
||||
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)),
|
||||
conv2d_backfil->input(kIndex2), conv2d_backfil->input(kIndex3), conv2d_backfil->input(kIndex1)};
|
||||
conv2d_backfil->input(kIndex2), new_filter_size_vnode, conv2d_backfil->input(kIndex1)};
|
||||
auto depth_conv_backfil = NewCNode(depth_conv_backfil_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
|
||||
depth_conv_backfil->set_scope(conv2d_backfil->scope());
|
||||
|
|
Loading…
Reference in New Issue