reorganiz empty graph

This commit is contained in:
Yang Jiao 2021-07-30 15:37:47 +08:00
parent 7b20a5adf7
commit fd7ab25fc2
10 changed files with 138 additions and 31 deletions

View File

@ -620,6 +620,23 @@ bool ArithmeticSimplify::DoConstantFold(const graphkernel::LiteGraphPtr &litegra
return changed;
}
void ReorganizeEmptyGraph(const graphkernel::LiteGraphPtr &litegraph) {
auto &outputs = litegraph->GetOutputs();
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i]->NodeType() == graphkernel::NType::Value) {
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::BroadcastToOp>("BroadcastTo", "");
std::vector<int64_t> new_shape = {1};
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(new_shape)}});
litegraph->output()->SetInput(i, op_ptr);
} else if (outputs[i]->NodeType() == graphkernel::NType::Parameter) {
graphkernel::PrimOpPtr op_ptr = std::make_shared<graphkernel::ReshapeOp>("Reshape", "");
op_ptr->Infer({outputs[i]}, {{"shape", MakeValue(outputs[i]->shape)}});
litegraph->output()->SetInput(i, op_ptr);
}
}
return;
}
bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
bool do_simplify = false;
@ -637,11 +654,13 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
change_anf_graph = change_anf_graph || find_pattern;
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
mng->Replace(node, new_node);

View File

@ -130,29 +130,6 @@ FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
return JsonDescToAnf(kernel_desc_str);
}
void DefaultExpander::EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
const auto &ori_parameter = func_graph->parameters();
auto todos = TopoSort(func_graph->get_return());
std::set<AnfNodePtr> used_param;
for (auto node : todos) {
if (node->isa<Parameter>()) {
used_param.insert(node);
}
}
if (used_param.size() == ori_parameter.size()) {
return;
}
AnfNodePtrList new_parameter, new_inputs;
for (size_t i = 0; i < ori_parameter.size(); ++i) {
if (used_param.count(ori_parameter[i])) {
new_parameter.push_back(ori_parameter[i]);
new_inputs.push_back((*inputs)[i]);
}
}
func_graph->set_parameters(new_parameter);
*inputs = std::move(new_inputs);
}
AnfNodePtr DefaultExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
auto func_graph = old_node->func_graph();
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());

View File

@ -36,7 +36,6 @@ class DefaultExpander : public Expander {
protected:
virtual bool ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json);
virtual void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
virtual AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node);
virtual FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
};

View File

@ -877,8 +877,8 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf
auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get());
}
// Create CNodes. the ops is already in topo order
for (const auto &op_node : lite_graph->ops()) {
// Create CNodes.
for (const auto &op_node : lite_graph->GetOrderedNodes()) {
if (op_node->NodeType() != graphkernel::NType::Primitive) {
MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
}
@ -934,5 +934,28 @@ FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, Anf
}
return func_graph;
}
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
const auto &ori_parameter = func_graph->parameters();
auto todos = TopoSort(func_graph->get_return());
std::set<AnfNodePtr> used_param;
for (auto node : todos) {
if (node->isa<Parameter>()) {
used_param.insert(node);
}
}
if (used_param.size() == ori_parameter.size()) {
return;
}
AnfNodePtrList new_parameter, new_inputs;
for (size_t i = 0; i < ori_parameter.size(); ++i) {
if (used_param.count(ori_parameter[i])) {
new_parameter.push_back(ori_parameter[i]);
new_inputs.push_back((*inputs)[i]);
}
}
func_graph->set_parameters(new_parameter);
*inputs = std::move(new_inputs);
}
} // namespace opt
} // namespace mindspore

View File

@ -136,6 +136,9 @@ ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t d
// functions to graphkernel model
graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph);
FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs = nullptr);
// remove parameter which is not used
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

View File

@ -115,10 +115,24 @@ NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &basei
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
if (creators.empty()) {
creators = {{"Add", Elemwise}, {"Sub", Elemwise}, {"RealDiv", Elemwise}, {"Mul", Elemwise},
{"Log", Elemwise}, {"Pow", Elemwise}, {"Sqrt", Elemwise}, {"Rsqrt", Elemwise},
{"Rsqrt", Elemwise}, {"Neg", Elemwise}, {"Reciprocal", Elemwise}, {"Abs", Elemwise},
{"ReduceSum", Reduce}, {"ReduceMax", Reduce}, {"ReduceMin", Reduce}, {"Conv2D", Conv2d}};
creators = {{"Add", Elemwise},
{"Sub", Elemwise},
{"RealDiv", Elemwise},
{"Mul", Elemwise},
{"Log", Elemwise},
{"Exp", Elemwise},
{"Pow", Elemwise},
{"Sqrt", Elemwise},
{"Rsqrt", Elemwise},
{"Neg", Elemwise},
{"Reciprocal", Elemwise},
{"Abs", Elemwise},
{"BroadcastTo", BroadcastTo},
{"Reshape", Reshape},
{"ReduceSum", Reduce},
{"ReduceMax", Reduce},
{"ReduceMin", Reduce},
{"Conv2D", Conv2d}};
}
auto iter = creators.find(op);
auto creator = (iter == creators.end() ? Opaque : iter->second);

View File

@ -84,6 +84,15 @@ class LiteGraph::GraphBuilder {
static PrimOpPtr Elemwise(const std::string &op, const std::string &name) {
return std::make_shared<ElemwiseOp>(op, name);
}
static PrimOpPtr BroadcastTo(const std::string &op, const std::string &name) {
return std::make_shared<BroadcastToOp>(op, name);
}
static PrimOpPtr Reshape(const std::string &op, const std::string &name) {
return std::make_shared<ReshapeOp>(op, name);
}
static PrimOpPtr Reduce(const std::string &op, const std::string &name) {
return std::make_shared<ReduceOp>(op, name);
}

View File

@ -24,6 +24,7 @@
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <numeric>
#include "backend/optimizer/graph_kernel/model/node.h"
@ -159,6 +160,24 @@ void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
}
DShape BroadcastToOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
return GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
}
DShape ReshapeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto new_shape = GetValue<std::vector<int64_t>>(attrs.find("shape")->second);
auto origin_shape = inputs[0]->shape;
for (size_t i = 0; i < new_shape.size(); i++) {
if (new_shape[i] == -1) {
auto origin_product = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>());
auto new_product = std::accumulate(new_shape.begin(), new_shape.end(), 1, std::multiplies<int64_t>());
new_shape[i] = origin_product / new_product * (-1);
break;
}
}
return new_shape;
}
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto axis = GetValue<std::vector<int64_t>>(attrs.find("axis")->second);
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);

View File

@ -63,6 +63,22 @@ class ElemwiseOp : public PrimOp {
// TODO(dayschan) rewrite InferShape/InferFormat
};
class ReshapeOp : public PrimOp {
public:
ReshapeOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, RESHAPE) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class BroadcastToOp : public PrimOp {
public:
BroadcastToOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, BROADCAST) {}
protected:
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class ReduceOp : public PrimOp {
public:
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}

View File

@ -50,6 +50,19 @@ class Net(Cell):
return self.reducemin(self.reducemin(red_res, 1), 1)
class EmptyNet(Cell):
def __init__(self):
super(EmptyNet, self).__init__()
self.add = P.Add()
self.neg = P.Neg()
def construct(self, x, y):
add_res1 = self.add(x, y)
neg_res1 = self.neg(x)
add_res2 = self.add(add_res1, neg_res1)
return add_res2
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -73,18 +86,33 @@ def test_basic():
assert res
def test_empty_graph():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
expect = input_y
net = EmptyNet()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4,
atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
test_empty_graph()
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()
test_empty_graph()