!20480 Add LiteGraph for graphkernel
Merge pull request !20480 from DeshiChen/0618_litegraph
This commit is contained in:
commit
7d2a07a2bd
|
@ -335,6 +335,10 @@ class CNodeDecoder {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
ShapeVector GetFakeAbstractShape(const ShapeVector &device_shape, const std::string &format) {
|
||||
return AbstractShapeCreator::GetFakeAbstractShape(device_shape, format);
|
||||
}
|
||||
|
||||
ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶meter_json,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(DEBUG) << "start decode parameter, " << parameter_json;
|
||||
|
|
|
@ -40,6 +40,9 @@ class AkgKernelJsonDecoder {
|
|||
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
||||
std::map<std::string, AnfNodePtr> nodes_map_;
|
||||
};
|
||||
|
||||
// infer abstract shape by device_shape and data_format
|
||||
ShapeVector GetFakeAbstractShape(const ShapeVector &device_shape, const std::string &format);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
|
||||
|
|
|
@ -14,6 +14,7 @@ if(ENABLE_D)
|
|||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"ascend/*.cc"
|
||||
"graph_kernel/*.cc"
|
||||
"graph_kernel/model/*.cc"
|
||||
)
|
||||
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
||||
endif()
|
||||
|
@ -22,6 +23,7 @@ if(ENABLE_GPU)
|
|||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"gpu/*.cc"
|
||||
"graph_kernel/*.cc"
|
||||
"graph_kernel/model/*.cc"
|
||||
)
|
||||
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
|
||||
endif()
|
||||
|
|
|
@ -323,16 +323,23 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
std::vector<TypeId> graph_output_type;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_input_format.push_back(input_format);
|
||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_input_type.push_back(input_type);
|
||||
if (kernel_with_index.first->isa<ValueNode>()) {
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
graph_input_format.emplace_back(kOpFormat_DEFAULT);
|
||||
graph_input_type.emplace_back(tensor->data_type());
|
||||
} else {
|
||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_input_format.emplace_back(std::move(input_format));
|
||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_input_type.emplace_back(input_type);
|
||||
}
|
||||
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
||||
fg->parameters()[i]->set_abstract(input_abs);
|
||||
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
|
||||
para_info_builder.SetOutputsFormat({input_format});
|
||||
para_info_builder.SetOutputsDeviceType({input_type});
|
||||
para_info_builder.SetOutputsFormat({graph_input_format.back()});
|
||||
para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
|
||||
para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i].get());
|
||||
|
@ -675,7 +682,8 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
|
|||
return axis;
|
||||
}
|
||||
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) {
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
|
||||
bool use_fake_abstract) {
|
||||
// Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
|
||||
MS_EXCEPTION_IF_NULL(out_info.type);
|
||||
auto out_type = out_info.type;
|
||||
|
@ -688,8 +696,14 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
// Setup abstract.
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, out_info.shape);
|
||||
cnode->set_abstract(abs_tensor);
|
||||
if (use_fake_abstract) {
|
||||
auto abs_shape = kernel::GetFakeAbstractShape(out_info.shape, out_info.format);
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, abs_shape);
|
||||
cnode->set_abstract(abs_tensor);
|
||||
} else {
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, out_info.shape);
|
||||
cnode->set_abstract(abs_tensor);
|
||||
}
|
||||
|
||||
// Setup kernel info.
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
|
@ -800,5 +814,125 @@ void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string>
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
|
||||
graphkernel::LiteGraph::GraphBuilder gb(GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)));
|
||||
std::map<AnfNodePtr, graphkernel::NodePtr> node_map;
|
||||
auto todos = TopoSort(func_graph->output());
|
||||
const auto ¶ms = func_graph->parameters();
|
||||
auto ExtractBuildInfo = [](const AnfNodePtr &node) {
|
||||
auto shape = GetDeviceShape(node);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
return graphkernel::NodeBase({shape, type, format});
|
||||
};
|
||||
// set inputs
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
node_map[params[i]] = gb.Parameter(ExtractBuildInfo(params[i]), std::string("input_") + std::to_string(i));
|
||||
}
|
||||
// set ops
|
||||
for (auto node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) continue;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) break;
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
graphkernel::NodePtrList inputs;
|
||||
std::transform(cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(inputs),
|
||||
[&node_map, &gb](const AnfNodePtr &no) {
|
||||
auto iter = node_map.find(no);
|
||||
if (iter != node_map.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(no);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
return gb.Value(tensor);
|
||||
}
|
||||
});
|
||||
node_map[node] = gb.Op(AnfAlgo::GetCNodeName(node), ExtractBuildInfo(node), inputs, prim->attrs());
|
||||
}
|
||||
// set outputs
|
||||
auto output_node = func_graph->output();
|
||||
if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
|
||||
graphkernel::NodePtrList outputs;
|
||||
auto mt = output_node->cast<CNodePtr>();
|
||||
std::transform(mt->inputs().begin() + 1, mt->inputs().end(), std::back_inserter(outputs),
|
||||
[&node_map](const AnfNodePtr &no) { return node_map[no]; });
|
||||
gb.SetOutputs(std::move(outputs));
|
||||
} else {
|
||||
gb.SetOutputs({node_map[output_node]});
|
||||
}
|
||||
return gb.Get();
|
||||
}
|
||||
|
||||
FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs) {
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
std::map<graphkernel::NodePtr, AnfNodePtr> node_map;
|
||||
for (const auto &inp : lite_graph->inputs()) {
|
||||
auto param = func_graph->add_parameter();
|
||||
node_map[inp] = param;
|
||||
auto abs_shape = kernel::GetFakeAbstractShape(inp->shape, inp->format);
|
||||
param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(inp->type), abs_shape));
|
||||
param->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get());
|
||||
}
|
||||
// Create CNodes. the ops is already in topo order
|
||||
for (const auto &op_node : lite_graph->ops()) {
|
||||
if (op_node->NodeType() != graphkernel::NType::Primitive) {
|
||||
MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
|
||||
}
|
||||
auto op = std::static_pointer_cast<graphkernel::PrimOp>(op_node);
|
||||
AnfNodePtrList inputs = {NewValueNode(std::make_shared<Primitive>(op->op(), op->attrs()))};
|
||||
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(inputs),
|
||||
[&node_map](const graphkernel::NodePtr &inp) -> AnfNodePtr {
|
||||
auto iter = node_map.find(inp);
|
||||
if (iter != node_map.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
if (inp->NodeType() != graphkernel::NType::Value) {
|
||||
MS_LOG(EXCEPTION) << "Node " << inp->name() << "should be a Value node";
|
||||
}
|
||||
auto inp_value = inp->As<graphkernel::ConstTensorNode>()->data();
|
||||
auto value_node = NewValueNode(inp_value);
|
||||
value_node->set_abstract(inp_value->ToAbstract());
|
||||
value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, value_node.get());
|
||||
return value_node;
|
||||
}
|
||||
});
|
||||
auto cnode = CreateCNode(inputs, func_graph, {op->format, op->shape, TypeIdToType(op->type)}, true);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
node_map[op_node] = cnode;
|
||||
}
|
||||
if (lite_graph->GetOutputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "The output of LiteGraph " << lite_graph->name() << " is empty.";
|
||||
} else if (lite_graph->GetOutputs().size() == 1) {
|
||||
func_graph->set_output(node_map[lite_graph->GetOutputs()[0]]);
|
||||
if (outputs != nullptr) {
|
||||
outputs->emplace_back(func_graph->output());
|
||||
}
|
||||
} else {
|
||||
AnfNodePtrList mt_inputs;
|
||||
AbstractBasePtrList out_abs_list;
|
||||
std::transform(lite_graph->GetOutputs().begin(), lite_graph->GetOutputs().end(), std::back_inserter(mt_inputs),
|
||||
[&node_map, &out_abs_list](const graphkernel::NodePtr &out) {
|
||||
auto out_node = node_map[out];
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
out_abs_list.emplace_back(out_node->abstract());
|
||||
return out_node;
|
||||
});
|
||||
auto mt = func_graph->NewCNode(prim::kPrimMakeTuple, mt_inputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
|
||||
mt->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
func_graph->AddNode(mt);
|
||||
func_graph->set_output(mt);
|
||||
if (outputs != nullptr) {
|
||||
*outputs = std::move(mt_inputs);
|
||||
}
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "backend/session/kernel_graph.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -90,7 +91,8 @@ ShapeVector GetShape(const AnfNodePtr &node);
|
|||
ShapeVector GetDeviceShape(const AnfNodePtr &node);
|
||||
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
|
||||
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
|
||||
bool use_fake_abstract = false);
|
||||
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
|
||||
bool IsKeepBasicNode(const AnfNodePtr &node);
|
||||
void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
|
||||
|
@ -130,6 +132,10 @@ ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t d
|
|||
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
// functions to graphkernel model
|
||||
graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph);
|
||||
FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs = nullptr);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
|
|
|
@ -91,7 +91,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
|||
}
|
||||
|
||||
PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(OptLevel_2, "highlevelopt1");
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(2, "highlevelopt1");
|
||||
// Reorder Cast and Type-insensitive node
|
||||
pm->AddPass(std::make_shared<ReorderOps>(), OptLevel_2);
|
||||
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
#include "backend/optimizer/graph_kernel/model/op_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
std::string LiteGraph::Dump() const {
|
||||
std::ostringstream os;
|
||||
os << name_ << "(";
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
os << inputs_[i]->name();
|
||||
if (i != inputs_.size() - 1) os << ", ";
|
||||
}
|
||||
os << ") -> ";
|
||||
auto &outputs = GetOutputs();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
os << outputs[i]->name();
|
||||
if (i != outputs.size() - 1) os << ", ";
|
||||
}
|
||||
os << " {\n";
|
||||
for (NodePtr op : ops_) {
|
||||
os << " " << *op << "\n";
|
||||
}
|
||||
os << "}";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
const NodePtrList &LiteGraph::GetOrderedNodes() {
|
||||
std::unordered_map<NodePtr, size_t> outdegrees;
|
||||
std::function<void(NodePtr)> dfs;
|
||||
std::set<NodePtr> visited;
|
||||
dfs = [&dfs, &outdegrees, &visited](const NodePtr &node) {
|
||||
visited.insert(node);
|
||||
for (auto &input : node->inputs()) {
|
||||
if (input->NodeType() == NType::Primitive) {
|
||||
++outdegrees[input];
|
||||
if (visited.count(input) == 0) {
|
||||
dfs(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
dfs(output_);
|
||||
NodePtrList res;
|
||||
NodePtrList stack;
|
||||
stack.push_back(output_);
|
||||
while (!stack.empty()) {
|
||||
auto cur = stack.back();
|
||||
stack.pop_back();
|
||||
res.push_back(cur);
|
||||
for (auto &input : cur->inputs()) {
|
||||
if (input->NodeType() != NType::Primitive) continue;
|
||||
--outdegrees[input];
|
||||
if (outdegrees[input] == 0) {
|
||||
stack.push_back(input);
|
||||
outdegrees.erase(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!outdegrees.empty()) {
|
||||
MS_LOG(ERROR) << "Circle was found:";
|
||||
for (auto &node : outdegrees) {
|
||||
MS_LOG(ERROR) << " " << *(node.first);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Circle size: " << outdegrees.size();
|
||||
}
|
||||
std::reverse(res.begin(), res.end());
|
||||
res.pop_back(); // erase the output node
|
||||
ops_ = std::move(res);
|
||||
return ops_;
|
||||
}
|
||||
|
||||
NodePtr LiteGraph::GraphBuilder::Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs,
|
||||
std::string node_name) {
|
||||
if (node_name.empty()) node_name = NewName();
|
||||
PrimOpPtr op_ptr = CreateOp(op, node_name);
|
||||
op_ptr->Infer(inputs, attrs);
|
||||
return graph_->Add(op_ptr);
|
||||
}
|
||||
|
||||
NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
|
||||
const DAttrs &attrs, std::string node_name) {
|
||||
auto op_ptr = Emit(op, inputs, attrs, node_name);
|
||||
op_ptr->SetBaseInfo(baseinfo);
|
||||
return op_ptr;
|
||||
}
|
||||
|
||||
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
|
||||
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
|
||||
if (creators.empty()) {
|
||||
creators = {
|
||||
{"Add", Elemwise},
|
||||
{"Sub", Elemwise},
|
||||
{"ReduceSum", Reduce},
|
||||
{"Conv2D", Conv2d},
|
||||
};
|
||||
}
|
||||
auto iter = creators.find(op);
|
||||
auto creator = (iter == creators.end() ? Opaque : iter->second);
|
||||
return creator(op, node_name);
|
||||
}
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
#include "backend/optimizer/graph_kernel/model/op_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
class LiteGraph {
|
||||
public:
|
||||
class GraphBuilder;
|
||||
explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {}
|
||||
|
||||
NodePtr &Add(PrimOpPtr op) {
|
||||
ops_.emplace_back(op);
|
||||
return ops_.back();
|
||||
}
|
||||
|
||||
const NodePtrList &GetOrderedNodes();
|
||||
|
||||
std::string Dump() const;
|
||||
const std::string &name() const { return name_; }
|
||||
const NodePtrList &ops() const { return ops_; }
|
||||
const NodePtrList &inputs() const { return inputs_; }
|
||||
const NodePtr &output() const { return output_; }
|
||||
const NodePtrList &GetOutputs() const { return output_->inputs(); }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
NodePtrList ops_; // save all operators in topo order
|
||||
NodePtrList inputs_;
|
||||
NodePtr output_;
|
||||
|
||||
private:
|
||||
int name_id_{0};
|
||||
};
|
||||
using LiteGraphPtr = std::shared_ptr<LiteGraph>;
|
||||
|
||||
class LiteGraph::GraphBuilder {
|
||||
public:
|
||||
explicit GraphBuilder(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); }
|
||||
|
||||
NodePtr Parameter(const NodeBase &baseinfo, std::string name = "") {
|
||||
if (name.empty()) name = NewName();
|
||||
auto para = std::make_shared<ParamNode>(name, baseinfo);
|
||||
graph_->inputs_.push_back(para);
|
||||
return para;
|
||||
}
|
||||
NodePtr Value(const tensor::TensorPtr &data, const std::string &name = "") {
|
||||
return std::make_shared<ConstTensorNode>(data, name);
|
||||
}
|
||||
|
||||
void SetOutputs(const NodePtrList &nodes) { graph_->output_->SetInputs(nodes); }
|
||||
|
||||
NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}, std::string node_name = "");
|
||||
NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {},
|
||||
std::string node_name = "");
|
||||
LiteGraphPtr Get() { return graph_; }
|
||||
|
||||
private:
|
||||
static PrimOpPtr Elemwise(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<ElemwiseOp>(op, name);
|
||||
}
|
||||
static PrimOpPtr Reduce(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<ReduceOp>(op, name);
|
||||
}
|
||||
static PrimOpPtr Opaque(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<OpaqueOp>(op, name);
|
||||
}
|
||||
static PrimOpPtr Conv2d(const std::string &op, const std::string &name) {
|
||||
return std::make_shared<Conv2dOp>(op, name);
|
||||
}
|
||||
|
||||
PrimOpPtr CreateOp(const std::string &id, const std::string &name);
|
||||
std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); }
|
||||
|
||||
LiteGraphPtr graph_;
|
||||
};
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "mindspore/core/ir/value.h"
|
||||
#include "mindspore/core/ir/tensor.h"
|
||||
#include "mindspore/core/utils/shape_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
void Node::DumpTensor(std::ostringstream &os) const {
|
||||
os << name_ << "[";
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
os << shape[i];
|
||||
if (i + 1 < shape.size()) os << ",";
|
||||
}
|
||||
os << "]{" << kernel::TypeId2String(type) << "x" << format << "}";
|
||||
}
|
||||
|
||||
void Node::AddInput(const NodePtr &new_input) {
|
||||
MS_EXCEPTION_IF_NULL(new_input);
|
||||
new_input->AddUser(this, inputs_.size());
|
||||
inputs_.emplace_back(new_input);
|
||||
}
|
||||
|
||||
void Node::SetInput(size_t i, const NodePtr &new_input) {
|
||||
MS_EXCEPTION_IF_NULL(new_input);
|
||||
if (i >= inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range " << inputs_.size();
|
||||
}
|
||||
auto &old_input = inputs_[i];
|
||||
old_input->RemoveUser(this, i);
|
||||
new_input->AddUser(this, i);
|
||||
inputs_[i] = new_input;
|
||||
}
|
||||
|
||||
void Node::SetInputs(const NodePtrList &inputs) {
|
||||
if (!inputs_.empty()) {
|
||||
// remove the original inputs
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
inputs_[i]->RemoveUser(this, i);
|
||||
}
|
||||
inputs_.clear();
|
||||
}
|
||||
inputs_.reserve(inputs.size());
|
||||
for (const auto &inp : inputs) {
|
||||
AddInput(inp);
|
||||
}
|
||||
}
|
||||
|
||||
void Node::ReplaceWith(const NodePtr &other_node) {
|
||||
if (this->users_.empty()) return;
|
||||
if (this->NodeType() != NType::Primitive) {
|
||||
MS_LOG(EXCEPTION) << "Only Primitive node can be replaced, but the node type is " << NodeType();
|
||||
}
|
||||
// copy the users before traversal
|
||||
auto users = this->users_;
|
||||
for (auto &user : users) {
|
||||
for (auto idx : user.second) {
|
||||
user.first->SetInput(idx, other_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "mindspore/core/ir/value.h"
|
||||
#include "mindspore/core/ir/tensor.h"
|
||||
#include "mindspore/core/utils/shape_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
enum class NType {
|
||||
Base,
|
||||
Primitive,
|
||||
Parameter,
|
||||
Value,
|
||||
Output,
|
||||
};
|
||||
|
||||
using DFormat = std::string;
|
||||
using DShape = ShapeVector;
|
||||
using DAttrs = std::unordered_map<std::string, ValuePtr>;
|
||||
|
||||
struct NodeBase {
|
||||
DShape shape;
|
||||
TypeId type;
|
||||
DFormat format;
|
||||
};
|
||||
|
||||
class Node;
|
||||
using NodePtr = std::shared_ptr<Node>;
|
||||
using NodePtrList = std::vector<NodePtr>;
|
||||
class Node : public NodeBase {
|
||||
public:
|
||||
Node(const NodeBase &baseinfo, const std::string &name) : NodeBase(baseinfo), name_(name) {}
|
||||
virtual ~Node() {
|
||||
// remove this node from the previous nodes' user.
|
||||
SetInputs({});
|
||||
}
|
||||
|
||||
void SetBaseInfo(NodeBase baseinfo) {
|
||||
this->shape = std::move(baseinfo.shape);
|
||||
this->type = std::move(baseinfo.type);
|
||||
this->format = std::move(baseinfo.format);
|
||||
}
|
||||
virtual NType NodeType() { return NType::Base; }
|
||||
friend std::ostream &operator<<(std::ostream &output, const Node &n) {
|
||||
std::ostringstream os;
|
||||
n.Dump(os);
|
||||
output << os.str();
|
||||
return output;
|
||||
}
|
||||
virtual void Dump(std::ostringstream &os) const = 0;
|
||||
virtual void DumpTensor(std::ostringstream &os) const;
|
||||
|
||||
void AddInput(const NodePtr &new_input);
|
||||
void SetInput(size_t i, const NodePtr &new_input);
|
||||
void SetInputs(const NodePtrList &inputs);
|
||||
void ReplaceWith(const NodePtr &other_node);
|
||||
|
||||
template <typename T>
|
||||
T *As() {
|
||||
return static_cast<T *>(this);
|
||||
}
|
||||
template <typename T>
|
||||
const T *As() const {
|
||||
return static_cast<const T *>(this);
|
||||
}
|
||||
|
||||
const std::string &name() const { return name_; }
|
||||
const DAttrs &attrs() const { return attrs_; }
|
||||
const NodePtrList &inputs() const { return inputs_; }
|
||||
const std::unordered_map<Node *, std::set<size_t>> &users() const { return users_; }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
DAttrs attrs_;
|
||||
NodePtrList inputs_;
|
||||
std::unordered_map<Node *, std::set<size_t>> users_;
|
||||
|
||||
private:
|
||||
// the nodes' users are only maintained by AddInput/SetInput.
|
||||
void AddUser(Node *user, size_t index) { users_[user].insert(index); }
|
||||
void RemoveUser(Node *user, size_t index) {
|
||||
if (auto iter = users_.find(user); iter != users_.end()) {
|
||||
iter->second.erase(index);
|
||||
if (iter->second.empty()) {
|
||||
users_.erase(iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ConstTensorNode : public Node {
|
||||
public:
|
||||
explicit ConstTensorNode(const tensor::TensorPtr &data, const std::string &name = "")
|
||||
: Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}, name), data_(data) {}
|
||||
NType NodeType() override { return NType::Value; }
|
||||
void Dump(std::ostringstream &os) const override { os << ToString(); }
|
||||
void DumpTensor(std::ostringstream &os) const override { os << ToString(); }
|
||||
std::string ToString() const { return data_->data().ToString(this->type, this->shape, false); }
|
||||
const tensor::TensorPtr data() const { return data_; }
|
||||
|
||||
protected:
|
||||
tensor::TensorPtr data_;
|
||||
};
|
||||
|
||||
class ParamNode : public Node {
|
||||
public:
|
||||
ParamNode(const std::string &name, const NodeBase &baseinfo) : Node(baseinfo, name) {}
|
||||
void Dump(std::ostringstream &os) const override { DumpTensor(os); }
|
||||
NType NodeType() override { return NType::Parameter; }
|
||||
};
|
||||
|
||||
class OutputNode : public Node {
|
||||
public:
|
||||
OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, "") {}
|
||||
void Dump(std::ostringstream &os) const override { ; }
|
||||
NType NodeType() override { return NType::Output; }
|
||||
};
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/model/op_node.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
void PrimOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
this->shape = InferShape(inputs, attrs);
|
||||
this->type = InferType(inputs, attrs);
|
||||
this->format = InferFormat(inputs, attrs);
|
||||
this->attrs_ = attrs;
|
||||
SetInputs(inputs);
|
||||
}
|
||||
|
||||
void PrimOp::Dump(std::ostringstream &os) const {
|
||||
DumpTensor(os);
|
||||
os << " = " << this->op_ << "(";
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
inputs_[i]->DumpTensor(os);
|
||||
if (i != inputs_.size() - 1) os << ", ";
|
||||
}
|
||||
os << ")";
|
||||
std::ostringstream attr_os;
|
||||
bool has_attr = false;
|
||||
std::set<std::string> black_list = {"IsFeatureMapInputList", "IsFeatureMapOutput", "output_names", "input_names"};
|
||||
for (auto attr : attrs_) {
|
||||
if (attr.second != nullptr && black_list.count(attr.first) == 0) {
|
||||
if (has_attr) {
|
||||
attr_os << ", ";
|
||||
} else {
|
||||
has_attr = true;
|
||||
}
|
||||
attr_os << attr.first << ": " << attr.second->ToString();
|
||||
}
|
||||
}
|
||||
if (has_attr) {
|
||||
os << " // attr {" << attr_os.str() << "}";
|
||||
}
|
||||
}
|
||||
|
||||
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
PrimOp::Infer(inputs, attrs);
|
||||
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
|
||||
for (auto &ref : inputs) {
|
||||
if (ref->shape.size() != this->shape.size()) return true;
|
||||
for (size_t i = 0; i < this->shape.size(); ++i) {
|
||||
if (ref->shape[i] != this->shape[i]) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
compute_type_ = IsBroadcast(inputs) ? BROADCAST : ELEMWISE;
|
||||
}
|
||||
|
||||
DShape ReduceOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
auto axis = GetValue<std::vector<int64_t>>(attrs.find("axis")->second);
|
||||
auto keepdims = GetValue<bool>(attrs.find("keep_dims")->second);
|
||||
if (keepdims) {
|
||||
DShape new_shape = inputs[0]->shape;
|
||||
for (auto x : axis) {
|
||||
new_shape[x] = 1;
|
||||
}
|
||||
return new_shape;
|
||||
}
|
||||
DShape new_shape;
|
||||
const auto &input_shape = inputs[0]->shape;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
if (std::find(axis.begin(), axis.end(), i) == axis.end()) {
|
||||
new_shape.emplace_back(input_shape[i]);
|
||||
}
|
||||
}
|
||||
if (new_shape.empty()) {
|
||||
new_shape.emplace_back(1);
|
||||
}
|
||||
return new_shape;
|
||||
}
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_OP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace graphkernel {
|
||||
class PrimOp : public Node {
|
||||
public:
|
||||
enum ComputeType {
|
||||
RESHAPE,
|
||||
ELEMWISE,
|
||||
BROADCAST,
|
||||
REDUCE,
|
||||
OPAQUE,
|
||||
};
|
||||
|
||||
PrimOp(const std::string &op, const std::string &node_name, ComputeType compute)
|
||||
: Node({{}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}, node_name), op_(op), compute_type_(compute) {}
|
||||
|
||||
virtual void Infer(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
void Dump(std::ostringstream &os) const override;
|
||||
NType NodeType() override { return NType::Primitive; }
|
||||
|
||||
const std::string &op() const { return op_; }
|
||||
ComputeType compute_type() const { return compute_type_; }
|
||||
|
||||
protected:
|
||||
std::string op_;
|
||||
ComputeType compute_type_;
|
||||
virtual DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->shape; }
|
||||
virtual TypeId InferType(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->type; }
|
||||
virtual DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) { return inputs[0]->format; }
|
||||
};
|
||||
using PrimOpPtr = std::shared_ptr<PrimOp>;
|
||||
|
||||
class ElemwiseOp : public PrimOp {
|
||||
public:
|
||||
ElemwiseOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, ELEMWISE) {}
|
||||
void Infer(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
// TODO(dayschan) rewrite InferShape/InferFormat
|
||||
};
|
||||
|
||||
class ReduceOp : public PrimOp {
|
||||
public:
|
||||
ReduceOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, REDUCE) {}
|
||||
|
||||
protected:
|
||||
DShape InferShape(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
class OpaqueOp : public PrimOp {
|
||||
public:
|
||||
OpaqueOp(const std::string &op, const std::string &node_name) : PrimOp(op, node_name, OPAQUE) {}
|
||||
};
|
||||
|
||||
class Conv2dOp : public OpaqueOp {
|
||||
public:
|
||||
Conv2dOp(const std::string &op, const std::string &node_name) : OpaqueOp("Conv2D", node_name) {}
|
||||
};
|
||||
} // namespace graphkernel
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -65,8 +65,8 @@ def gen_data():
|
|||
y1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
|
||||
x2_np = np.random.randint(1, 5, 1).astype(np.int32)
|
||||
y2_np = np.random.randint(1, 5, 1).astype(np.int32)
|
||||
x3_np = np.array(768).astype(np.float32)
|
||||
y3_np = np.array(3072.5).astype(np.float32)
|
||||
x3_np = np.array([768]).astype(np.float32)
|
||||
y3_np = np.array([3072.5]).astype(np.float32)
|
||||
|
||||
x0 = Tensor(x0_np)
|
||||
y0 = Tensor(y0_np)
|
|
@ -144,6 +144,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/model/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc"
|
||||
"../../../mindspore/ccsrc/backend/session/ascend_session.cc"
|
||||
"../../../mindspore/ccsrc/backend/session/ascend_auto_monad.cc"
|
||||
|
|
Loading…
Reference in New Issue