forked from mindspore-Ecosystem/mindspore
reorganiz empty graph
This commit is contained in:
parent
7b20a5adf7
commit
fd7ab25fc2
|
@ -620,6 +620,23 @@ bool ArithmeticSimplify::DoConstantFold(const graphkernel::LiteGraphPtr &litegra
|
|||
return changed;
|
||||
}
|
||||
|
||||
void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) {
|
||||
auto &outputs = litegraph->GetOutputs();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
if (outputs[i]->NodeType() == graphkernel::NType::Value) {
|
||||
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::BroadcastToOp>("BroadcastTo", "");
|
||||
std::vector<int64_t> new_shape = {1};
|
||||
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}});
|
||||
litegraph->output()->SetInput(i, op_ptr);
|
||||
} else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) {
|
||||
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::ReshapeOp>("Reshape", "");
|
||||
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
|
||||
litegraph->output()->SetInput(i, op_ptr);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
bool do_simplify = false;
|
||||
|
@ -637,11 +654,13 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
|||
change_anf_graph = change_anf_graph || find_pattern;
|
||||
}
|
||||
if (!change_anf_graph) continue;
|
||||
ReorganizeEmptyGraph(lg);
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
EliminateRedundantParameters(new_funcgraph, &inputs);
|
||||
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
mng->Replace(node, new_node);
|
||||
|
|
|
@ -130,29 +130,6 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
|||
return JsonDescToAnf(kernel_desc_str);
|
||||
}
|
||||
|
||||
void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
|
||||
const auto &ori_parameter = func_graph->parameters();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
std::set<AnfNodePtr> used_param;
|
||||
for (auto node : todos) {
|
||||
if (node->isa<Parameter>()) {
|
||||
used_param.insert(node);
|
||||
}
|
||||
}
|
||||
if (used_param.size() == ori_parameter.size()) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_parameter, new_inputs;
|
||||
for (size_t i = 0; i < ori_parameter.size(); ++i) {
|
||||
if (used_param.count(ori_parameter[i])) {
|
||||
new_parameter.push_back(ori_parameter[i]);
|
||||
new_inputs.push_back((*inputs)[i]);
|
||||
}
|
||||
}
|
||||
func_graph->set_parameters(new_parameter);
|
||||
*inputs = std::move(new_inputs);
|
||||
}
|
||||
|
||||
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
|
||||
auto func_graph = old_node->func_graph();
|
||||
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
|
||||
|
|
|
@ -36,7 +36,6 @@ class DefaultExpander : public Expander {
|
|||
|
||||
protected:
|
||||
virtual bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
|
||||
virtual void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
|
||||
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
|
||||
};
|
||||
|
|
|
@ -877,8 +877,8 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf
|
|||
auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get());
|
||||
}
|
||||
// Create CNodes. the ops is already in topo order
|
||||
for (const auto &op_node : lite_graph->ops()) {
|
||||
// Create CNodes.
|
||||
for (const auto &op_node : lite_graph->GetOrderedNodes()) {
|
||||
if (op_node->NodeType() != graphkernel::NType::Primitive) {
|
||||
MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
|
||||
}
|
||||
|
@ -934,5 +934,28 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf
|
|||
}
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
|
||||
const auto &ori_parameter = func_graph->parameters();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
std::set<AnfNodePtr> used_param;
|
||||
for (auto node : todos) {
|
||||
if (node->isa<Parameter>()) {
|
||||
used_param.insert(node);
|
||||
}
|
||||
}
|
||||
if (used_param.size() == ori_parameter.size()) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_parameter, new_inputs;
|
||||
for (size_t i = 0; i < ori_parameter.size(); ++i) {
|
||||
if (used_param.count(ori_parameter[i])) {
|
||||
new_parameter.push_back(ori_parameter[i]);
|
||||
new_inputs.push_back((*inputs)[i]);
|
||||
}
|
||||
}
|
||||
func_graph->set_parameters(new_parameter);
|
||||
*inputs = std::move(new_inputs);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -136,6 +136,9 @@ ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t d
|
|||
// functions to graphkernel model
|
||||
graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph);
|
||||
FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs = nullptr);
|
||||
|
||||
// remove parameter which is not used
|
||||
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
|
|
|
@ -115,10 +115,24 @@ NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &basei
|
|||
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
|
||||
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
|
||||
if (creators.empty()) {
|
||||
creators = {{"Add", Elemwise}, {"Sub", Elemwise}, {"RealDiv", Elemwise}, {"Mul", Elemwise},
|
||||
{"Log", Elemwise}, {"Pow", Elemwise}, {"Sqrt", Elemwise}, {"Rsqrt", Elemwise},
|
||||
{"Rsqrt", Elemwise}, {"Neg", Elemwise}, {"Reciprocal", Elemwise}, {"Abs", Elemwise},
|
||||
{"ReduceSum", Reduce}, {"ReduceMax", Reduce}, {"ReduceMin", Reduce}, {"Conv2D", Conv2d}};
|
||||
creators = {{"Add", Elemwise},
|
||||
{"Sub", Elemwise},
|
||||
{"RealDiv", Elemwise},
|
||||
{"Mul", Elemwise},
|
||||
{"Log", Elemwise},
|
||||
{"Exp", Elemwise},
|
||||
{"Pow", Elemwise},
|
||||
{"Sqrt", Elemwise},
|
||||
{"Rsqrt", Elemwise},
|
||||
{"Neg", Elemwise},
|
||||
{"Reciprocal", Elemwise},
|
||||
{"Abs", Elemwise},
|
||||
{"BroadcastTo", BroadcastTo},
|
||||
{"Reshape", Reshape},
|
||||
{"ReduceSum", Reduce},
|
||||
{"ReduceMax", Reduce},
|
||||
{"ReduceMin", Reduce},
|
||||
{"Conv2D", Conv2d}};
|
||||
}
|
||||
auto iter = creators.find(op);
|
||||
auto creator = (iter == creators.end() ? Opaque : iter->second);
|
||||
|
|
|
@ -84,6 +84,15 @@ class LiteGraph::GraphBuilder {
|
|||
static PrimOpPtr Elemwise(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<ElemwiseOp>(op, name);
|
||||
}
|
||||
|
||||
static PrimOpPtr BroadcastTo(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<BroadcastToOp>(op, name);
|
||||
}
|
||||
|
||||
static PrimOpPtr Reshape(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<ReshapeOp>(op, name);
|
||||
}
|
||||
|
||||
static PrimOpPtr Reduce(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<ReduceOp>(op, name);
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <numeric>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
|
@ -159,6 +160,24 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
|
||||
}
|
||||
|
||||
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
return GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
|
||||
}
|
||||
|
||||
DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
auto new_shape = GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
|
||||
auto origin_shape = inputs[0]->shape;
|
||||
for (size_t i = 0; i < new_shape.size(); i++) {
|
||||
if (new_shape[i] == -1) {
|
||||
auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
|
||||
auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
|
||||
new_shape[i] = origin_product / new_product * (-1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
auto axis = GetValue<std::vector<int64_t>>(attrs.find("axis")->second);
|
||||
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
|
||||
|
|
|
@ -63,6 +63,22 @@ class ElemwiseOp : public PrimOp {
|
|||
// TODO(dayschan) rewrite InferShape/InferFormat
|
||||
};
|
||||
|
||||
class ReshapeOp : public PrimOp {
|
||||
public:
|
||||
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
class BroadcastToOp : public PrimOp {
|
||||
public:
|
||||
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
class ReduceOp : public PrimOp {
|
||||
public:
|
||||
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
|
||||
|
|
|
@ -50,6 +50,19 @@ class Net(Cell):
|
|||
return self.reducemin(self.reducemin(red_res, 1), 1)
|
||||
|
||||
|
||||
class EmptyNet(Cell):
|
||||
def __init__(self):
|
||||
super(EmptyNet, self).__init__()
|
||||
self.add = P.Add()
|
||||
self.neg = P.Neg()
|
||||
|
||||
def construct(self, x, y):
|
||||
add_res1 = self.add(x, y)
|
||||
neg_res1 = self.neg(x)
|
||||
add_res2 = self.add(add_res1, neg_res1)
|
||||
return add_res2
|
||||
|
||||
|
||||
def test_basic():
|
||||
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
|
@ -73,18 +86,33 @@ def test_basic():
|
|||
assert res
|
||||
|
||||
|
||||
def test_empty_graph():
|
||||
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||
expect = input_y
|
||||
|
||||
net = EmptyNet()
|
||||
result = net(Tensor(input_x), Tensor(input_y))
|
||||
|
||||
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4,
|
||||
atol=1.e-7, equal_nan=True)
|
||||
assert res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_basic_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
test_basic()
|
||||
test_empty_graph()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_basic_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
test_basic()
|
||||
test_empty_graph()
|
||||
|
|
Loading…
Reference in New Issue