forked from mindspore-Ecosystem/mindspore
!16537 Bugfix in arithmetic_simplify of graphkernel
From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
d6be1efedb
|
@ -656,27 +656,36 @@ AnfNodePtr SimplifyReduce2(const AnfNodePtr &node) {
|
|||
}
|
||||
PatternNode<AnfNodePtr> x;
|
||||
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||
auto tmp_node = node->cast<CNodePtr>();
|
||||
auto arg_node = tmp_node->input(1);
|
||||
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis"));
|
||||
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||
ShapeVector new_dimensions;
|
||||
for (size_t i = 0; i < arg_dimensions.size(); ++i) {
|
||||
for (size_t j = 0; j < reduce_dimensions.size(); ++j) {
|
||||
if (reduce_dimensions[j] >= arg_dimensions[i]) {
|
||||
++reduce_dimensions[j];
|
||||
auto reduce2 = node->cast<CNodePtr>();
|
||||
auto reduce1 = reduce2->input(1);
|
||||
// if the "keep_dims" of two nodes are different, the result "keep_dims" could not match the output shape.
|
||||
// inferring shape and another "Reshape" is needed, or skip this case.
|
||||
if (AnfAlgo::GetBooleanAttr(reduce1, "keep_dims") != AnfAlgo::GetBooleanAttr(reduce2, "keep_dims")) {
|
||||
return nullptr;
|
||||
}
|
||||
auto reduce1_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce1, "axis"));
|
||||
auto reduce2_axis = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(reduce2, "axis"));
|
||||
ShapeVector new_axis;
|
||||
// offset the second node's reduction axes.
|
||||
if (!AnfAlgo::GetBooleanAttr(reduce1, "keep_dims")) {
|
||||
for (size_t i = 0; i < reduce1_axis.size(); ++i) {
|
||||
for (size_t j = 0; j < reduce2_axis.size(); ++j) {
|
||||
if (reduce2_axis[j] >= reduce1_axis[i]) {
|
||||
++reduce2_axis[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(),
|
||||
std::back_inserter(new_dimensions));
|
||||
std::merge(reduce1_axis.begin(), reduce1_axis.end(), reduce2_axis.begin(), reduce2_axis.end(),
|
||||
std::back_inserter(new_axis));
|
||||
new_axis.erase(std::unique(new_axis.begin(), new_axis.end()), new_axis.end());
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||
SetNodeAttrSafely("axis", MakeValue(new_axis), new_cnode);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||
return new_cnode;
|
||||
};
|
||||
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
|
||||
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node);
|
||||
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum->Clone()), x.GetNode(node)}, node);
|
||||
AnfAlgo::CopyNodeAttr("axis", node, arg_node);
|
||||
AnfAlgo::CopyNodeAttr("keep_dims", node, arg_node);
|
||||
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node);
|
||||
|
|
|
@ -51,9 +51,9 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
|
|||
auto v1 = GetValue<int64_t>(axis);
|
||||
auto v2 = NormAxis(v1, rank);
|
||||
axis_vec.push_back(v2);
|
||||
diff = diff || (v1 != v2);
|
||||
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
|
||||
auto vec = axis->isa<ValueList>() ? axis->cast<ValueListPtr>()->value() : axis->cast<ValueTuplePtr>()->value();
|
||||
diff = true;
|
||||
} else if (axis->isa<ValueSequeue>()) {
|
||||
auto vec = axis->cast<ValueSequeuePtr>()->value();
|
||||
if (vec.empty()) {
|
||||
diff = true;
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
|
|
|
@ -47,8 +47,9 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
PassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
||||
auto pm = std::make_shared<PassManager>("graphkernel_stage1_preprocess");
|
||||
// Add cse at beginning, otherwise the SpreadUpdateState ir size may be huge in yolov3 network.
|
||||
// Do cse before all passes of graphkernel
|
||||
pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
||||
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
|
||||
pm->AddPass(std::make_shared<SplitAssign>());
|
||||
|
||||
|
@ -59,7 +60,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
|||
|
||||
// Spread the MakeTuple input of UpdateState
|
||||
pm->AddPass(std::make_shared<SpreadUpdateState>());
|
||||
// Add cse to reduce the SpreadUpdateState ir size, which is huge in yolov3 network.
|
||||
// Eliminate the common nodes that generated in SpreadUpdateState
|
||||
pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
return pm;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue