forked from mindspore-Ecosystem/mindspore
168 lines
4.6 KiB
C++
168 lines
4.6 KiB
C++
/**
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "common/graph_util.h"
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include "common/mslog.h"
|
|
#include "include/errorcode.h"
|
|
|
|
namespace mindspore {
|
|
namespace predict {
|
|
OpGraph *OpGraph::Build(const SubGraphDef &subGraphDef) {
|
|
auto graph = std::unique_ptr<OpGraph>(new OpGraph());
|
|
if (graph == nullptr) {
|
|
MS_LOGE("malloc opgraph failed");
|
|
return nullptr;
|
|
}
|
|
|
|
auto nodeDefs = subGraphDef.nodes();
|
|
if (nodeDefs == nullptr) {
|
|
MS_LOGE("nodeDefs from subGraphDef is nullptr");
|
|
return nullptr;
|
|
}
|
|
|
|
uint32_t opCount = nodeDefs->size();
|
|
for (uint32_t i = 0; i < opCount; i++) {
|
|
auto nodeDef = nodeDefs->GetAs<NodeDef>(i);
|
|
MS_ASSERT(nodeDef != nullptr);
|
|
auto ret = graph->AddEdge(*nodeDef, *nodeDefs);
|
|
if (ret != RET_OK) {
|
|
MS_LOGE("%s add edge failed. ret:%d", nodeDef->opDef()->name()->c_str(), ret);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
return graph.release();
|
|
}
|
|
|
|
int OpGraph::AddEdge(const NodeDef &srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<NodeDef>> &nodeDefs) {
|
|
MS_ASSERT(srcNodeDef.opDef() != nullptr);
|
|
MS_ASSERT(srcNodeDef.opDef()->name() != nullptr);
|
|
NODE_ID srcId = std::string(srcNodeDef.opDef()->name()->c_str());
|
|
uint32_t opCount = nodeDefs.size();
|
|
|
|
MS_ASSERT(srcNodeDef.opDef()->outputIndex() != nullptr);
|
|
for (auto index : *(srcNodeDef.opDef()->outputIndex())) {
|
|
for (uint32_t i = 0; i < opCount; i++) {
|
|
auto dstNodeDef = nodeDefs.GetAs<NodeDef>(i);
|
|
bool find = false;
|
|
MS_ASSERT(dstNodeDef != nullptr);
|
|
MS_ASSERT(dstNodeDef->opDef() != nullptr);
|
|
auto inputIndex = dstNodeDef->opDef()->inputIndex();
|
|
MS_ASSERT(inputIndex != nullptr);
|
|
if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
|
|
find = true;
|
|
}
|
|
|
|
if (!find) {
|
|
continue;
|
|
}
|
|
MS_ASSERT(dstNodeDef->opDef()->name() != nullptr);
|
|
NODE_ID dstId = std::string(dstNodeDef->opDef()->name()->c_str());
|
|
auto ret = AddEdge(srcId, dstId);
|
|
if (ret != RET_OK) {
|
|
return ret;
|
|
}
|
|
}
|
|
}
|
|
|
|
return RET_OK;
|
|
}
|
|
|
|
int OpGraph::AddEdge(const NODE_ID &srcId, const NODE_ID &dstId) {
|
|
auto srcNode = AddNode(srcId);
|
|
if (srcNode == nullptr) {
|
|
MS_LOGE("add srcNode failed");
|
|
return RET_ERROR;
|
|
}
|
|
srcNode->AddOutEdge(dstId);
|
|
auto dstNode = AddNode(dstId);
|
|
if (dstNode == nullptr) {
|
|
MS_LOGE("add dstNode failed");
|
|
return RET_ERROR;
|
|
}
|
|
dstNode->AddInEdge(srcId);
|
|
return RET_OK;
|
|
}
|
|
|
|
OpNode *OpGraph::GetNode(const NODE_ID &nodeId) {
|
|
auto node = nodes.find(nodeId);
|
|
if (node == nodes.end()) {
|
|
return nullptr;
|
|
}
|
|
return node->second;
|
|
}
|
|
|
|
OpNode *OpGraph::AddNode(const NODE_ID &nodeId) {
|
|
auto node = GetNode(nodeId);
|
|
if (node != nullptr) {
|
|
return node;
|
|
}
|
|
node = new (std::nothrow) OpNode(nodeId);
|
|
if (node == nullptr) {
|
|
MS_LOGE("new node failed");
|
|
return nullptr;
|
|
}
|
|
nodes[nodeId] = node;
|
|
return node;
|
|
}
|
|
|
|
std::unordered_set<NODE_ID> OpGraph::GetInputNode() {
|
|
std::unordered_set<NODE_ID> inputNodes;
|
|
for (const auto &iter : nodes) {
|
|
auto node = iter.second;
|
|
MS_ASSERT(node != nullptr);
|
|
if (node->GetAllInEdge().empty()) {
|
|
inputNodes.insert(node->ID());
|
|
}
|
|
}
|
|
return inputNodes;
|
|
}
|
|
|
|
std::unordered_set<NODE_ID> OpGraph::GetOutputNode() {
|
|
std::unordered_set<NODE_ID> outputNodes;
|
|
for (const auto &iter : nodes) {
|
|
auto node = iter.second;
|
|
MS_ASSERT(node != nullptr);
|
|
if (node->GetAllOutEdge().empty()) {
|
|
outputNodes.insert(node->ID());
|
|
}
|
|
}
|
|
return outputNodes;
|
|
}
|
|
|
|
OpGraph::~OpGraph() {
|
|
for (auto iter : nodes) {
|
|
if (iter.second != nullptr) {
|
|
delete iter.second;
|
|
}
|
|
}
|
|
nodes.clear();
|
|
}
|
|
|
|
NODE_ID OpNode::ID() { return id; }
|
|
|
|
void OpNode::AddInEdge(const NODE_ID &nodeId) { inEdges.insert(nodeId); }
|
|
|
|
void OpNode::AddOutEdge(const NODE_ID &nodeId) { outEdges.insert(nodeId); }
|
|
|
|
std::unordered_set<NODE_ID> OpNode::GetAllInEdge() { return inEdges; }
|
|
|
|
std::unordered_set<NODE_ID> OpNode::GetAllOutEdge() { return outEdges; }
|
|
} // namespace predict
|
|
} // namespace mindspore
|