forked from mindspore-Ecosystem/mindspore
Add abstract for maketuple
This commit is contained in:
parent
33edd67261
commit
fe5ef68c5b
|
@ -892,8 +892,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|||
return cnode_inputs;
|
||||
}
|
||||
|
||||
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
|
||||
const std::vector<AnfNodePtr> &real_inputs) {
|
||||
void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
|
||||
const std::vector<AnfNodePtr> &real_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// func1 =switch(branch1, branch2)
|
||||
// func2 = func1(param1)
|
||||
|
@ -997,7 +997,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
|
||||
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs);
|
||||
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
|
||||
}
|
||||
// partial node add input args
|
||||
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
||||
|
@ -1006,7 +1006,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
new_make_tuple_inputs.emplace_back(new_partial);
|
||||
}
|
||||
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
|
||||
new_make_tuple->set_abstract(make_tuple_node->abstract());
|
||||
auto abstract = make_tuple_node->abstract();
|
||||
if (abstract == nullptr) {
|
||||
abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
|
||||
}
|
||||
new_make_tuple->set_abstract(abstract);
|
||||
switch_layer_inputs.emplace_back(new_make_tuple);
|
||||
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
|
||||
cnode_inputs.emplace_back(new_switch_layer);
|
||||
|
|
|
@ -155,8 +155,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
|
||||
const std::vector<AnfNodePtr> &real_inputs);
|
||||
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
|
||||
|
||||
protected:
|
||||
friend class Executor;
|
||||
|
|
Loading…
Reference in New Issue