fix tuple output when 310 infer control graph

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-09-19 17:10:50 +08:00
parent bbaca2005b
commit db77397f4f
2 changed files with 94 additions and 11 deletions

View File

@ -10,7 +10,7 @@ if("${ENABLE_HIDDEN}" STREQUAL "OFF")
string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
endif()
if(ENABLE_D)
if(ENABLE_D OR ENABLE_ACL)
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"ascend/*.cc"
"graph_kernel/*.cc"

View File

@ -21,9 +21,12 @@
#include <string>
#include <algorithm>
#include <numeric>
#include <deque>
#include <functional>
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "cxx_api/factory.h"
#include "vm/backend.h"
#include "vm/transform.h"
@ -56,14 +59,7 @@ class MSTensorRef : public BaseRef {
std::vector<MSTensor> res;
if (utils::isa<VectorRef>(args)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
for (size_t i = 0; i < args_vec.size(); ++i) {
const auto &item = args_vec[i];
if (!utils::isa<MSTensorRef>(item)) {
MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i;
}
auto wrapper = utils::cast<MSTensorRef>(item);
res.push_back(wrapper.ms_tensor_);
}
res = ConvertTuple(args_vec);
} else if (utils::isa<MSTensorRef>(args)) {
auto wrapper = utils::cast<MSTensorRef>(args);
res.push_back(wrapper.ms_tensor_);
@ -101,6 +97,25 @@ class MSTensorRef : public BaseRef {
}
private:
static std::vector<MSTensor> ConvertTuple(const VectorRef &args) {
std::vector<MSTensor> outs;
for (size_t i = 0; i < args.size(); ++i) {
const auto &item = args[i];
if (utils::isa<VectorRef>(item)) {
VectorRef args_vec = utils::cast<VectorRef>(args);
auto ret = ConvertTuple(args_vec);
outs.insert(outs.end(), ret.begin(), ret.end());
} else if (utils::isa<MSTensorRef>(item)) {
auto wrapper = utils::cast<MSTensorRef>(item);
outs.push_back(wrapper.ms_tensor_);
} else {
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString()
<< " must be MSTensorRef or VectorRef{MSTensorRef...}";
}
}
return outs;
}
MSTensor ms_tensor_;
};
@ -114,7 +129,11 @@ class MultiGraphAclSession : public session::SessionBasic {
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
private:
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
std::map<GraphId, GraphCell> graphs_ = {};
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
std::shared_ptr<AclModelOptions> options_ = nullptr;
};
@ -138,8 +157,16 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
std::shared_ptr<AclModelOptions> options_;
};
MS_LOG(INFO) << "Start MultiGraph Compile.";
auto kernel_graph = ConstructKernelGraph(lst, outputs, false);
// construct kernel graph
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
// concert to om data
ModelConverter model_converter_;
model_converter_.set_options(options_);
FirstGraphModeGuard guard(options_);
@ -148,6 +175,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
MS_LOG(ERROR) << "Load MindIR failed.";
return kMCFailed;
}
// load
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = GraphCell(graph);
@ -156,6 +184,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
MS_LOG(EXCEPTION) << "Load failed.";
}
graphs_[kernel_graph->graph_id()] = graph_cell;
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
return kernel_graph->graph_id();
}
@ -172,7 +201,61 @@ void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor
if (ret != kSuccess) {
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
}
(*outputs) = MSTensorRef::Convert(out_tensors);
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
}
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
auto out_nodes = kernel_graphs_[graph_id]->outputs();
for (auto &out : out_nodes) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString();
}
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(out, 0);
auto &anf_node = item_with_index.first;
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else {
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
if (!out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
<< " MSTensor remained.";
}
return outs;
}
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
std::deque<MSTensor> *out_tensors) {
MS_EXCEPTION_IF_NULL(out_tensors);
VectorRef outs;
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(tuple_node->input(i), 0);
auto &anf_node = item_with_index.first;
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << anf_node->DebugString();
}
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else {
outs.emplace_back(MSTensorRef(out_tensors->front()));
out_tensors->pop_front();
}
}
return outs;
}
class AclBackend : public compile::MsBackend {