!16537 Bugfix in arithmetic_simplify of graphkernel

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-18 20:24:23 +08:00 committed by Gitee
commit d6be1efedb
3 changed files with 28 additions and 18 deletions

View File

@ -656,27 +656,36 @@ AnfNodePtr SimplifyReduce2(const AnfNodePtr &node) {
}
PatternNode<AnfNodePtr> x;
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto tmp_node = node->cast<CNodePtr>();
auto arg_node = tmp_node->input(1);
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis"));
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
ShapeVector new_dimensions;
for (size_t i = 0; i < arg_dimensions.size(); ++i) {
for (size_t j = 0; j < reduce_dimensions.size(); ++j) {
if (reduce_dimensions[j] >= arg_dimensions[i]) {
++reduce_dimensions[j];
auto reduce2 = node->cast<CNodePtr>();
auto reduce1 = reduce2->input(1);
// if the "keep_dims" of two nodes are different, the result "keep_dims" could not match the output shape.
// inferring shape and another "Reshape" is needed, or skip this case.
if (AnfAlgo::GetBooleanAttr(reduce1, "keep_dims") != AnfAlgo::GetBooleanAttr(reduce2, "keep_dims")) {
return nullptr;
}
auto reduce1_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce1, "axis"));
auto reduce2_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce2, "axis"));
ShapeVector new_axis;
// offset the second node's reduction axes.
if (!AnfAlgo::GetBooleanAttr(reduce1, "keep_dims")) {
for (size_t i = 0; i < reduce1_axis.size(); ++i) {
for (size_t j = 0; j < reduce2_axis.size(); ++j) {
if (reduce2_axis[j] >= reduce1_axis[i]) {
++reduce2_axis[j];
}
}
}
}
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(),
std::back_inserter(new_dimensions));
std::merge(reduce1_axis.begin(), reduce1_axis.end(), reduce2_axis.begin(), reduce2_axis.end(),
std::back_inserter(new_axis));
new_axis.erase(std::unique(new_axis.begin(), new_axis.end()), new_axis.end());
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
SetNodeAttrSafely("axis", MakeValue(new_axis), new_cnode);
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode;
};
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node);
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum->Clone()), x.GetNode(node)}, node);
AnfAlgo::CopyNodeAttr("axis", node, arg_node);
AnfAlgo::CopyNodeAttr("keep_dims", node, arg_node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node);

View File

@ -51,9 +51,9 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
auto v1 = GetValue<int64_t>(axis);
auto v2 = NormAxis(v1, rank);
axis_vec.push_back(v2);
diff = diff || (v1 != v2);
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
auto vec = axis->isa<ValueList>() ? axis->cast<ValueListPtr>()->value() : axis->cast<ValueTuplePtr>()->value();
diff = true;
} else if (axis->isa<ValueSequeue>()) {
auto vec = axis->cast<ValueSequeuePtr>()->value();
if (vec.empty()) {
diff = true;
for (size_t i = 0; i < rank; i++) {

View File

@ -47,8 +47,9 @@ namespace mindspore {
namespace opt {
PassManagerPtr GraphKernelOptimizer::PreProcess() const {
auto pm = std::make_shared<PassManager>("graphkernel_stage1_preprocess");
// Add cse at beginning, otherwise the SpreadUpdateState ir size may be huge in yolov3 network.
// Do cse before all passes of graphkernel
pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
pm->AddPass(std::make_shared<SplitAssign>());
@ -59,7 +60,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
// Spread the MakeTuple input of UpdateState
pm->AddPass(std::make_shared<SpreadUpdateState>());
// Add cse to reduce the SpreadUpdateState ir size, which is huge in yolov3 network.
// Eliminate the common nodes that generated in SpreadUpdateState
pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
return pm;
}