forked from mindspore-Ecosystem/mindspore
fix tuple output when 310 infer control graph
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
bbaca2005b
commit
db77397f4f
|
@ -10,7 +10,7 @@ if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
|||
string(REPLACE " -fvisibility=hidden" " -fvisibility=default" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
if(ENABLE_D OR ENABLE_ACL)
|
||||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"ascend/*.cc"
|
||||
"graph_kernel/*.cc"
|
||||
|
|
|
@ -21,9 +21,12 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "cxx_api/factory.h"
|
||||
#include "vm/backend.h"
|
||||
#include "vm/transform.h"
|
||||
|
@ -56,14 +59,7 @@ class MSTensorRef : public BaseRef {
|
|||
std::vector<MSTensor> res;
|
||||
if (utils::isa<VectorRef>(args)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
for (size_t i = 0; i < args_vec.size(); ++i) {
|
||||
const auto &item = args_vec[i];
|
||||
if (!utils::isa<MSTensorRef>(item)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid item " << item.ToString() << " at index " << i;
|
||||
}
|
||||
auto wrapper = utils::cast<MSTensorRef>(item);
|
||||
res.push_back(wrapper.ms_tensor_);
|
||||
}
|
||||
res = ConvertTuple(args_vec);
|
||||
} else if (utils::isa<MSTensorRef>(args)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(args);
|
||||
res.push_back(wrapper.ms_tensor_);
|
||||
|
@ -101,6 +97,25 @@ class MSTensorRef : public BaseRef {
|
|||
}
|
||||
|
||||
private:
|
||||
static std::vector<MSTensor> ConvertTuple(const VectorRef &args) {
|
||||
std::vector<MSTensor> outs;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
const auto &item = args[i];
|
||||
if (utils::isa<VectorRef>(item)) {
|
||||
VectorRef args_vec = utils::cast<VectorRef>(args);
|
||||
auto ret = ConvertTuple(args_vec);
|
||||
outs.insert(outs.end(), ret.begin(), ret.end());
|
||||
} else if (utils::isa<MSTensorRef>(item)) {
|
||||
auto wrapper = utils::cast<MSTensorRef>(item);
|
||||
outs.push_back(wrapper.ms_tensor_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid BaseRef " << args.ToString()
|
||||
<< " must be MSTensorRef or VectorRef{MSTensorRef...}";
|
||||
}
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
MSTensor ms_tensor_;
|
||||
};
|
||||
|
||||
|
@ -114,7 +129,11 @@ class MultiGraphAclSession : public session::SessionBasic {
|
|||
void SetOptions(const std::shared_ptr<AclModelOptions> &options) { options_ = options; }
|
||||
|
||||
private:
|
||||
VectorRef ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors);
|
||||
VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque<MSTensor> *out_tensors);
|
||||
|
||||
std::map<GraphId, GraphCell> graphs_ = {};
|
||||
std::map<GraphId, KernelGraphPtr> kernel_graphs_ = {};
|
||||
std::shared_ptr<AclModelOptions> options_ = nullptr;
|
||||
};
|
||||
|
||||
|
@ -138,8 +157,16 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
|
|||
std::shared_ptr<AclModelOptions> options_;
|
||||
};
|
||||
MS_LOG(INFO) << "Start MultiGraph Compile.";
|
||||
auto kernel_graph = ConstructKernelGraph(lst, outputs, false);
|
||||
// construct kernel graph
|
||||
auto kernel_graph = SessionBasic::ConstructKernelGraph(lst, outputs, false);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("310_multi_graph_pm");
|
||||
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
// concert to om data
|
||||
ModelConverter model_converter_;
|
||||
model_converter_.set_options(options_);
|
||||
FirstGraphModeGuard guard(options_);
|
||||
|
@ -148,6 +175,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
|
|||
MS_LOG(ERROR) << "Load MindIR failed.";
|
||||
return kMCFailed;
|
||||
}
|
||||
// load
|
||||
std::shared_ptr<Graph> graph = std::make_shared<Graph>(std::make_shared<Graph::GraphData>(om_data, ModelType::kOM));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_cell = GraphCell(graph);
|
||||
|
@ -156,6 +184,7 @@ GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const
|
|||
MS_LOG(EXCEPTION) << "Load failed.";
|
||||
}
|
||||
graphs_[kernel_graph->graph_id()] = graph_cell;
|
||||
kernel_graphs_[kernel_graph->graph_id()] = kernel_graph;
|
||||
MS_LOG(INFO) << "Mulit graph compile success, graph id " << kernel_graph->graph_id();
|
||||
return kernel_graph->graph_id();
|
||||
}
|
||||
|
@ -172,7 +201,61 @@ void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector<MSTensor
|
|||
if (ret != kSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed.";
|
||||
}
|
||||
(*outputs) = MSTensorRef::Convert(out_tensors);
|
||||
|
||||
std::deque<MSTensor> out_tensors_deque(out_tensors.begin(), out_tensors.end());
|
||||
(*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque);
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
auto out_nodes = kernel_graphs_[graph_id]->outputs();
|
||||
for (auto &out : out_nodes) {
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString();
|
||||
}
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(out, 0);
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else {
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
if (!out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size()
|
||||
<< " MSTensor remained.";
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node,
|
||||
std::deque<MSTensor> *out_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(out_tensors);
|
||||
VectorRef outs;
|
||||
for (size_t i = 1; i < tuple_node->inputs().size(); ++i) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(tuple_node->input(i), 0);
|
||||
auto &anf_node = item_with_index.first;
|
||||
if (out_tensors->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << anf_node->DebugString();
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
|
||||
} else {
|
||||
outs.emplace_back(MSTensorRef(out_tensors->front()));
|
||||
out_tensors->pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
return outs;
|
||||
}
|
||||
|
||||
class AclBackend : public compile::MsBackend {
|
||||
|
|
Loading…
Reference in New Issue