diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index f7d5fa7d19e..e4c2e59ec2f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -620,6 +620,23 @@ bool ArithmeticSimplify::DoConstantFold(const graphkernel::LiteGraphPtr &litegra return changed; } +void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) { + auto &outputs = litegraph->GetOutputs(); + for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i]->NodeType() == graphkernel::NType::Value) { + graphkernel::PrimOpPtr op_ptr = std::make_shared("BroadcastTo", ""); + std::vector new_shape = {1}; + op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}}); + litegraph->output()->SetInput(i, op_ptr); + } else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) { + graphkernel::PrimOpPtr op_ptr = std::make_shared("Reshape", ""); + op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}}); + litegraph->output()->SetInput(i, op_ptr); + } + } + return; +} + bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { auto mng = func_graph->manager(); bool do_simplify = false; @@ -637,11 +654,13 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { change_anf_graph = change_anf_graph || find_pattern; } if (!change_anf_graph) continue; + ReorganizeEmptyGraph(lg); AnfNodePtrList outputs; auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs); new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); auto cnode = node->cast(); AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end()); + EliminateRedundantParameters(new_funcgraph, &inputs); auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs); SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs); mng->Replace(node, new_node); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index e2e41bbe687..2c30e4b02e1 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -130,29 +130,6 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) { return JsonDescToAnf(kernel_desc_str); } -void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { - const auto &ori_parameter = func_graph->parameters(); - auto todos = TopoSort(func_graph->get_return()); - std::set used_param; - for (auto node : todos) { - if (node->isa()) { - used_param.insert(node); - } - } - if (used_param.size() == ori_parameter.size()) { - return; - } - AnfNodePtrList new_parameter, new_inputs; - for (size_t i = 0; i < ori_parameter.size(); ++i) { - if (used_param.count(ori_parameter[i])) { - new_parameter.push_back(ori_parameter[i]); - new_inputs.push_back((*inputs)[i]); - } - } - func_graph->set_parameters(new_parameter); - *inputs = std::move(new_inputs); -} - AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) { auto func_graph = old_node->func_graph(); std::vector inputs(old_node->inputs().begin() + 1, old_node->inputs().end()); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h index 382a208d840..fcb53e1cbd4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.h @@ -36,7 +36,6 @@ class DefaultExpander : public Expander { protected: virtual bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json); - virtual void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node); virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node); }; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 6789041ef45..bf25889d1e5 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -877,8 +877,8 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type}); AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get()); } - // Create CNodes. the ops is already in topo order - for (const auto &op_node : lite_graph->ops()) { + // Create CNodes. + for (const auto &op_node : lite_graph->GetOrderedNodes()) { if (op_node->NodeType() != graphkernel::NType::Primitive) { MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node"; } @@ -934,5 +934,28 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf } return func_graph; } + +void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) { + const auto &ori_parameter = func_graph->parameters(); + auto todos = TopoSort(func_graph->get_return()); + std::set used_param; + for (auto node : todos) { + if (node->isa()) { + used_param.insert(node); + } + } + if (used_param.size() == ori_parameter.size()) { + return; + } + AnfNodePtrList new_parameter, new_inputs; + for (size_t i = 0; i < ori_parameter.size(); ++i) { + if (used_param.count(ori_parameter[i])) { + new_parameter.push_back(ori_parameter[i]); + new_inputs.push_back((*inputs)[i]); + } + } + func_graph->set_parameters(new_parameter); + *inputs = std::move(new_inputs); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 46854407413..68bdae3638d 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -136,6 +136,9 @@ ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t d // functions to graphkernel model graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph); FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs = nullptr); + +// remove parameter which is not used +void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc index 83c6a6ce884..6a0f9168c03 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.cc @@ -115,10 +115,24 @@ NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &basei PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) { static std::map> creators; if (creators.empty()) { - creators = {{"Add", Elemwise}, {"Sub", Elemwise}, {"RealDiv", Elemwise}, {"Mul", Elemwise}, - {"Log", Elemwise}, {"Pow", Elemwise}, {"Sqrt", Elemwise}, {"Rsqrt", Elemwise}, - {"Rsqrt", Elemwise}, {"Neg", Elemwise}, {"Reciprocal", Elemwise}, {"Abs", Elemwise}, - {"ReduceSum", Reduce}, {"ReduceMax", Reduce}, {"ReduceMin", Reduce}, {"Conv2D", Conv2d}}; + creators = {{"Add", Elemwise}, + {"Sub", Elemwise}, + {"RealDiv", Elemwise}, + {"Mul", Elemwise}, + {"Log", Elemwise}, + {"Exp", Elemwise}, + {"Pow", Elemwise}, + {"Sqrt", Elemwise}, + {"Rsqrt", Elemwise}, + {"Neg", Elemwise}, + {"Reciprocal", Elemwise}, + {"Abs", Elemwise}, + {"BroadcastTo", BroadcastTo}, + {"Reshape", Reshape}, + {"ReduceSum", Reduce}, + {"ReduceMax", Reduce}, + {"ReduceMin", Reduce}, + {"Conv2D", Conv2d}}; } auto iter = creators.find(op); auto creator = (iter == creators.end() ? Opaque : iter->second); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h index 9c95b00ce1b..439a172fc58 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/lite_graph.h @@ -84,6 +84,15 @@ class LiteGraph::GraphBuilder { static PrimOpPtr Elemwise(const std::string &op, const std::string &name) { return std::make_shared(op, name); } + + static PrimOpPtr BroadcastTo(const std::string &op, const std::string &name) { + return std::make_shared(op, name); + } + + static PrimOpPtr Reshape(const std::string &op, const std::string &name) { + return std::make_shared(op, name); + } + static PrimOpPtr Reduce(const std::string &op, const std::string &name) { return std::make_shared(op, name); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc index 987773284bc..3a03f3cf4b5 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "backend/optimizer/graph_kernel/model/node.h" @@ -159,6 +160,24 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) { compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE; } +DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + return GetValue>(attrs.find("shape")->second); +} + +DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { + auto new_shape = GetValue>(attrs.find("shape")->second); + auto origin_shape = inputs[0]->shape; + for (size_t i = 0; i < new_shape.size(); i++) { + if (new_shape[i] == -1) { + auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies()); + auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies()); + new_shape[i] = origin_product / new_product * (-1); + break; + } + } + return new_shape; +} + DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { auto axis = GetValue>(attrs.find("axis")->second); auto keepdims = GetValue(attrs.find("keep_dims")->second); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h index dc46a0e396c..c477bd08488 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.h @@ -63,6 +63,22 @@ class ElemwiseOp : public PrimOp { // TODO(dayschan) rewrite InferShape/InferFormat }; +class ReshapeOp : public PrimOp { + public: + ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {} + + protected: + DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; +}; + +class BroadcastToOp : public PrimOp { + public: + BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {} + + protected: + DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override; +}; + class ReduceOp : public PrimOp { public: ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {} diff --git a/tests/st/ops/graph_kernel/test_simplify.py b/tests/st/ops/graph_kernel/test_simplify.py index 9648947d031..d75d8076a17 100644 --- a/tests/st/ops/graph_kernel/test_simplify.py +++ b/tests/st/ops/graph_kernel/test_simplify.py @@ -50,6 +50,19 @@ class Net(Cell): return self.reducemin(self.reducemin(red_res, 1), 1) +class EmptyNet(Cell): + def __init__(self): + super(EmptyNet, self).__init__() + self.add = P.Add() + self.neg = P.Neg() + + def construct(self, x, y): + add_res1 = self.add(x, y) + neg_res1 = self.neg(x) + add_res2 = self.add(add_res1, neg_res1) + return add_res2 + + def test_basic(): input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) @@ -73,18 +86,33 @@ def test_basic(): assert res +def test_empty_graph(): + input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) + expect = input_y + + net = EmptyNet() + result = net(Tensor(input_x), Tensor(input_y)) + + res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, + atol=1.e-7, equal_nan=True) + assert res + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_basic_gpu(): context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") test_basic() + test_empty_graph() -@pytest.mark.level0 +@pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_basic_ascend(): context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") test_basic() + test_empty_graph()