forked from mindspore-Ecosystem/mindspore
set abstract of switch
This commit is contained in:
parent
bfdd040728
commit
fa7cb34c8a
|
@ -1696,7 +1696,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
||||||
}
|
}
|
||||||
auto partial_cnode = partial->cast<CNodePtr>();
|
auto partial_cnode = partial->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||||
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
|
auto graph_node = partial_cnode->input(kPartialGraphIndex);
|
||||||
MS_EXCEPTION_IF_NULL(graph_node);
|
MS_EXCEPTION_IF_NULL(graph_node);
|
||||||
auto graph_value_node = graph_node->cast<ValueNodePtr>();
|
auto graph_value_node = graph_node->cast<ValueNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(graph_value_node);
|
MS_EXCEPTION_IF_NULL(graph_value_node);
|
||||||
|
@ -1706,7 +1706,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
||||||
return child_graph;
|
return child_graph;
|
||||||
};
|
};
|
||||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||||
auto input1 = cnode->input(kCallKernelGraphIndex);
|
auto input1 = cnode->input(kPartialGraphIndex);
|
||||||
MS_EXCEPTION_IF_NULL(input1);
|
MS_EXCEPTION_IF_NULL(input1);
|
||||||
auto value_node = input1->cast<ValueNodePtr>();
|
auto value_node = input1->cast<ValueNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
@ -1714,11 +1714,10 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
return {kernel_graph->cast<KernelGraphPtr>()};
|
return {kernel_graph->cast<KernelGraphPtr>()};
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||||
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
|
return {get_switch_kernel_graph(kSwitchTrueBranchIndex), get_switch_kernel_graph(kSwitchFalseBranchIndex)};
|
||||||
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
|
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||||
std::vector<KernelGraphPtr> child_graphs;
|
std::vector<KernelGraphPtr> child_graphs;
|
||||||
for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) {
|
for (size_t idx = kSwitchLayerBranchesIndex; idx < cnode->inputs().size(); idx++) {
|
||||||
auto child_graph = get_switch_kernel_graph(idx);
|
auto child_graph = get_switch_kernel_graph(idx);
|
||||||
child_graphs.emplace_back(child_graph);
|
child_graphs.emplace_back(child_graph);
|
||||||
}
|
}
|
||||||
|
|
|
@ -642,7 +642,7 @@ class CallInfoFinder {
|
||||||
}
|
}
|
||||||
|
|
||||||
CallBranch GetCallBranch(const CNodePtr &cnode) {
|
CallBranch GetCallBranch(const CNodePtr &cnode) {
|
||||||
auto input_graph = cnode->input(kCallKernelGraphIndex);
|
auto input_graph = cnode->input(kPartialGraphIndex);
|
||||||
MS_EXCEPTION_IF_NULL(input_graph);
|
MS_EXCEPTION_IF_NULL(input_graph);
|
||||||
auto kg = GetValueNode<KernelGraphPtr>(input_graph);
|
auto kg = GetValueNode<KernelGraphPtr>(input_graph);
|
||||||
MS_EXCEPTION_IF_NULL(kg);
|
MS_EXCEPTION_IF_NULL(kg);
|
||||||
|
|
|
@ -826,7 +826,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
||||||
switch_cnode->input(kFirstDataInputIndex)};
|
switch_cnode->input(kFirstDataInputIndex)};
|
||||||
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->inputs().size(); index++) {
|
||||||
auto node = switch_cnode->input(index);
|
auto node = switch_cnode->input(index);
|
||||||
// there is real input in call, should put it to true and false branch in switch
|
// there is real input in call, should put it to true and false branch in switch
|
||||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||||
|
@ -919,7 +919,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
||||||
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
|
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
|
||||||
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
|
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
|
||||||
switch_layer_cnode->input(kFirstDataInputIndex)};
|
switch_layer_cnode->input(kFirstDataInputIndex)};
|
||||||
auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
|
auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||||
auto node = make_tuple_node->cast<CNodePtr>();
|
auto node = make_tuple_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -1040,7 +1040,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||||
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
|
for (size_t index = kSwitchTrueBranchIndex; index < cnode->inputs().size(); index++) {
|
||||||
auto node_input = cnode->input(index);
|
auto node_input = cnode->input(index);
|
||||||
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
|
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
|
||||||
(void)cnode_inputs->emplace_back(switch_input);
|
(void)cnode_inputs->emplace_back(switch_input);
|
||||||
|
|
|
@ -296,8 +296,8 @@ class InlinerBase : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
|
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
|
||||||
auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
|
auto true_branch_abstract = sw_inputs[kSwitchTrueBranchIndex]->abstract();
|
||||||
auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
|
auto false_branch_abstract = sw_inputs[kSwitchFalseBranchIndex]->abstract();
|
||||||
// When branch has dead node or poly node, do not perform inline.
|
// When branch has dead node or poly node, do not perform inline.
|
||||||
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
|
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace irpass {
|
namespace irpass {
|
||||||
|
const auto kMinInputSizeOfCallWithArgs = 2;
|
||||||
// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
|
// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
|
||||||
class PartialEliminater : public AnfVisitor {
|
class PartialEliminater : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
|
@ -98,7 +99,7 @@ class PartialEliminater : public AnfVisitor {
|
||||||
|
|
||||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
// {prim::kPrimPartial, X, Xs}
|
// {prim::kPrimPartial, X, Xs}
|
||||||
if (inputs.size() < 2) {
|
if (inputs.size() <= 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,8 +131,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||||
// {prim::kPrimPartial, G, Xs}
|
// {prim::kPrimPartial, G}
|
||||||
if (inputs.size() < 3) {
|
if (inputs.size() < kPartialMinInputSize) {
|
||||||
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
|
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -218,6 +219,9 @@ class ChoicePartialEliminater : public AnfVisitor {
|
||||||
if (new_params[j] == nullptr) {
|
if (new_params[j] == nullptr) {
|
||||||
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
|
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
|
||||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||||
|
auto new_abs =
|
||||||
|
anchor_fg_params[j]->abstract() == nullptr ? nullptr : anchor_fg_params[j]->abstract()->Clone();
|
||||||
|
param->set_abstract(new_abs);
|
||||||
new_params[j] = param;
|
new_params[j] = param;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -226,6 +230,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
||||||
if (new_params[anchor_args_size + j] == nullptr) {
|
if (new_params[anchor_args_size + j] == nullptr) {
|
||||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
|
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
|
||||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||||
|
auto new_abs = extra_inputs[j]->abstract() == nullptr ? nullptr : extra_inputs[j]->abstract()->Clone();
|
||||||
|
param->set_abstract(new_abs);
|
||||||
new_params[anchor_args_size + j] = param;
|
new_params[anchor_args_size + j] = param;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -242,6 +248,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
||||||
for (size_t i = 0; i < extra_inputs.size(); ++i) {
|
for (size_t i = 0; i < extra_inputs.size(); ++i) {
|
||||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
|
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
|
||||||
ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
|
ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
|
||||||
|
auto new_abs = extra_inputs[i]->abstract() == nullptr ? nullptr : extra_inputs[i]->abstract()->Clone();
|
||||||
|
param->set_abstract(new_abs);
|
||||||
new_params.push_back(param);
|
new_params.push_back(param);
|
||||||
}
|
}
|
||||||
// Reorder Zs_ to last;
|
// Reorder Zs_ to last;
|
||||||
|
@ -312,19 +320,19 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
|
auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||||
if (input0_cnode->size() != 4) {
|
if (input0_cnode->size() != kSwitchInputSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
fg_list_.clear();
|
fg_list_.clear();
|
||||||
args_list_.clear();
|
args_list_.clear();
|
||||||
auto &maybe_partial_1 = input0_cnode->input(2);
|
auto &maybe_partial_1 = input0_cnode->input(kSwitchTrueBranchIndex);
|
||||||
Visit(maybe_partial_1);
|
Visit(maybe_partial_1);
|
||||||
auto &maybe_partial_2 = input0_cnode->input(3);
|
auto &maybe_partial_2 = input0_cnode->input(kSwitchFalseBranchIndex);
|
||||||
Visit(maybe_partial_2);
|
Visit(maybe_partial_2);
|
||||||
|
|
||||||
// Either one should be {Partial, G, X}
|
// Either one should be {Partial, G, X}
|
||||||
if (fg_list_.size() != 2 && args_list_.size() != 2) {
|
if (fg_list_.size() != kSwitchBranchesNum && args_list_.size() != kSwitchBranchesNum) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
// Should not continue;
|
// Should not continue;
|
||||||
|
@ -359,11 +367,12 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
|
||||||
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
|
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
|
||||||
// {Switch, cond, G1, G2}
|
// {Switch, cond, G1, G2}
|
||||||
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
|
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
|
||||||
|
switch_cnode->set_abstract(input0_cnode->abstract());
|
||||||
AnfNodePtrList args{switch_cnode};
|
AnfNodePtrList args{switch_cnode};
|
||||||
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
|
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
|
||||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||||
// Zs
|
// Zs
|
||||||
if (old_cnode->size() >= 2) {
|
if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
|
||||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||||
}
|
}
|
||||||
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||||
|
@ -392,14 +401,14 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
||||||
}
|
}
|
||||||
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
|
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||||
// {SwitchLayer, cond, MakeTuple{}}
|
// {SwitchLayer, cond, MakeTuple{}}
|
||||||
if (switch_layer_cnode->size() != 3) {
|
if (switch_layer_cnode->size() != kSwitchLayerInputSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) {
|
if (!IsPrimitiveCNode(switch_layer_cnode->input(kSwitchLayerBranchesIndex), prim::kPrimMakeTuple)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
|
||||||
if (make_tuple_cnode->size() < 2) {
|
if (make_tuple_cnode->size() <= 1) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,7 +446,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
||||||
private:
|
private:
|
||||||
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
|
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
|
||||||
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
|
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
|
||||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
|
||||||
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
|
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
|
||||||
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
|
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
|
||||||
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
|
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
|
||||||
|
@ -452,7 +461,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
||||||
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
|
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
|
||||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||||
// Zs
|
// Zs
|
||||||
if (old_cnode->size() >= 2) {
|
if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
|
||||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||||
}
|
}
|
||||||
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||||
|
|
|
@ -563,18 +563,21 @@ constexpr auto kFirstDataInputIndex = 1;
|
||||||
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
||||||
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
||||||
constexpr auto kTupleGetItemInputSize = 3;
|
constexpr auto kTupleGetItemInputSize = 3;
|
||||||
|
// index define of partial
|
||||||
|
constexpr auto kPartialMinInputSize = 2;
|
||||||
|
constexpr auto kPartialGraphIndex = 1;
|
||||||
|
|
||||||
|
// index define of switch
|
||||||
constexpr auto kSwitchInputSize = 4;
|
constexpr auto kSwitchInputSize = 4;
|
||||||
constexpr auto kFirstBranchInSwitch = 2;
|
constexpr auto kSwitchTrueBranchIndex = 2;
|
||||||
constexpr auto kCallKernelGraphIndex = 1;
|
constexpr auto kSwitchFalseBranchIndex = 3;
|
||||||
constexpr auto kSwitchTrueKernelGraphIndex = 2;
|
constexpr auto kSwitchBranchesNum = 2;
|
||||||
constexpr auto kSwitchFalseKernelGraphIndex = 3;
|
|
||||||
constexpr auto kMakeTupleInSwitchLayerIndex = 2;
|
// index define of switch_layer
|
||||||
constexpr auto kSwitchLayerInputSize = 3;
|
constexpr auto kSwitchLayerInputSize = 3;
|
||||||
// index define of control depend
|
constexpr auto kSwitchLayerSelectIndex = 1;
|
||||||
constexpr auto kControlDependPriorIndex = 1;
|
constexpr auto kSwitchLayerBranchesIndex = 2;
|
||||||
constexpr auto kControlDependBehindIndex = 2;
|
|
||||||
constexpr auto kControlDependInputSize = 3;
|
|
||||||
constexpr auto kControlDependMode = "depend_mode";
|
|
||||||
// index define of depend
|
// index define of depend
|
||||||
constexpr auto kRealInputIndexInDepend = 1;
|
constexpr auto kRealInputIndexInDepend = 1;
|
||||||
constexpr auto kDependAttachNodeIndex = 2;
|
constexpr auto kDependAttachNodeIndex = 2;
|
||||||
|
|
|
@ -330,9 +330,9 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
|
||||||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
||||||
}
|
}
|
||||||
VectorRef args;
|
VectorRef args;
|
||||||
args.emplace_back(Ref(inputs[kCallKernelGraphIndex]));
|
args.emplace_back(Ref(inputs[kPartialGraphIndex]));
|
||||||
args.emplace_back(Ref(inputs[kSwitchTrueKernelGraphIndex]));
|
args.emplace_back(Ref(inputs[kSwitchTrueBranchIndex]));
|
||||||
args.emplace_back(Ref(inputs[kSwitchFalseKernelGraphIndex]));
|
args.emplace_back(Ref(inputs[kSwitchFalseBranchIndex]));
|
||||||
AddInst(Instruction::kSwitch, args);
|
AddInst(Instruction::kSwitch, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue