!16537 Bugfix in arithmetic_simplify of graphkernel

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-18 20:24:23 +08:00 committed by Gitee
commit d6be1efedb
3 changed files with 28 additions and 18 deletions

View File

@ -656,27 +656,36 @@ AnfNodePtr SimplifyReduce2(const AnfNodePtr &node) {
} }
PatternNode<AnfNodePtr> x; PatternNode<AnfNodePtr> x;
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto tmp_node = node->cast<CNodePtr>(); auto reduce2 = node->cast<CNodePtr>();
auto arg_node = tmp_node->input(1); auto reduce1 = reduce2->input(1);
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis")); // if the "keep_dims" of two nodes are different, the result "keep_dims" could not match the output shape.
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis")); // inferring shape and another "Reshape" is needed, or skip this case.
ShapeVector new_dimensions; if (AnfAlgo::GetBooleanAttr(reduce1, "keep_dims") != AnfAlgo::GetBooleanAttr(reduce2, "keep_dims")) {
for (size_t i = 0; i < arg_dimensions.size(); ++i) { return nullptr;
for (size_t j = 0; j < reduce_dimensions.size(); ++j) { }
if (reduce_dimensions[j] >= arg_dimensions[i]) { auto reduce1_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce1, "axis"));
++reduce_dimensions[j]; auto reduce2_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce2, "axis"));
ShapeVector new_axis;
// offset the second node's reduction axes.
if (!AnfAlgo::GetBooleanAttr(reduce1, "keep_dims")) {
for (size_t i = 0; i < reduce1_axis.size(); ++i) {
for (size_t j = 0; j < reduce2_axis.size(); ++j) {
if (reduce2_axis[j] >= reduce1_axis[i]) {
++reduce2_axis[j];
}
} }
} }
} }
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(), std::merge(reduce1_axis.begin(), reduce1_axis.end(), reduce2_axis.begin(), reduce2_axis.end(),
std::back_inserter(new_dimensions)); std::back_inserter(new_axis));
new_axis.erase(std::unique(new_axis.begin(), new_axis.end()), new_axis.end());
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); SetNodeAttrSafely("axis", MakeValue(new_axis), new_cnode);
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode; return new_cnode;
}; };
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum->Clone()), x.GetNode(node)}, node);
AnfAlgo::CopyNodeAttr("axis", node, arg_node); AnfAlgo::CopyNodeAttr("axis", node, arg_node);
AnfAlgo::CopyNodeAttr("keep_dims", node, arg_node); AnfAlgo::CopyNodeAttr("keep_dims", node, arg_node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node);

View File

@ -51,9 +51,9 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
auto v1 = GetValue<int64_t>(axis); auto v1 = GetValue<int64_t>(axis);
auto v2 = NormAxis(v1, rank); auto v2 = NormAxis(v1, rank);
axis_vec.push_back(v2); axis_vec.push_back(v2);
diff = diff || (v1 != v2); diff = true;
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) { } else if (axis->isa<ValueSequeue>()) {
auto vec = axis->isa<ValueList>() ? axis->cast<ValueListPtr>()->value() : axis->cast<ValueTuplePtr>()->value(); auto vec = axis->cast<ValueSequeuePtr>()->value();
if (vec.empty()) { if (vec.empty()) {
diff = true; diff = true;
for (size_t i = 0; i < rank; i++) { for (size_t i = 0; i < rank; i++) {

View File

@ -47,8 +47,9 @@ namespace mindspore {
namespace opt { namespace opt {
PassManagerPtr GraphKernelOptimizer::PreProcess() const { PassManagerPtr GraphKernelOptimizer::PreProcess() const {
auto pm = std::make_shared<PassManager>("graphkernel_stage1_preprocess"); auto pm = std::make_shared<PassManager>("graphkernel_stage1_preprocess");
// Add cse at beginning, otherwise the SpreadUpdateState ir size may be huge in yolov3 network. // Do cse before all passes of graphkernel
pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
// Change Assign(p, a, U) to Assign(Depend(p, U), a) // Change Assign(p, a, U) to Assign(Depend(p, U), a)
pm->AddPass(std::make_shared<SplitAssign>()); pm->AddPass(std::make_shared<SplitAssign>());
@ -59,7 +60,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
// Spread the MakeTuple input of UpdateState // Spread the MakeTuple input of UpdateState
pm->AddPass(std::make_shared<SpreadUpdateState>()); pm->AddPass(std::make_shared<SpreadUpdateState>());
// Add cse to reduce the SpreadUpdateState ir size, which is huge in yolov3 network. // Eliminate the common nodes that generated in SpreadUpdateState
pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
return pm; return pm;
} }