!19655 fix switch layer call to call

Merge pull request !19655 from gaoyong10/new_runtime23
This commit is contained in:
i-robot 2021-07-08 13:30:48 +00:00 committed by Gitee
commit de4b9a94fc
9 changed files with 421 additions and 203 deletions

View File

@ -42,10 +42,10 @@ void GatherActor::Init() {
}
}
size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
size_t GatherActor::FetchDataNodePosition(const KernelWithIndex &data_node) const {
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
if (iter == data_nodes_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node.first) << " index:" << data_node.second
<< " is not exist in gather actor:" << GetAID();
}
return iter - data_nodes_.begin();
@ -114,7 +114,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
size_t from_index = result_arrow->from_output_index_;
const auto &front_node = data_nodes_[from_index];
const auto &front_node = data_nodes_[from_index].first;
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
input_device_tensors_[from_index]) {

View File

@ -47,7 +47,7 @@ constexpr size_t kReturnInputPos = 1;
// collected at the entrance of the kernel graph.
class GatherActor : public OpActor<DeviceTensor> {
public:
GatherActor(const std::string &name, const std::vector<AnfNodePtr> &parameters, const bool need_branch_id_input,
GatherActor(const std::string &name, const std::vector<KernelWithIndex> &parameters, const bool need_branch_id_input,
const AID switch_aid, const AID gather_aid, const int branch_id)
: OpActor(name),
data_nodes_(parameters),
@ -60,7 +60,7 @@ class GatherActor : public OpActor<DeviceTensor> {
~GatherActor() override = default;
// Get the index of the parameter, the data_node needs to be the front node.
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
size_t FetchDataNodePosition(const KernelWithIndex &data_node) const;
// The gather actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
@ -107,7 +107,7 @@ class GatherActor : public OpActor<DeviceTensor> {
std::vector<AID> output_branch_arrows_;
// Parameters of sub funcgraph, which is the front node.
std::vector<AnfNodePtr> data_nodes_;
std::vector<KernelWithIndex> data_nodes_;
std::vector<DeviceContext *> device_contexts_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;

View File

@ -78,7 +78,9 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
auto &sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
if (input_data->data_ == nullptr) {
MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty";
std::string error_info =
"Input data of actor:" + GetAID().Name() + " num:" + std::to_string(input_data->index_) + " is empty";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {

View File

@ -72,20 +72,20 @@ void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *
input_branch_ids_[sequential_num].push(branch_id);
}
void SwitchActor::Initialize(const ControlNodeParserPtr &parser) {
void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) {
std::vector<AnfNodePtr> inputs = node_->inputs();
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
InitSwitch();
ParseSwitchInput();
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
InitReturn(parser);
ParseReturnInput(parser);
} else {
InitSwitchLayer();
ParseSwitchLayerInput();
}
backend_parameters_.resize(input_nodes_.size());
}
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
CNodePtr cnode = node->cast<CNodePtr>();
@ -93,20 +93,50 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
// [0] ValueNode<Primitive> kPartial.
// [1] ValueNode<FuncGraphPtr>.
// [2..] Inputs.
const auto &node_inputs = cnode->inputs();
if (node_inputs.size() <= kPartialFuncGraphPos) {
auto partial_inputs = cnode->inputs();
if (partial_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
if (func_graph->output()->isa<ValueNode>()) {
AddInput(func_graph->output(), branch_id);
return;
} else if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
// If the funcgraph called by the partial returns a partial node, the switch actor should call the funcgraph
// of the sub partial. Similarly, the input node should also be the input of the sub partial.
is_mulit_call_ = true;
CNodePtr sub_partial = func_graph->output()->cast<CNodePtr>();
const auto &sub_partial_inputs = sub_partial->inputs();
if (sub_partial_inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(sub_partial);
}
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(sub_partial_inputs[kPartialFuncGraphPos]);
if (sub_func_graph->output()->isa<ValueNode>()) {
AddInput(sub_func_graph->output(), branch_id);
return;
}
branch_func_graph_[branch_id] = sub_func_graph;
const auto &sub_parameters = func_graph->parameters();
// Record the input that comes with the sub partial node.
for (size_t i = kPartialInputStartPos; i < sub_partial_inputs.size(); ++i) {
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[i], 0).first;
const auto &iter = find(sub_parameters.begin(), sub_parameters.end(), real_partial_input);
if ((iter != sub_parameters.end()) &&
((iter - sub_parameters.begin()) < SizeToInt(partial_inputs.size() - kPartialInputStartPos))) {
AddInput(partial_inputs[iter - sub_parameters.begin() + kPartialInputStartPos], branch_id);
}
}
return;
}
branch_func_graph_[branch_id] = func_graph;
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
AddInput(node_inputs[j], branch_id);
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
AddInput(partial_inputs[j], branch_id);
}
} else {
AddInput(node, branch_id);
@ -122,15 +152,21 @@ void SwitchActor::InitVectorSize(const size_t num) {
output_branch_branch_arrows_.resize(num);
}
void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) {
void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
const auto &func_graph = node_->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
InitVectorSize(call_num);
// If the return is a partial node or funcgraph, this subgraph will not be initialized and no input is required.
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial) ||
(func_graph->output()->isa<ValueNode>() && IsValueNode<FuncGraph>(func_graph->output()))) {
return;
}
AddCommonInput(func_graph->output());
}
void SwitchActor::InitSwitch() {
void SwitchActor::ParseSwitchInput() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitch.
// [1] switch condition.
@ -147,11 +183,11 @@ void SwitchActor::InitSwitch() {
input_nodes_.push_back(cond_node);
input_datas_num_++;
// Init the two branches of switch node.
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
}
void SwitchActor::InitSwitchLayer() {
void SwitchActor::ParseSwitchLayerInput() {
// The inputs of the switch node:
// [0] ValueNode<Primitive> kSwitchLayer.
// [1] switchLayer index.
@ -170,11 +206,30 @@ void SwitchActor::InitSwitchLayer() {
InitVectorSize(branch_nodes.size() - 1);
// Parse all branches.
for (size_t i = 1; i < branch_nodes.size(); ++i) {
for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
InitPartial(branch_nodes[i], i - 1);
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
} else if (branch_nodes[i]->isa<ValueNode>()) {
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
const auto &func_graph = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
const auto output = func_graph->output();
// The switch layer node has a second-order call connected to call. When the called funcgraph returns a partial
// node or funcgraph, the switch actor needs to call the funcgraph directly.
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
is_mulit_call_ = true;
branch_func_graph_[i - kMakeTupleInputStartPos] =
GetValueNode<FuncGraphPtr>(output->cast<CNodePtr>()->input(kPartialFuncGraphPos));
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
is_mulit_call_ = true;
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(output);
if (sub_func_graph->output()->isa<ValueNode>()) {
AddInput(sub_func_graph->output(), i - kMakeTupleInputStartPos);
continue;
}
branch_func_graph_[i - kMakeTupleInputStartPos] = GetValueNode<FuncGraphPtr>(output);
} else {
branch_func_graph_[i - kMakeTupleInputStartPos] = func_graph;
}
}
}
}
@ -198,8 +253,12 @@ size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
const auto &node = node_with_index.first;
// Add weight and value node.
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
// 1) The input of the switch is the value node.
// 2) There is a weight node or value node in the return of the sub funcgraph.
// 3) When the switch actor is a second-order call, it does not distinguish between weight and parameter.
if (((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) || is_mulit_call_) && node->isa<Parameter>() &&
HasAbstractRef(node)) ||
node->isa<ValueNode>()) {
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
if (iter != input_nodes_.end()) {
@ -243,7 +302,6 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
} else if (IsCallNode(real_input.first)) {
std::vector<AnfNodePtr> call_nodes;
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
if (call_output_num <= 0) {
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
}
@ -259,25 +317,26 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
if (need_branch_id_input_) {
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
input_branch_ids_[context->sequential_num_].empty()) {
MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID();
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
}
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
input_branch_ids_[context->sequential_num_].pop();
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id;
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
" branch id:" + std::to_string(branch_id);
}
return branch_id_to_index_[branch_id];
}
DeviceTensor *device_tensor = input_device_tensors_[0];
if (device_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID();
MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name();
}
auto inputs = node_->inputs();
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
size_t size = abstract::TypeIdSize(type_id);
if (size > sizeof(int64_t)) {
MS_LOG(EXCEPTION) << "Index must be Int type.";
MS_LOG(ERROR) << "Index must be Int type.";
}
int64_t index = 0;
@ -293,7 +352,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
index = static_cast<int64_t>(cond ? 1 : 0);
} else {
MS_LOG(EXCEPTION) << "Index must be Int type.";
MS_LOG(ERROR) << "Index must be Int type.";
}
// SwitchLayer node support negative index range [-size, -1].
@ -352,7 +411,7 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
@ -370,7 +429,8 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto index = GetIndex(context);
if (index >= output_branch_arrows_.size()) {
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index);
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
@ -389,8 +449,10 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
auto &result_arrow = output_branch_result_arrow[i];
MS_EXCEPTION_IF_NULL(result_arrow);
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
MS_LOG(EXCEPTION) << "Invalid from index in switch actor, from index:" << result_arrow->from_output_index_
<< " total:" << branch_inputs_pos_[index].size() << " actor:" << GetAID();
std::string error_info =
"Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) +
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
@ -410,10 +472,12 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
}
}
if (!is_send) {
MS_LOG(EXCEPTION) << "Failed to get backend node of switch actor output, actor:" << GetAID()
<< " branch:" << index << " index:" << result_arrow->from_output_index_ << " output pos"
<< branch_inputs_pos_[index][result_arrow->from_output_index_] << " output index"
<< result_arrow->to_input_index_;
std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() +
" branch:" + std::to_string(index) +
" index:" + std::to_string(result_arrow->from_output_index_) + " output pos" +
std::to_string(branch_inputs_pos_[index][result_arrow->from_output_index_]) +
" output index" + std::to_string(result_arrow->to_input_index_);
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
@ -426,7 +490,6 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(data);
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
}

View File

@ -76,22 +76,23 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
// The switch actor run when receive the input branch id.
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
// Initialize the input and output information of the switch actor According to node_.
void Initialize(const ControlNodeParserPtr &parser);
// Parse the input node information of the switch actor according to node_.
void ParseInput(const ControlNodeParserPtr &parser);
// Add input for all branches.
void AddCommonInput(const AnfNodePtr &node);
void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); }
// Fetch the input position of the data node.
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
private:
friend class GraphScheduler;
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
void InitSwitch();
void InitSwitchLayer();
void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id);
void ParseSwitchInput();
void ParseSwitchLayerInput();
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
// initialized with the return node of the subgraph.
void InitReturn(const ControlNodeParserPtr &parser);
void ParseReturnInput(const ControlNodeParserPtr &parser);
// Initialize the size of the vector members.
void InitVectorSize(const size_t num);
// Get index from DeviceTensor.
@ -170,6 +171,11 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
// Used to indicate that in the control flow, when the input of the call node is a call node, the switch actor
// corresponding to the switch node called by the sub call node. At this time, the funcgraph of the input of
// the switch actor will return to a partial node or funcgraph.
bool is_mulit_call_{false};
};
using SwitchActorPtr = std::shared_ptr<SwitchActor>;

View File

@ -22,7 +22,7 @@
namespace mindspore {
namespace runtime {
constexpr size_t kSingleCallDepth = 1;
namespace {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
// Fetch all the weight parameters related to node. It runs like this:
@ -339,25 +339,23 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::
const auto func_graphs = FetchFuncGraphbyCallNode(call_node);
for (const auto func_graph : func_graphs) {
if (func_graph->output()->isa<ValueNode>()) {
outputs.push_back(func_graph->output());
} else {
std::vector<AnfNodePtr> sub_call_nodes;
const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
for (const auto &graph_output : graph_outputs) {
if (graph_output->isa<Parameter>()) {
outputs.push_back(graph_output);
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
} else if (IsCallNode(graph_output)) {
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
} else if (graph_output->isa<CNode>()) {
outputs.emplace_back(graph_output);
} else {
MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
}
std::vector<AnfNodePtr> sub_call_nodes;
const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
for (const auto &graph_output : graph_outputs) {
if (graph_output->isa<Parameter>()) {
outputs.push_back(graph_output);
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
} else if (IsCallNode(graph_output)) {
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
} else if (graph_output->isa<CNode>()) {
outputs.emplace_back(graph_output);
} else if (graph_output->isa<ValueNode>()) {
outputs.push_back(graph_output);
} else {
MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
}
}
}
@ -452,6 +450,70 @@ std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr
}
return parameters;
}
// Get the number of calls, that is, the number of times the input of the call node is the call node.
size_t FetchCallDepth(const AnfNodePtr &node) {
if (!IsCallNode(node)) {
return 0;
}
const auto &cnode = node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return kSingleCallDepth + FetchCallDepth(inputs[0]);
}
// Get the final subgraph called by fungraph through the depth of calls.
FuncGraphPtr FetchFuncGraphByCallDepth(const FuncGraphPtr &func_graph, const size_t call_depth) {
if (call_depth <= kSingleCallDepth) {
return func_graph;
}
const auto &output = func_graph->output();
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
const auto &cnode = output->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
if (inputs.size() < kPartialInputStartPos) {
MS_LOG(EXCEPTION) << "Invalid partial node:" << AnfAlgo::GetNodeDebugString(output);
}
const auto &called_func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
return FetchFuncGraphByCallDepth(called_func_graph, call_depth - kSingleCallDepth);
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
return FetchFuncGraphByCallDepth(GetValueNode<FuncGraphPtr>(output), call_depth - kSingleCallDepth);
} else {
MS_LOG(EXCEPTION) << "Invalid output for call depth:" << call_depth << " funcgraph:" << func_graph->ToString()
<< " output node:" << AnfAlgo::GetNodeDebugString(output);
}
}
// Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
const auto &func_graph = GetFuncGraphFromPartial(node);
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
return FetchFuncGraphInNode(func_graph->output());
} else if (IsValueNode<FuncGraph>(func_graph->output())) {
// When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
return FetchFuncGraphInNode(func_graph->output());
}
return func_graph;
} else if (IsValueNode<FuncGraph>(node)) {
const auto &func_graph = GetValueNode<FuncGraphPtr>(node);
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
// When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called.
return FetchFuncGraphInNode(func_graph->output());
} else if (IsValueNode<FuncGraph>(func_graph->output())) {
// When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
return FetchFuncGraphInNode(func_graph->output());
}
return func_graph;
}
return nullptr;
}
} // namespace
// Return true if the node has Ref abstract.
@ -472,24 +534,56 @@ bool IsCallNode(const AnfNodePtr &node) {
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
}
std::vector<AnfNodePtr> FetchAllRealInputNodeByParameter(const AnfNodePtr &node) {
std::vector<AnfNodePtr> parameters;
const auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
bool IsSubCallNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
const auto inputs = node->cast<CNodePtr>()->inputs();
if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
return false;
}
const auto &switch_layer_inputs = inputs[0]->cast<CNodePtr>()->inputs();
const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
if (tuple_inputs.size() <= kMakeTupleInputStartPos) {
return false;
}
// Check whether the funcgraph called by the call node returns funcgraph or partial node.
FuncGraphPtr func_graph = nullptr;
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) {
const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast<CNodePtr>()->input(kPartialFuncGraphPos);
func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
} else if (tuple_inputs[kMakeTupleInputStartPos]->isa<ValueNode>() &&
IsValueNode<FuncGraph>(tuple_inputs[kMakeTupleInputStartPos])) {
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[kMakeTupleInputStartPos]);
}
const auto &output = func_graph->output();
return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) ||
(output->isa<ValueNode>() && IsValueNode<FuncGraph>(output));
}
std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
std::vector<KernelWithIndex> parameters;
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
const auto &real_node = real_node_with_index.first;
if (real_node->isa<Parameter>()) {
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
parameters.emplace_back(real_node);
parameters.emplace_back(real_node_with_index);
}
} else if (HasAbstractMonad(real_node)) {
return parameters;
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
const auto &inputs = real_node->cast<CNodePtr>()->inputs();
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
const auto &sub_parameters = FetchAllRealInputNodeByParameter(inputs[i]);
const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0});
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
}
} else {
parameters.emplace_back(real_node);
parameters.emplace_back(real_node_with_index);
}
return parameters;
}
@ -514,13 +608,15 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
// Fetch all funcgraphs in make tuple node.
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
func_graphs.emplace_back(GetFuncGraphFromPartial(tuple_inputs[i]));
} else if (IsValueNode<FuncGraph>(tuple_inputs[i])) {
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(tuple_inputs[i]));
const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]);
if (func_graph != nullptr) {
func_graphs.emplace_back(func_graph);
}
}
} else if (IsCallNode(cnode)) {
return FetchFuncGraphbyCallNode(cnode);
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
}
@ -563,7 +659,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
break;
}
total_num += call_output_num;
} else {
} else if (!HasAbstractMonad(inputs[i])) {
++total_num;
}
}
@ -612,16 +708,16 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
}
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
if (front_node != nullptr) {
return front_node;
return {front_node, 0};
}
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
if (front_node_with_index.first == nullptr) {
MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node);
}
return front_node_with_index.first;
return front_node_with_index;
}
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
@ -665,6 +761,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
FetchHostParameterToWeight(real_to_formal_front_parameters);
FetchCallInputKernelGraph(graphs, device_contexts);
FetchFrontValueNode(control_nodes, graphs, device_contexts);
FetchFrontToBackendKernel(graphs, device_contexts);
@ -757,7 +855,7 @@ AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &nod
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
for (const auto &front_weight : host_parameter_to_weight.second) {
if (front_weight == node) {
const auto &iter = front_to_backend_parameters_.find(front_weight);
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
if (iter != front_to_backend_parameters_.end()) {
return iter->second.first;
}
@ -868,6 +966,20 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
}
}
}
// When funcgraph called by call node returns to the value node, device addresses should be created for these
// value nodes.
for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) {
const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first);
for (const auto &func_graph : func_graphs) {
const auto &output = func_graph->output();
if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
const auto &device_context = call_node_to_backend_parameter.second.second;
CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
front_value_nodes_.push_back({output, device_context});
}
}
}
}
void ControlNodeParser::FetchFrontToFrontParameter(
@ -940,14 +1052,17 @@ void ControlNodeParser::FetchFrontToFrontParameter(
}
} else if (inputs[0]->isa<CNode>()) {
// Call node which the first input node is a switch or switchlayer node.
if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) &&
(!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) {
if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
std::vector<AnfNodePtr> call_inputs;
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
switch_input_parse(inputs[0], call_inputs);
} else if (IsCallNode(inputs[0])) {
continue;
} else {
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
<< AnfAlgo::GetNodeDebugString(inputs[0]);
}
std::vector<AnfNodePtr> call_inputs;
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
switch_input_parse(inputs[0], call_inputs);
}
}
}
@ -992,6 +1107,7 @@ void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &con
for (const auto &control_node : control_nodes) {
if (IsCallNode(control_node)) {
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
for (const auto &func_graph : func_graphs) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->output()->isa<ValueNode>()) {
@ -1019,7 +1135,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
call_input_kernel_graphs_[graph] = device_context;
break;
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
}
}
}
@ -1084,13 +1200,14 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
return parameters;
}
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
std::vector<AnfNodePtr> parameters;
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
std::vector<KernelWithIndex> parameters;
const auto &graph_parameters = graph->input_nodes();
for (const auto &graph_parameter : graph_parameters) {
const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(graph_parameter).first;
const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
const auto &internal_front_node = internal_front_node_with_index.first;
if (external_front_node == nullptr && internal_front_node == nullptr) {
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
@ -1098,9 +1215,9 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
continue;
}
const auto &front_node = (external_front_node != nullptr) ? external_front_node : internal_front_node;
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
const auto &sub_parameters = FetchAllRealInputNodeByParameter(real_front_node);
const auto &front_node_with_index =
((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index);
const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index);
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
}
@ -1191,6 +1308,8 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr>
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
// Switchlayer node.
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
} else if (IsCallNode(inputs[0])) {
continue;
} else {
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
}
@ -1232,10 +1351,6 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
const auto graph_output_map = graph->graph_output_map();
for (const auto &output_pair : graph_output_map) {
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(output_pair.second.first)
<< "index:" << output_pair.second.second << " addr:" << output_pair.second.first
<< " second:" << AnfAlgo::GetNodeDebugString(output_pair.first.first)
<< "index:" << output_pair.first.second << " addr:" << output_pair.first.first;
}
}
}
@ -1246,6 +1361,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
std::set<KernelWithIndex> *results) {
if (front_output->isa<ValueNode>()) {
(*results).insert({front_output, 0});
const auto &iter = formal_to_real_parameters_.find(front_output);
if (iter != formal_to_real_parameters_.end()) {
for (const auto &node : iter->second) {
@ -1405,6 +1521,15 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
}
}
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
for (const auto &front_weight : host_parameter_to_weight.second) {
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
if (iter != front_to_backend_parameters_.end()) {
formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0});
}
}
}
for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
const auto &func_graph = func_graph_to_parameters.first;
std::vector<AnfNodePtr> graph_inputs;
@ -1453,7 +1578,6 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro
const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
if (iter != front_to_backend_kernels_.end()) {
kernel_to_call_nodes_[iter->second.first.first] = control_node;
MS_LOG(DEBUG) << "Add auto monad control arrow for node:" << AnfAlgo::GetNodeDebugString(node);
}
}
}

View File

@ -50,6 +50,9 @@ using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>
// 2. First input of node is a funcgraph value node.
bool IsCallNode(const AnfNodePtr &node);
// Check if the call node is the input of another call node.
bool IsSubCallNode(const AnfNodePtr &node);
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);
@ -66,7 +69,7 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
// corresponding cnode.
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
// Get the funcgraph to which the node belongs.
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
@ -75,7 +78,7 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
// Get parameters in kernel graph.
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
class ControlNodeParser {
@ -205,6 +208,9 @@ class ControlNodeParser {
// the input node of gather.
FuncGraphToParameter func_graph_to_parameters_;
// The relationship between the valuenode inputs of the call node and the backend parameter
std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_;
// Branch id of funcgraph.
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch

View File

@ -1180,7 +1180,7 @@ std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompiler
const auto &actor_name = control_node->DebugString();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
control_node->cast<CNodePtr>(), branch_id, false);
switch_actor->Initialize(graph_compiler_info.control_node_parser_);
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
// Fetch all the input nodes of switch actor.
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
@ -1197,7 +1197,7 @@ std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompiler
const auto &actor_name = return_node->DebugString();
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
return_node->cast<CNodePtr>(), kInvalidBranchID, true);
switch_actor->Initialize(graph_compiler_info.control_node_parser_);
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
// Fetch all the input nodes of switch actor.
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
@ -1235,7 +1235,6 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
const auto &cnode = control_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
const auto &return_node = func_graph->get_return();
const auto &output_switch_aid = FetchActor(return_node->DebugString())->GetAID();
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
// Root funcgraph does not need to create a gather actor.
@ -1245,21 +1244,26 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
}
// If the output of funcgraph is a value node, no need to create gather actor.
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
if (inputs[kReturnInputPos]->isa<ValueNode>() ||
AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
continue;
}
auto actor_name = func_graph->ToString();
std::vector<AnfNodePtr> parameters;
std::vector<KernelWithIndex> parameters;
for (const auto &parameter : func_graph->get_inputs()) {
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
continue;
}
parameters.emplace_back(parameter);
parameters.push_back({parameter, 0});
}
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
const auto &output_switch_actor = FetchActor(return_node->DebugString());
MS_EXCEPTION_IF_NULL(output_switch_actor);
const auto &output_switch_aid = output_switch_actor->GetAID();
auto gather_actor =
std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
@ -1275,12 +1279,12 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
// Collect the parameters.
std::vector<AnfNodePtr> parameters;
std::vector<KernelWithIndex> parameters;
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
continue;
}
parameters.emplace_back(inputs[i]);
parameters.push_back({inputs[i], 0});
}
auto func_graph = control_node->func_graph();
@ -1322,9 +1326,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
auto front_node = GetFrontNodeByBackendNode(from_kernel);
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
if (HasAbstractRef(from_kernel)) {
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
const auto &real_front_node_with_index =
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
if (HasAbstractRef(real_front_node_with_index.first)) {
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
real_front_node_with_index.first.get());
return;
}
@ -1332,9 +1339,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
const auto actor_name = graph->ToString();
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
const auto &real_front_node = GetFrontNodeByKernelGraph(from_kernel, graph);
LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), real_front_node, to_actor,
to_kernel_with_input_idx.second);
LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), to_actor, real_front_node_with_index,
to_kernel_with_input_idx);
return;
}
@ -1355,8 +1361,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get());
return;
}
LinkDataArrowForGatherActor(from_actor, front_node, to_actor, to_kernel_with_input_idx.second);
LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel, graph, tensor, graph_compiler_info.origin_parameters_order_,
graph_compiler_info.strategy_)) {
// Link the data arrows of host queue data source actor.
@ -2053,25 +2058,84 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
}
}
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set) {
for (const auto &node : graph_compiler_info.control_nodes_) {
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &node : control_nodes) {
CNodePtr cnode = node->cast<CNodePtr>();
const auto &from_func_graph = node->func_graph();
auto inputs = cnode->inputs();
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
if (inputs[0]->isa<CNode>()) {
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
// Add the input of call node to switch actor.
if (IsCallNode(inputs[0])) {
const auto &sub_call_cnode = inputs[0]->cast<CNodePtr>();
const auto &sub_inputs = sub_call_cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(sub_inputs[0], prim::kPrimSwitchLayer)) {
auto actor = FetchActor(sub_inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
switch_actor->AddCommonInput(inputs[i]);
}
}
} else if (IsSubCallNode(cnode)) {
// Add the input of sub call node to switch actor.
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
const auto &tuple_node = inputs[0]->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
FuncGraphPtr func_graph = nullptr;
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
int pre_real_parameter_num = 0;
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
pre_real_parameter_num = (tuple_inputs[i]->cast<CNodePtr>()->inputs().size() - kPartialInputStartPos);
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]->cast<CNodePtr>()->input(kPartialFuncGraphPos));
} else {
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
}
const auto parameters = func_graph->parameters();
const auto &output = func_graph->output();
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
const auto &sub_partial_inputs = output->cast<CNodePtr>()->inputs();
// Check whether the input node of the sub call node needs to be added to the switch actor. Only when
// the final return is a partial node and the partial node needs this input, the input node is added
// to the switch actor/
for (size_t j = kPartialInputStartPos; j < sub_partial_inputs.size(); ++j) {
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[j], 0).first;
const auto &iter = find(parameters.begin(), parameters.end(), real_partial_input);
if ((iter != parameters.end()) && (iter - parameters.begin() >= pre_real_parameter_num) &&
(iter - parameters.begin() <
SizeToInt(pre_real_parameter_num + inputs.size() - kCallInputStartPos))) {
size_t pos = iter - parameters.begin() - pre_real_parameter_num + kCallInputStartPos;
switch_actor->AddSingleInput(inputs[pos], i - 1);
}
}
}
}
} else {
auto actor = FetchActor(inputs[0]->DebugString());
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
if (HasAbstractMonad(inputs[i])) {
continue;
}
switch_actor->AddCommonInput(inputs[i]);
}
switch_actor->AddCommonInput(inputs[i]);
}
}
}
}
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set) {
PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
for (const auto &node : graph_compiler_info.control_nodes_) {
CNodePtr cnode = node->cast<CNodePtr>();
@ -2142,11 +2206,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
for (const auto &input_node : gather_actor->data_nodes_) {
for (const auto &input_with_index : gather_actor->data_nodes_) {
const auto &from_func_graph = kernel_graph->GetFuncGraph();
const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor,
gather_actor->FetchDataNodePosition(input_node));
gather_actor->FetchDataNodePosition(input_with_index));
}
}
LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set);
@ -2162,15 +2225,16 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
}
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node,
KernelActor *to_actor, const size_t to_index) {
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
const KernelWithIndex &front_node_with_index,
const KernelWithIndex &to_node_with_index) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(front_node);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
auto position = from_actor->FetchDataNodePosition(front_node);
auto position = from_actor->FetchDataNodePosition(front_node_with_index);
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
from_actor->output_data_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
}
@ -2181,13 +2245,18 @@ void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_w
// Fetch all the funcgraph that call node would call.
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
const auto &call_inputs = cnode->inputs();
auto switch_node = call_inputs[0];
if (IsCallNode(switch_node)) {
switch_node = call_inputs[0]->cast<CNodePtr>()->input(0);
}
// Collect the output of each funcgraph.
for (const auto &func_graph : func_graphs) {
if (func_graph->output()->isa<ValueNode>()) {
const auto &call_inputs = cnode->inputs();
if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) {
const auto &actor_name = call_inputs[0]->DebugString();
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitchLayer)) {
const auto &actor_name = switch_node->DebugString();
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
@ -2254,7 +2323,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const
}
for (size_t i = start_branch; i < max_branch; ++i) {
if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i;
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
<< " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
<< " to actor:" << to_actor->GetAID() << " to index:" << to_index;
}
auto op_arrow =
std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
@ -2277,7 +2348,7 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
// The actor input is a parameter in gather actor.
auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
auto position = from_actor->FetchDataNodePosition(input_node);
auto position = from_actor->FetchDataNodePosition({input_node, 0});
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
from_actor->output_data_arrows_.emplace_back(op_arrow);
} else if (IsSwitchActor(input_node)) {
@ -2338,7 +2409,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
UpdateRefCount(device_tensor.get(), true);
} else {
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node);
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
<< " to actor:" << to_actor->GetAID();
}
}
@ -2559,65 +2631,6 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
}
}
void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
OutputActor *to_actor = actor_set->output_actor_.get();
MS_EXCEPTION_IF_NULL(to_actor);
for (const auto &func_graph_to_branch_id : graph_compiler_info.control_node_parser_->func_graph_to_branch_id_) {
if (func_graph_to_branch_id.second == kMainBranchID) {
continue;
}
const auto &func_graph = func_graph_to_branch_id.first;
auto actor = FetchActor(func_graph->ToString());
MS_EXCEPTION_IF_NULL(actor);
auto gather_actor = dynamic_cast<GatherActor *>(actor);
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
const auto front_node = gather_actor->data_nodes_[i];
auto origin_output_with_index = KernelWithIndex(front_node, 0);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue;
}
for (auto &output_position : iter->second) {
auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), output_position);
gather_actor->output_result_arrows_.emplace_back(op_arrow);
const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node];
if (backend_nodes.empty()) {
MS_LOG(EXCEPTION) << "No backend node for data node:" << AnfAlgo::GetNodeDebugString(front_node);
}
const auto &backend_node = backend_nodes[0].first;
if (backend_node->isa<Parameter>()) {
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
auto ds_op_actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(ds_op_actor);
auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(ds_op_actor);
MS_EXCEPTION_IF_NULL(host_ds_actor);
const auto &data_nodes = host_ds_actor->data_nodes_;
const auto &node_iter = find(data_nodes.begin(), data_nodes.end(), backend_node);
if (node_iter == data_nodes.end()) {
MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:"
<< AnfAlgo::GetNodeDebugString(backend_node);
}
to_actor->device_contexts_[output_position] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()];
} else {
auto actor_base = FetchActor(backend_node->fullname_with_scope());
MS_EXCEPTION_IF_NULL(actor_base);
auto kernel_actor = dynamic_cast<KernelActor *>(actor_base);
MS_EXCEPTION_IF_NULL(kernel_actor);
to_actor->device_contexts_[output_position] = kernel_actor->device_context_;
}
}
}
}
}
bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
MS_EXCEPTION_IF_NULL(actor_set);
// Check the data source actors.
@ -3098,7 +3111,7 @@ void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &of
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
for (const auto &node : actor->data_nodes_) {
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
}
ofs << "\t\tactor front to backend node:\n";

View File

@ -233,8 +233,9 @@ class GraphScheduler {
// 4. The processing of control flow linking.
void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set);
void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor,
const size_t to_index);
void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
const KernelWithIndex &front_node_with_index,
const KernelWithIndex &to_node_with_index);
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor);
// Connect the input of the actor.
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
@ -263,6 +264,9 @@ class GraphScheduler {
const ControlNodeParserPtr &control_node_parser,
const std::vector<AnfNodePtr> &origin_parameters,
const std::vector<TensorPtr> &tensors, std::vector<TensorPtr> *host_tensors);
// Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
// to switch actor.
void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
// The processing of actors link dynamically.
// Analyze necessary input data of current actor, generate and cache op arrow