codedex_bot
This commit is contained in:
parent
1c3fc5c49b
commit
1f107d5a8a
|
@ -1220,6 +1220,5 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
namespace mindspore {
|
||||
namespace session {
|
||||
namespace {
|
||||
void UpdateOutputTensors(VectorRef *outputs,
|
||||
void UpdateOutputTensors(const VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto item : *outputs) {
|
||||
|
|
|
@ -35,7 +35,6 @@ using std::vector;
|
|||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
|
||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||
try {
|
||||
auto session = std::make_shared<MSInferSession>();
|
||||
|
@ -271,36 +270,18 @@ void MSInferSession::RegAllOp() {
|
|||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
if (c_expression == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(c_expression);
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
if (c_expression_dict == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get dict from mindspore._c_expression module.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(c_expression_dict);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
if (op_info_loader_class == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_class);
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
if (op_info_loader == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader);
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
if (op_info_loader_ins == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
if (all_ops_info_vector_addr_ul == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr.";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
|
|
|
@ -494,54 +494,52 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
// get primitive of old node
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
if (prim != nullptr) {
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
|
||||
cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
|
||||
} else {
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_fg = BasicClone(fg);
|
||||
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
|
||||
cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
bool optimize_depend = false;
|
||||
bool optimize_control_depend = false;
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
|
||||
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
|
||||
optimize_depend = true;
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) {
|
||||
optimize_control_depend = true;
|
||||
}
|
||||
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
|
||||
origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>();
|
||||
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
||||
auto anf = origin_inputs[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
// anf has been created before
|
||||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
||||
cnode_inputs.push_back((*other_graph_cnode)[anf]);
|
||||
cnode_inputs->push_back((*other_graph_cnode)[anf]);
|
||||
continue;
|
||||
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
|
||||
// if input is a value node,
|
||||
auto new_value_node = CreateNewValueNode(anf, graph);
|
||||
if (new_value_node != nullptr) {
|
||||
cnode_inputs.emplace_back(new_value_node);
|
||||
cnode_inputs->emplace_back(new_value_node);
|
||||
}
|
||||
continue;
|
||||
} else if (anf->isa<Parameter>()) {
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
|
||||
cnode_inputs.push_back(new_parameter);
|
||||
cnode_inputs->push_back(new_parameter);
|
||||
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
} else {
|
||||
|
@ -549,20 +547,31 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
|||
}
|
||||
continue;
|
||||
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
|
||||
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
|
||||
cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]);
|
||||
continue;
|
||||
} else if (optimize_control_depend) {
|
||||
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
} else {
|
||||
// the input node is a cnode from other graph
|
||||
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
|
||||
if (parameter_from_cnode == nullptr) {
|
||||
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
|
||||
}
|
||||
cnode_inputs.push_back(parameter_from_cnode);
|
||||
cnode_inputs->push_back(parameter_from_cnode);
|
||||
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
// get primitive of old node
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
GetCNodeInfo(cnode, &cnode_inputs);
|
||||
GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNode(cnode_inputs);
|
||||
TraceManager::EndTrace();
|
||||
|
@ -593,6 +602,42 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra
|
|||
return partial_node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
auto switch_cnode = cnode_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
if (cnode->inputs().size() < 2) {
|
||||
cnode_inputs = switch_cnode->inputs();
|
||||
return cnode_inputs;
|
||||
}
|
||||
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_cnode->input(kFirstDataInputIndex)};
|
||||
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
||||
auto node = switch_cnode->input(index);
|
||||
// there is real input in call, should put it to true and false branch in switch
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
auto partial_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
||||
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||
auto new_partial = graph->NewCNode(partial_inputs);
|
||||
switch_inputs.emplace_back(new_partial);
|
||||
}
|
||||
}
|
||||
if (switch_inputs.size() < kSwitchInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
|
||||
}
|
||||
auto switch_node = graph->NewCNode(switch_inputs);
|
||||
cnode_inputs.emplace_back(switch_node);
|
||||
return cnode_inputs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -618,32 +663,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|||
});
|
||||
return cnode_inputs;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
||||
auto switch_cnode = cnode_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
if (cnode->inputs().size() < 2) {
|
||||
cnode_inputs = switch_cnode->inputs();
|
||||
return cnode_inputs;
|
||||
}
|
||||
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_cnode->input(kFirstDataInputIndex)};
|
||||
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
||||
auto node = switch_cnode->input(index);
|
||||
// there is real input in call, should put it to true and false branch in switch
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
auto partial_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
||||
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||
auto new_partial = graph->NewCNode(partial_inputs);
|
||||
switch_inputs.emplace_back(new_partial);
|
||||
}
|
||||
}
|
||||
if (switch_inputs.size() < kSwitchInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
|
||||
}
|
||||
auto switch_node = graph->NewCNode(switch_inputs);
|
||||
cnode_inputs.emplace_back(switch_node);
|
||||
return cnode_inputs;
|
||||
return CreateCallSwitchInputs(cnode, graph);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
|
||||
}
|
||||
|
|
|
@ -130,6 +130,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs);
|
||||
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
|
||||
protected:
|
||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||
|
|
Loading…
Reference in New Issue