mindspore/predict/common/graph_util.cc

168 lines
4.6 KiB
C++
Raw Normal View History

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_util.h"
#include <fstream>
#include <sstream>
#include "common/mslog.h"
#include "include/errorcode.h"
namespace mindspore {
namespace predict {
OpGraph *OpGraph::Build(const SubGraphDef &subGraphDef) {
auto graph = std::unique_ptr<OpGraph>(new OpGraph());
if (graph == nullptr) {
MS_LOGE("malloc opgraph failed");
return nullptr;
}
auto nodeDefs = subGraphDef.nodes();
if (nodeDefs == nullptr) {
MS_LOGE("nodeDefs from subGraphDef is nullptr");
return nullptr;
}
uint32_t opCount = nodeDefs->size();
for (uint32_t i = 0; i < opCount; i++) {
auto nodeDef = nodeDefs->GetAs<NodeDef>(i);
MS_ASSERT(nodeDef != nullptr);
auto ret = graph->AddEdge(*nodeDef, *nodeDefs);
if (ret != RET_OK) {
MS_LOGE("%s add edge failed. ret:%d", nodeDef->opDef()->name()->c_str(), ret);
return nullptr;
}
}
return graph.release();
}
int OpGraph::AddEdge(const NodeDef &srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<NodeDef>> &nodeDefs) {
MS_ASSERT(srcNodeDef.opDef() != nullptr);
MS_ASSERT(srcNodeDef.opDef()->name() != nullptr);
NODE_ID srcId = std::string(srcNodeDef.opDef()->name()->c_str());
uint32_t opCount = nodeDefs.size();
MS_ASSERT(srcNodeDef.opDef()->outputIndex() != nullptr);
for (auto index : *(srcNodeDef.opDef()->outputIndex())) {
for (uint32_t i = 0; i < opCount; i++) {
auto dstNodeDef = nodeDefs.GetAs<NodeDef>(i);
bool find = false;
MS_ASSERT(dstNodeDef != nullptr);
MS_ASSERT(dstNodeDef->opDef() != nullptr);
auto inputIndex = dstNodeDef->opDef()->inputIndex();
MS_ASSERT(inputIndex != nullptr);
if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
find = true;
}
if (!find) {
continue;
}
MS_ASSERT(dstNodeDef->opDef()->name() != nullptr);
NODE_ID dstId = std::string(dstNodeDef->opDef()->name()->c_str());
auto ret = AddEdge(srcId, dstId);
if (ret != RET_OK) {
return ret;
}
}
}
return RET_OK;
}
int OpGraph::AddEdge(const NODE_ID &srcId, const NODE_ID &dstId) {
auto srcNode = AddNode(srcId);
if (srcNode == nullptr) {
MS_LOGE("add srcNode failed");
return RET_ERROR;
}
srcNode->AddOutEdge(dstId);
auto dstNode = AddNode(dstId);
if (dstNode == nullptr) {
MS_LOGE("add dstNode failed");
return RET_ERROR;
}
dstNode->AddInEdge(srcId);
return RET_OK;
}
OpNode *OpGraph::GetNode(const NODE_ID &nodeId) {
auto node = nodes.find(nodeId);
if (node == nodes.end()) {
return nullptr;
}
return node->second;
}
OpNode *OpGraph::AddNode(const NODE_ID &nodeId) {
auto node = GetNode(nodeId);
if (node != nullptr) {
return node;
}
node = new (std::nothrow) OpNode(nodeId);
if (node == nullptr) {
MS_LOGE("new node failed");
return nullptr;
}
nodes[nodeId] = node;
return node;
}
std::unordered_set<NODE_ID> OpGraph::GetInputNode() {
std::unordered_set<NODE_ID> inputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
MS_ASSERT(node != nullptr);
if (node->GetAllInEdge().empty()) {
inputNodes.insert(node->ID());
}
}
return inputNodes;
}
std::unordered_set<NODE_ID> OpGraph::GetOutputNode() {
std::unordered_set<NODE_ID> outputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
MS_ASSERT(node != nullptr);
if (node->GetAllOutEdge().empty()) {
outputNodes.insert(node->ID());
}
}
return outputNodes;
}
OpGraph::~OpGraph() {
for (auto iter : nodes) {
if (iter.second != nullptr) {
delete iter.second;
}
}
nodes.clear();
}
NODE_ID OpNode::ID() { return id; }
void OpNode::AddInEdge(const NODE_ID &nodeId) { inEdges.insert(nodeId); }
void OpNode::AddOutEdge(const NODE_ID &nodeId) { outEdges.insert(nodeId); }
std::unordered_set<NODE_ID> OpNode::GetAllInEdge() { return inEdges; }
std::unordered_set<NODE_ID> OpNode::GetAllOutEdge() { return outEdges; }
} // namespace predict
} // namespace mindspore