fix control flow bugs.
This commit is contained in:
parent
2c94b631a7
commit
75ab763696
|
@ -99,45 +99,14 @@ void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_
|
|||
}
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(func_graph->output(), branch_id);
|
||||
return;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
|
||||
// If the funcgraph called by the partial returns a partial node, the switch actor should call the funcgraph
|
||||
// of the sub partial. Similarly, the input node should also be the input of the sub partial.
|
||||
is_mulit_call_ = true;
|
||||
CNodePtr sub_partial = func_graph->output()->cast<CNodePtr>();
|
||||
const auto &sub_partial_inputs = sub_partial->inputs();
|
||||
if (sub_partial_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(sub_partial);
|
||||
}
|
||||
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(sub_partial_inputs[kPartialFuncGraphPos]);
|
||||
|
||||
if (sub_func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(sub_func_graph->output(), branch_id);
|
||||
return;
|
||||
}
|
||||
|
||||
branch_func_graph_[branch_id] = sub_func_graph;
|
||||
const auto &sub_parameters = func_graph->parameters();
|
||||
|
||||
// Record the input that comes with the sub partial node.
|
||||
for (size_t i = kPartialInputStartPos; i < sub_partial_inputs.size(); ++i) {
|
||||
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[i], 0).first;
|
||||
const auto &iter = find(sub_parameters.begin(), sub_parameters.end(), real_partial_input);
|
||||
if ((iter != sub_parameters.end()) &&
|
||||
((iter - sub_parameters.begin()) < SizeToInt(partial_inputs.size() - kPartialInputStartPos))) {
|
||||
AddInput(partial_inputs[iter - sub_parameters.begin() + kPartialInputStartPos], branch_id);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
|
||||
AddInput(partial_inputs[j], branch_id);
|
||||
}
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
const auto func_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
} else {
|
||||
AddInput(node, branch_id);
|
||||
}
|
||||
|
@ -158,11 +127,6 @@ void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
|
|||
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
|
||||
InitVectorSize(call_num);
|
||||
|
||||
// If the return is a partial node or funcgraph, this subgraph will not be initialized and no input is required.
|
||||
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial) ||
|
||||
(func_graph->output()->isa<ValueNode>() && IsValueNode<FuncGraph>(func_graph->output()))) {
|
||||
return;
|
||||
}
|
||||
AddCommonInput(func_graph->output());
|
||||
}
|
||||
|
||||
|
@ -210,26 +174,7 @@ void SwitchActor::ParseSwitchLayerInput() {
|
|||
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
|
||||
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
|
||||
} else if (branch_nodes[i]->isa<ValueNode>()) {
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
const auto output = func_graph->output();
|
||||
|
||||
// The switch layer node has a second-order call connected to call. When the called funcgraph returns a partial
|
||||
// node or funcgraph, the switch actor needs to call the funcgraph directly.
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
is_mulit_call_ = true;
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] =
|
||||
GetValueNode<FuncGraphPtr>(output->cast<CNodePtr>()->input(kPartialFuncGraphPos));
|
||||
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
|
||||
is_mulit_call_ = true;
|
||||
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(output);
|
||||
if (sub_func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(sub_func_graph->output(), i - kMakeTupleInputStartPos);
|
||||
continue;
|
||||
}
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] = GetValueNode<FuncGraphPtr>(output);
|
||||
} else {
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] = func_graph;
|
||||
}
|
||||
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -256,9 +201,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b
|
|||
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
|
||||
// 1) The input of the switch is the value node.
|
||||
// 2) There is a weight node or value node in the return of the sub funcgraph.
|
||||
// 3) When the switch actor is a second-order call, it does not distinguish between weight and parameter.
|
||||
if (((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) || is_mulit_call_) && node->isa<Parameter>() &&
|
||||
HasAbstractRef(node)) ||
|
||||
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
|
||||
node->isa<ValueNode>()) {
|
||||
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter != input_nodes_.end()) {
|
||||
|
@ -308,6 +251,12 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
|||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
AddInput({real_input.first, i}, branch);
|
||||
}
|
||||
} else if (real_input.first->isa<ValueNode>() && real_input.first->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
|
||||
const auto &value = real_input.first->cast<ValueNodePtr>()->value();
|
||||
const auto &tuple_value = value->cast<ValueTuplePtr>();
|
||||
for (size_t i = 0; i < tuple_value->value().size(); ++i) {
|
||||
AddInput({real_input.first, i}, branch);
|
||||
}
|
||||
} else {
|
||||
AddInput(real_input, branch);
|
||||
}
|
||||
|
|
|
@ -171,11 +171,6 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
|
||||
// Used to indicate that in the control flow, when the input of the call node is a call node, the switch actor
|
||||
// corresponding to the switch node called by the sub call node. At this time, the funcgraph of the input of
|
||||
// the switch actor will return to a partial node or funcgraph.
|
||||
bool is_mulit_call_{false};
|
||||
};
|
||||
|
||||
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
constexpr size_t kSingleCallDepth = 1;
|
||||
namespace {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
// Fetch all the weight parameters related to node. It runs like this:
|
||||
|
@ -69,6 +68,9 @@ void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParame
|
|||
|
||||
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
|
||||
const auto &partial_node = switch_inputs[i];
|
||||
if (IsValueNode<FuncGraph>(partial_node)) {
|
||||
continue;
|
||||
}
|
||||
const auto &func_graph = GetFuncGraphFromPartial(partial_node);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
const auto &partial_inputs = partial_node->cast<CNodePtr>()->inputs();
|
||||
|
@ -451,40 +453,6 @@ std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr
|
|||
return parameters;
|
||||
}
|
||||
|
||||
// Get the number of calls, that is, the number of times the input of the call node is the call node.
|
||||
size_t FetchCallDepth(const AnfNodePtr &node) {
|
||||
if (!IsCallNode(node)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
return kSingleCallDepth + FetchCallDepth(inputs[0]);
|
||||
}
|
||||
|
||||
// Get the final subgraph called by fungraph through the depth of calls.
|
||||
FuncGraphPtr FetchFuncGraphByCallDepth(const FuncGraphPtr &func_graph, const size_t call_depth) {
|
||||
if (call_depth <= kSingleCallDepth) {
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
const auto &output = func_graph->output();
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
const auto &cnode = output->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() < kPartialInputStartPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial node:" << AnfAlgo::GetNodeDebugString(output);
|
||||
}
|
||||
const auto &called_func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
|
||||
return FetchFuncGraphByCallDepth(called_func_graph, call_depth - kSingleCallDepth);
|
||||
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
|
||||
return FetchFuncGraphByCallDepth(GetValueNode<FuncGraphPtr>(output), call_depth - kSingleCallDepth);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid output for call depth:" << call_depth << " funcgraph:" << func_graph->ToString()
|
||||
<< " output node:" << AnfAlgo::GetNodeDebugString(output);
|
||||
}
|
||||
}
|
||||
|
||||
// Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
|
||||
FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
|
@ -516,6 +484,26 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
|
||||
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
if (!IsCallNode(real_node)) {
|
||||
return real_node;
|
||||
}
|
||||
if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
|
||||
return nullptr;
|
||||
}
|
||||
(*call_nodes).insert(real_node);
|
||||
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(real_node);
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
const auto &output = FetchRealOutputByCallNode(func_graph->output(), call_nodes);
|
||||
if (output != nullptr) {
|
||||
return output;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Return true if the node has Ref abstract.
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
|
@ -602,6 +590,8 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
|||
for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) {
|
||||
func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i]));
|
||||
} else if (IsValueNode<FuncGraph>(cnode_inputs[i])) {
|
||||
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(cnode_inputs[i]));
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) &&
|
||||
|
@ -659,6 +649,11 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
|||
break;
|
||||
}
|
||||
total_num += call_output_num;
|
||||
} else if (inputs[i]->isa<ValueNode>() && inputs[i]->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
|
||||
auto value_tuple = inputs[i]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
auto tuple_value = value_tuple->value();
|
||||
total_num += tuple_value.size();
|
||||
} else if (!HasAbstractMonad(inputs[i])) {
|
||||
++total_num;
|
||||
}
|
||||
|
@ -925,8 +920,12 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
}
|
||||
if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) ==
|
||||
front_to_backend_parameters_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend parameter for front parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]);
|
||||
MS_LOG(INFO) << "Cannot find backend parameter for front parameter:"
|
||||
<< AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos])
|
||||
<< ", used the default format";
|
||||
CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]);
|
||||
front_value_nodes_.push_back({inputs[i], device_contexts[0]});
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first;
|
||||
|
@ -1110,9 +1109,6 @@ void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &con
|
|||
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) {
|
||||
func_graph_to_call_num_[func_graph] = 1;
|
||||
|
@ -1147,13 +1143,6 @@ void ControlNodeParser::CreateBranchIDForFuncGraph(const std::vector<AnfNodePtr>
|
|||
for (const auto &control_node : control_nodes) {
|
||||
// Root funcgraph does not need to create a gather actor.
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
// If the output of funcgraph is a value node, no need to create gather actor.
|
||||
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
func_graph_to_branch_id_[func_graph] = branch_id++;
|
||||
}
|
||||
|
|
|
@ -53,6 +53,9 @@ bool IsCallNode(const AnfNodePtr &node);
|
|||
// Check if the call node is the input of another call node.
|
||||
bool IsSubCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Recursive interface, find the real output of funcgraph called by call node.
|
||||
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes);
|
||||
|
||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
|
|
|
@ -1253,12 +1253,9 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
continue;
|
||||
}
|
||||
|
||||
// If the output of funcgraph is a value node, no need to create gather actor.
|
||||
if (inputs[kReturnInputPos]->isa<ValueNode>() ||
|
||||
AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto actor_name = func_graph->ToString();
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
for (const auto ¶meter : func_graph->get_inputs()) {
|
||||
|
@ -2079,69 +2076,14 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr
|
|||
|
||||
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
|
||||
if (inputs[0]->isa<CNode>()) {
|
||||
// Add the input of call node to switch actor.
|
||||
if (IsCallNode(inputs[0])) {
|
||||
const auto &sub_call_cnode = inputs[0]->cast<CNodePtr>();
|
||||
const auto &sub_inputs = sub_call_cnode->inputs();
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(sub_inputs[0], prim::kPrimSwitchLayer)) {
|
||||
auto actor = FetchActor(sub_inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (IsSubCallNode(cnode)) {
|
||||
// Add the input of sub call node to switch actor.
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
||||
const auto &tuple_node = inputs[0]->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
||||
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||
int pre_real_parameter_num = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
|
||||
pre_real_parameter_num = (tuple_inputs[i]->cast<CNodePtr>()->inputs().size() - kPartialInputStartPos);
|
||||
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]->cast<CNodePtr>()->input(kPartialFuncGraphPos));
|
||||
} else {
|
||||
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
|
||||
}
|
||||
const auto parameters = func_graph->parameters();
|
||||
const auto &output = func_graph->output();
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
const auto &sub_partial_inputs = output->cast<CNodePtr>()->inputs();
|
||||
|
||||
// Check whether the input node of the sub call node needs to be added to the switch actor. Only when
|
||||
// the final return is a partial node and the partial node needs this input, the input node is added
|
||||
// to the switch actor/
|
||||
for (size_t j = kPartialInputStartPos; j < sub_partial_inputs.size(); ++j) {
|
||||
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[j], 0).first;
|
||||
const auto &iter = find(parameters.begin(), parameters.end(), real_partial_input);
|
||||
|
||||
if ((iter != parameters.end()) && (iter - parameters.begin() >= pre_real_parameter_num) &&
|
||||
(iter - parameters.begin() <
|
||||
SizeToInt(pre_real_parameter_num + inputs.size() - kCallInputStartPos))) {
|
||||
size_t pos = iter - parameters.begin() - pre_real_parameter_num + kCallInputStartPos;
|
||||
switch_actor->AddSingleInput(inputs[pos], i - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2219,10 +2161,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
|
||||
for (const auto &input_with_index : gather_actor->data_nodes_) {
|
||||
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
|
||||
const auto &input_with_index = gather_actor->data_nodes_[i];
|
||||
const auto &from_func_graph = kernel_graph->GetFuncGraph();
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor,
|
||||
gather_actor->FetchDataNodePosition(input_with_index));
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
|
||||
}
|
||||
}
|
||||
LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set);
|
||||
|
@ -2258,47 +2200,9 @@ void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_w
|
|||
// Fetch all the funcgraph that call node would call.
|
||||
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
|
||||
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
|
||||
const auto &call_inputs = cnode->inputs();
|
||||
auto switch_node = call_inputs[0];
|
||||
if (IsCallNode(switch_node)) {
|
||||
switch_node = call_inputs[0]->cast<CNodePtr>()->input(0);
|
||||
}
|
||||
|
||||
// Collect the output of each funcgraph.
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &actor_name = switch_node->DebugString();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
|
||||
const auto &output_with_index = KernelWithIndex(func_graph->output(), 0);
|
||||
const auto &iter =
|
||||
find(switch_actor->input_nodes_.begin(), switch_actor->input_nodes_.end(), output_with_index);
|
||||
if (iter == switch_actor->input_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input node for switch actor:" << switch_actor->GetAID()
|
||||
<< " node:" << AnfAlgo::GetNodeDebugString(func_graph->output());
|
||||
}
|
||||
size_t pos = iter - switch_actor->input_nodes_.begin();
|
||||
// Add output for each branch of switch.
|
||||
for (size_t i = 0; i < switch_actor->branch_inputs_pos_.size(); ++i) {
|
||||
const auto poses = switch_actor->branch_inputs_pos_[i];
|
||||
if (find(poses.begin(), poses.end(), pos) == poses.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_index);
|
||||
switch_actor->output_branch_arrows_[i].emplace_back(op_arrow);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto actor_name = func_graph->get_return()->DebugString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
|
@ -2443,7 +2347,7 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_
|
|||
// Link switch output.
|
||||
for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
|
||||
auto func_graph = actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
|
||||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2547,6 +2451,18 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr>
|
|||
// If there is no output from the switch actor branch, it means that the subgraph has no input,
|
||||
// and need to connect a control arrow to the corresponding gather actor.
|
||||
for (auto &switch_actor : (*switch_actors)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
|
||||
const auto &func_graph = switch_actor->node_->func_graph();
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
const auto &actor_name = func_graph->ToString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
|
||||
switch_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
|
||||
const auto &arrows = switch_actor->output_branch_arrows_[i];
|
||||
if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
|
||||
|
@ -2603,7 +2519,7 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap
|
|||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
|
||||
const auto &func_graph = switch_actor->branch_func_graph_[i];
|
||||
if (func_graph == nullptr || func_graph->output()->isa<ValueNode>()) {
|
||||
if (func_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -258,7 +258,6 @@ class GraphScheduler {
|
|||
// send the branch id to the loop count actor.
|
||||
void LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, const ActorSet *actor_set);
|
||||
void PrepareDataForControlNode(HostQueueDataSourceActor *host_data_source_actor,
|
||||
const ControlNodeParserPtr &control_node_parser,
|
||||
|
|
Loading…
Reference in New Issue