Add abstract for maketuple

This commit is contained in:
yangwei 2021-04-09 10:16:05 +08:00
parent 33edd67261
commit fe5ef68c5b
2 changed files with 9 additions and 6 deletions

View File

@ -892,8 +892,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
return cnode_inputs;
}
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs) {
void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
// func1 =switch(branch1, branch2)
// func2 = func1(param1)
@ -997,7 +997,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
MS_EXCEPTION_IF_NULL(ret);
auto return_input = ret->input(kFirstDataInputIndex);
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
CreateCallNodeReturnFunction(cnode, partial_kernel_graph.get(), real_inputs);
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
}
// partial node add input args
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
@ -1006,7 +1006,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
new_make_tuple_inputs.emplace_back(new_partial);
}
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
new_make_tuple->set_abstract(make_tuple_node->abstract());
auto abstract = make_tuple_node->abstract();
if (abstract == nullptr) {
abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
}
new_make_tuple->set_abstract(abstract);
switch_layer_inputs.emplace_back(new_make_tuple);
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
cnode_inputs.emplace_back(new_switch_layer);

View File

@ -155,8 +155,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
void CreateCallNodeReturnFunction(const CNodePtr &cnode, KernelGraph *graph,
const std::vector<AnfNodePtr> &real_inputs);
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
protected:
friend class Executor;