forked from mindspore-Ecosystem/mindspore
!19655 fix switch layer call to call
Merge pull request !19655 from gaoyong10/new_runtime23
This commit is contained in:
commit
de4b9a94fc
|
@ -42,10 +42,10 @@ void GatherActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
||||
size_t GatherActor::FetchDataNodePosition(const KernelWithIndex &data_node) const {
|
||||
const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node);
|
||||
if (iter == data_nodes_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node)
|
||||
MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node.first) << " index:" << data_node.second
|
||||
<< " is not exist in gather actor:" << GetAID();
|
||||
}
|
||||
return iter - data_nodes_.begin();
|
||||
|
@ -114,7 +114,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
size_t from_index = result_arrow->from_output_index_;
|
||||
const auto &front_node = data_nodes_[from_index];
|
||||
const auto &front_node = data_nodes_[from_index].first;
|
||||
for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) {
|
||||
if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() ==
|
||||
input_device_tensors_[from_index]) {
|
||||
|
|
|
@ -47,7 +47,7 @@ constexpr size_t kReturnInputPos = 1;
|
|||
// collected at the entrance of the kernel graph.
|
||||
class GatherActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
GatherActor(const std::string &name, const std::vector<AnfNodePtr> ¶meters, const bool need_branch_id_input,
|
||||
GatherActor(const std::string &name, const std::vector<KernelWithIndex> ¶meters, const bool need_branch_id_input,
|
||||
const AID switch_aid, const AID gather_aid, const int branch_id)
|
||||
: OpActor(name),
|
||||
data_nodes_(parameters),
|
||||
|
@ -60,7 +60,7 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
~GatherActor() override = default;
|
||||
|
||||
// Get the index of the parameter, the data_node needs to be the front node.
|
||||
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
|
||||
size_t FetchDataNodePosition(const KernelWithIndex &data_node) const;
|
||||
|
||||
// The gather actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
|
||||
|
@ -107,7 +107,7 @@ class GatherActor : public OpActor<DeviceTensor> {
|
|||
std::vector<AID> output_branch_arrows_;
|
||||
|
||||
// Parameters of sub funcgraph, which is the front node.
|
||||
std::vector<AnfNodePtr> data_nodes_;
|
||||
std::vector<KernelWithIndex> data_nodes_;
|
||||
std::vector<DeviceContext *> device_contexts_;
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
|
|
@ -78,7 +78,9 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
if (input_data->data_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty";
|
||||
std::string error_info =
|
||||
"Input data of actor:" + GetAID().Name() + " num:" + std::to_string(input_data->index_) + " is empty";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
|
|
|
@ -72,20 +72,20 @@ void SwitchActor::CollectBranchId(const int branch_id, OpContext<DeviceTensor> *
|
|||
input_branch_ids_[sequential_num].push(branch_id);
|
||||
}
|
||||
|
||||
void SwitchActor::Initialize(const ControlNodeParserPtr &parser) {
|
||||
void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) {
|
||||
std::vector<AnfNodePtr> inputs = node_->inputs();
|
||||
|
||||
if (IsPrimitive(inputs[0], prim::kPrimSwitch)) {
|
||||
InitSwitch();
|
||||
ParseSwitchInput();
|
||||
} else if (IsPrimitive(inputs[0], prim::kPrimReturn)) {
|
||||
InitReturn(parser);
|
||||
ParseReturnInput(parser);
|
||||
} else {
|
||||
InitSwitchLayer();
|
||||
ParseSwitchLayerInput();
|
||||
}
|
||||
backend_parameters_.resize(input_nodes_.size());
|
||||
}
|
||||
|
||||
void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
||||
void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
|
||||
|
@ -93,20 +93,50 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) {
|
|||
// [0] ValueNode<Primitive> kPartial.
|
||||
// [1] ValueNode<FuncGraphPtr>.
|
||||
// [2..] Inputs.
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
if (node_inputs.size() <= kPartialFuncGraphPos) {
|
||||
auto partial_inputs = cnode->inputs();
|
||||
if (partial_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode);
|
||||
}
|
||||
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(node_inputs[kPartialFuncGraphPos]);
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(func_graph->output(), branch_id);
|
||||
return;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
|
||||
// If the funcgraph called by the partial returns a partial node, the switch actor should call the funcgraph
|
||||
// of the sub partial. Similarly, the input node should also be the input of the sub partial.
|
||||
is_mulit_call_ = true;
|
||||
CNodePtr sub_partial = func_graph->output()->cast<CNodePtr>();
|
||||
const auto &sub_partial_inputs = sub_partial->inputs();
|
||||
if (sub_partial_inputs.size() <= kPartialFuncGraphPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(sub_partial);
|
||||
}
|
||||
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(sub_partial_inputs[kPartialFuncGraphPos]);
|
||||
|
||||
if (sub_func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(sub_func_graph->output(), branch_id);
|
||||
return;
|
||||
}
|
||||
|
||||
branch_func_graph_[branch_id] = sub_func_graph;
|
||||
const auto &sub_parameters = func_graph->parameters();
|
||||
|
||||
// Record the input that comes with the sub partial node.
|
||||
for (size_t i = kPartialInputStartPos; i < sub_partial_inputs.size(); ++i) {
|
||||
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[i], 0).first;
|
||||
const auto &iter = find(sub_parameters.begin(), sub_parameters.end(), real_partial_input);
|
||||
if ((iter != sub_parameters.end()) &&
|
||||
((iter - sub_parameters.begin()) < SizeToInt(partial_inputs.size() - kPartialInputStartPos))) {
|
||||
AddInput(partial_inputs[iter - sub_parameters.begin() + kPartialInputStartPos], branch_id);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
branch_func_graph_[branch_id] = func_graph;
|
||||
for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) {
|
||||
AddInput(node_inputs[j], branch_id);
|
||||
for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) {
|
||||
AddInput(partial_inputs[j], branch_id);
|
||||
}
|
||||
} else {
|
||||
AddInput(node, branch_id);
|
||||
|
@ -122,15 +152,21 @@ void SwitchActor::InitVectorSize(const size_t num) {
|
|||
output_branch_branch_arrows_.resize(num);
|
||||
}
|
||||
|
||||
void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) {
|
||||
void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) {
|
||||
const auto &func_graph = node_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &call_num = parser->GetCallNumByFuncGraph(func_graph);
|
||||
InitVectorSize(call_num);
|
||||
|
||||
// If the return is a partial node or funcgraph, this subgraph will not be initialized and no input is required.
|
||||
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial) ||
|
||||
(func_graph->output()->isa<ValueNode>() && IsValueNode<FuncGraph>(func_graph->output()))) {
|
||||
return;
|
||||
}
|
||||
AddCommonInput(func_graph->output());
|
||||
}
|
||||
|
||||
void SwitchActor::InitSwitch() {
|
||||
void SwitchActor::ParseSwitchInput() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitch.
|
||||
// [1] switch condition.
|
||||
|
@ -147,11 +183,11 @@ void SwitchActor::InitSwitch() {
|
|||
input_nodes_.push_back(cond_node);
|
||||
input_datas_num_++;
|
||||
// Init the two branches of switch node.
|
||||
InitPartial(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
InitPartial(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
|
||||
ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast<size_t>(false));
|
||||
ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast<size_t>(true));
|
||||
}
|
||||
|
||||
void SwitchActor::InitSwitchLayer() {
|
||||
void SwitchActor::ParseSwitchLayerInput() {
|
||||
// The inputs of the switch node:
|
||||
// [0] ValueNode<Primitive> kSwitchLayer.
|
||||
// [1] switchLayer index.
|
||||
|
@ -170,11 +206,30 @@ void SwitchActor::InitSwitchLayer() {
|
|||
InitVectorSize(branch_nodes.size() - 1);
|
||||
|
||||
// Parse all branches.
|
||||
for (size_t i = 1; i < branch_nodes.size(); ++i) {
|
||||
for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) {
|
||||
InitPartial(branch_nodes[i], i - 1);
|
||||
ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos);
|
||||
} else if (branch_nodes[i]->isa<ValueNode>()) {
|
||||
branch_func_graph_[i - 1] = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(branch_nodes[i]);
|
||||
const auto output = func_graph->output();
|
||||
|
||||
// The switch layer node has a second-order call connected to call. When the called funcgraph returns a partial
|
||||
// node or funcgraph, the switch actor needs to call the funcgraph directly.
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
is_mulit_call_ = true;
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] =
|
||||
GetValueNode<FuncGraphPtr>(output->cast<CNodePtr>()->input(kPartialFuncGraphPos));
|
||||
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
|
||||
is_mulit_call_ = true;
|
||||
const auto &sub_func_graph = GetValueNode<FuncGraphPtr>(output);
|
||||
if (sub_func_graph->output()->isa<ValueNode>()) {
|
||||
AddInput(sub_func_graph->output(), i - kMakeTupleInputStartPos);
|
||||
continue;
|
||||
}
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] = GetValueNode<FuncGraphPtr>(output);
|
||||
} else {
|
||||
branch_func_graph_[i - kMakeTupleInputStartPos] = func_graph;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -198,8 +253,12 @@ size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
|
|||
void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) {
|
||||
const auto &node = node_with_index.first;
|
||||
|
||||
// Add weight and value node.
|
||||
if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa<Parameter>() && HasAbstractRef(node)) ||
|
||||
// The value node and weight node need to be placed in the device store. The switch actor has three inputs:
|
||||
// 1) The input of the switch is the value node.
|
||||
// 2) There is a weight node or value node in the return of the sub funcgraph.
|
||||
// 3) When the switch actor is a second-order call, it does not distinguish between weight and parameter.
|
||||
if (((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) || is_mulit_call_) && node->isa<Parameter>() &&
|
||||
HasAbstractRef(node)) ||
|
||||
node->isa<ValueNode>()) {
|
||||
const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index);
|
||||
if (iter != input_nodes_.end()) {
|
||||
|
@ -243,7 +302,6 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) {
|
|||
} else if (IsCallNode(real_input.first)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes);
|
||||
|
||||
if (call_output_num <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first);
|
||||
}
|
||||
|
@ -259,25 +317,26 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
if (need_branch_id_input_) {
|
||||
if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() ||
|
||||
input_branch_ids_[context->sequential_num_].empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID();
|
||||
MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name();
|
||||
}
|
||||
size_t branch_id = input_branch_ids_[context->sequential_num_].top();
|
||||
input_branch_ids_[context->sequential_num_].pop();
|
||||
if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id;
|
||||
MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() +
|
||||
" branch id:" + std::to_string(branch_id);
|
||||
}
|
||||
return branch_id_to_index_[branch_id];
|
||||
}
|
||||
|
||||
DeviceTensor *device_tensor = input_device_tensors_[0];
|
||||
if (device_tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID();
|
||||
MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name();
|
||||
}
|
||||
auto inputs = node_->inputs();
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0);
|
||||
size_t size = abstract::TypeIdSize(type_id);
|
||||
if (size > sizeof(int64_t)) {
|
||||
MS_LOG(EXCEPTION) << "Index must be Int type.";
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
int64_t index = 0;
|
||||
|
@ -293,7 +352,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) {
|
|||
bool cond = (static_cast<bool *>(static_cast<void *>(buf)))[0];
|
||||
index = static_cast<int64_t>(cond ? 1 : 0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Index must be Int type.";
|
||||
MS_LOG(ERROR) << "Index must be Int type.";
|
||||
}
|
||||
|
||||
// SwitchLayer node support negative index range [-size, -1].
|
||||
|
@ -352,7 +411,7 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
@ -370,7 +429,8 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto index = GetIndex(context);
|
||||
if (index >= output_branch_arrows_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index;
|
||||
std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index);
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
// Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal
|
||||
|
@ -389,8 +449,10 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
auto &result_arrow = output_branch_result_arrow[i];
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from index in switch actor, from index:" << result_arrow->from_output_index_
|
||||
<< " total:" << branch_inputs_pos_[index].size() << " actor:" << GetAID();
|
||||
std::string error_info =
|
||||
"Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) +
|
||||
" total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_];
|
||||
|
||||
|
@ -410,10 +472,12 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
if (!is_send) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get backend node of switch actor output, actor:" << GetAID()
|
||||
<< " branch:" << index << " index:" << result_arrow->from_output_index_ << " output pos"
|
||||
<< branch_inputs_pos_[index][result_arrow->from_output_index_] << " output index"
|
||||
<< result_arrow->to_input_index_;
|
||||
std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() +
|
||||
" branch:" + std::to_string(index) +
|
||||
" index:" + std::to_string(result_arrow->from_output_index_) + " output pos" +
|
||||
std::to_string(branch_inputs_pos_[index][result_arrow->from_output_index_]) +
|
||||
" output index" + std::to_string(result_arrow->to_input_index_);
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -426,7 +490,6 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
data->data_ = input_device_tensors_[data_arrow->from_output_index_];
|
||||
|
||||
Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context);
|
||||
}
|
||||
|
||||
|
|
|
@ -76,22 +76,23 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
|
||||
// The switch actor run when receive the input branch id.
|
||||
void CollectBranchId(const int branch_id, OpContext<DeviceTensor> *context);
|
||||
// Initialize the input and output information of the switch actor According to node_.
|
||||
void Initialize(const ControlNodeParserPtr &parser);
|
||||
// Parse the input node information of the switch actor according to node_.
|
||||
void ParseInput(const ControlNodeParserPtr &parser);
|
||||
// Add input for all branches.
|
||||
void AddCommonInput(const AnfNodePtr &node);
|
||||
void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); }
|
||||
// Fetch the input position of the data node.
|
||||
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void InitPartial(const AnfNodePtr &node, const size_t branch_id);
|
||||
void InitSwitch();
|
||||
void InitSwitchLayer();
|
||||
void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id);
|
||||
void ParseSwitchInput();
|
||||
void ParseSwitchLayerInput();
|
||||
// In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is
|
||||
// initialized with the return node of the subgraph.
|
||||
void InitReturn(const ControlNodeParserPtr &parser);
|
||||
void ParseReturnInput(const ControlNodeParserPtr &parser);
|
||||
// Initialize the size of the vector members.
|
||||
void InitVectorSize(const size_t num);
|
||||
// Get index from DeviceTensor.
|
||||
|
@ -170,6 +171,11 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_;
|
||||
|
||||
// Used to indicate that in the control flow, when the input of the call node is a call node, the switch actor
|
||||
// corresponding to the switch node called by the sub call node. At this time, the funcgraph of the input of
|
||||
// the switch actor will return to a partial node or funcgraph.
|
||||
bool is_mulit_call_{false};
|
||||
};
|
||||
|
||||
using SwitchActorPtr = std::shared_ptr<SwitchActor>;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
constexpr size_t kSingleCallDepth = 1;
|
||||
namespace {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
// Fetch all the weight parameters related to node. It runs like this:
|
||||
|
@ -339,25 +339,23 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::
|
|||
const auto func_graphs = FetchFuncGraphbyCallNode(call_node);
|
||||
|
||||
for (const auto func_graph : func_graphs) {
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
outputs.push_back(func_graph->output());
|
||||
} else {
|
||||
std::vector<AnfNodePtr> sub_call_nodes;
|
||||
const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
|
||||
for (const auto &graph_output : graph_outputs) {
|
||||
if (graph_output->isa<Parameter>()) {
|
||||
outputs.push_back(graph_output);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
|
||||
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
|
||||
outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||
} else if (IsCallNode(graph_output)) {
|
||||
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
|
||||
outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||
} else if (graph_output->isa<CNode>()) {
|
||||
outputs.emplace_back(graph_output);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
|
||||
}
|
||||
std::vector<AnfNodePtr> sub_call_nodes;
|
||||
const std::vector<AnfNodePtr> graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes);
|
||||
for (const auto &graph_output : graph_outputs) {
|
||||
if (graph_output->isa<Parameter>()) {
|
||||
outputs.push_back(graph_output);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
|
||||
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
|
||||
outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||
} else if (IsCallNode(graph_output)) {
|
||||
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
|
||||
outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||
} else if (graph_output->isa<CNode>()) {
|
||||
outputs.emplace_back(graph_output);
|
||||
} else if (graph_output->isa<ValueNode>()) {
|
||||
outputs.push_back(graph_output);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -452,6 +450,70 @@ std::vector<AnfNodePtr> FetchParameterByControlNode(const std::vector<AnfNodePtr
|
|||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
// Get the number of calls, that is, the number of times the input of the call node is the call node.
|
||||
size_t FetchCallDepth(const AnfNodePtr &node) {
|
||||
if (!IsCallNode(node)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
return kSingleCallDepth + FetchCallDepth(inputs[0]);
|
||||
}
|
||||
|
||||
// Get the final subgraph called by fungraph through the depth of calls.
|
||||
FuncGraphPtr FetchFuncGraphByCallDepth(const FuncGraphPtr &func_graph, const size_t call_depth) {
|
||||
if (call_depth <= kSingleCallDepth) {
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
const auto &output = func_graph->output();
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
const auto &cnode = output->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() < kPartialInputStartPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial node:" << AnfAlgo::GetNodeDebugString(output);
|
||||
}
|
||||
const auto &called_func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
|
||||
return FetchFuncGraphByCallDepth(called_func_graph, call_depth - kSingleCallDepth);
|
||||
} else if (output->isa<ValueNode>() && IsValueNode<FuncGraph>(output)) {
|
||||
return FetchFuncGraphByCallDepth(GetValueNode<FuncGraphPtr>(output), call_depth - kSingleCallDepth);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid output for call depth:" << call_depth << " funcgraph:" << func_graph->ToString()
|
||||
<< " output node:" << AnfAlgo::GetNodeDebugString(output);
|
||||
}
|
||||
}
|
||||
|
||||
// Get funcgraph from node, the interface only accepts partial node and funcgraph value node.
|
||||
FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
const auto &func_graph = GetFuncGraphFromPartial(node);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
|
||||
return FetchFuncGraphInNode(func_graph->output());
|
||||
} else if (IsValueNode<FuncGraph>(func_graph->output())) {
|
||||
// When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
|
||||
return FetchFuncGraphInNode(func_graph->output());
|
||||
}
|
||||
|
||||
return func_graph;
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) {
|
||||
// When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called.
|
||||
return FetchFuncGraphInNode(func_graph->output());
|
||||
} else if (IsValueNode<FuncGraph>(func_graph->output())) {
|
||||
// When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called.
|
||||
return FetchFuncGraphInNode(func_graph->output());
|
||||
}
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Return true if the node has Ref abstract.
|
||||
|
@ -472,24 +534,56 @@ bool IsCallNode(const AnfNodePtr &node) {
|
|||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchAllRealInputNodeByParameter(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
const auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
bool IsSubCallNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto inputs = node->cast<CNodePtr>()->inputs();
|
||||
|
||||
if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &switch_layer_inputs = inputs[0]->cast<CNodePtr>()->inputs();
|
||||
const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
if (tuple_inputs.size() <= kMakeTupleInputStartPos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the funcgraph called by the call node returns funcgraph or partial node.
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) {
|
||||
const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast<CNodePtr>()->input(kPartialFuncGraphPos);
|
||||
func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
|
||||
} else if (tuple_inputs[kMakeTupleInputStartPos]->isa<ValueNode>() &&
|
||||
IsValueNode<FuncGraph>(tuple_inputs[kMakeTupleInputStartPos])) {
|
||||
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[kMakeTupleInputStartPos]);
|
||||
}
|
||||
|
||||
const auto &output = func_graph->output();
|
||||
return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) ||
|
||||
(output->isa<ValueNode>() && IsValueNode<FuncGraph>(output));
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) {
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second);
|
||||
const auto &real_node = real_node_with_index.first;
|
||||
if (real_node->isa<Parameter>()) {
|
||||
if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) {
|
||||
parameters.emplace_back(real_node);
|
||||
parameters.emplace_back(real_node_with_index);
|
||||
}
|
||||
} else if (HasAbstractMonad(real_node)) {
|
||||
return parameters;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
|
||||
const auto &inputs = real_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) {
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter(inputs[i]);
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0});
|
||||
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
}
|
||||
} else {
|
||||
parameters.emplace_back(real_node);
|
||||
parameters.emplace_back(real_node_with_index);
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
@ -514,13 +608,15 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
|||
AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) {
|
||||
const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast<CNodePtr>()->inputs();
|
||||
|
||||
// Fetch all funcgraphs in make tuple node.
|
||||
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
|
||||
func_graphs.emplace_back(GetFuncGraphFromPartial(tuple_inputs[i]));
|
||||
} else if (IsValueNode<FuncGraph>(tuple_inputs[i])) {
|
||||
func_graphs.emplace_back(GetValueNode<FuncGraphPtr>(tuple_inputs[i]));
|
||||
const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]);
|
||||
if (func_graph != nullptr) {
|
||||
func_graphs.emplace_back(func_graph);
|
||||
}
|
||||
}
|
||||
} else if (IsCallNode(cnode)) {
|
||||
return FetchFuncGraphbyCallNode(cnode);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||
}
|
||||
|
@ -563,7 +659,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
|||
break;
|
||||
}
|
||||
total_num += call_output_num;
|
||||
} else {
|
||||
} else if (!HasAbstractMonad(inputs[i])) {
|
||||
++total_num;
|
||||
}
|
||||
}
|
||||
|
@ -612,16 +708,16 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) {
|
|||
return kernel_graph->GetFrontAnfByBackendAnf(backend_node);
|
||||
}
|
||||
|
||||
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node);
|
||||
if (front_node != nullptr) {
|
||||
return front_node;
|
||||
return {front_node, 0};
|
||||
}
|
||||
const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
|
||||
if (front_node_with_index.first == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node);
|
||||
}
|
||||
return front_node_with_index.first;
|
||||
return front_node_with_index;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) {
|
||||
|
@ -665,6 +761,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
FetchHostParameterToWeight(real_to_formal_front_parameters);
|
||||
|
||||
FetchCallInputKernelGraph(graphs, device_contexts);
|
||||
|
||||
FetchFrontValueNode(control_nodes, graphs, device_contexts);
|
||||
|
||||
FetchFrontToBackendKernel(graphs, device_contexts);
|
||||
|
@ -757,7 +855,7 @@ AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &nod
|
|||
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
|
||||
for (const auto &front_weight : host_parameter_to_weight.second) {
|
||||
if (front_weight == node) {
|
||||
const auto &iter = front_to_backend_parameters_.find(front_weight);
|
||||
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
|
||||
if (iter != front_to_backend_parameters_.end()) {
|
||||
return iter->second.first;
|
||||
}
|
||||
|
@ -868,6 +966,20 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When funcgraph called by call node returns to the value node, device addresses should be created for these
|
||||
// value nodes.
|
||||
for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) {
|
||||
const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first);
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
const auto &output = func_graph->output();
|
||||
if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) {
|
||||
const auto &device_context = call_node_to_backend_parameter.second.second;
|
||||
CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context);
|
||||
front_value_nodes_.push_back({output, device_context});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontToFrontParameter(
|
||||
|
@ -940,14 +1052,17 @@ void ControlNodeParser::FetchFrontToFrontParameter(
|
|||
}
|
||||
} else if (inputs[0]->isa<CNode>()) {
|
||||
// Call node which the first input node is a switch or switchlayer node.
|
||||
if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) &&
|
||||
(!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) {
|
||||
if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
std::vector<AnfNodePtr> call_inputs;
|
||||
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
|
||||
switch_input_parse(inputs[0], call_inputs);
|
||||
} else if (IsCallNode(inputs[0])) {
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
|
||||
<< AnfAlgo::GetNodeDebugString(inputs[0]);
|
||||
}
|
||||
std::vector<AnfNodePtr> call_inputs;
|
||||
call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end());
|
||||
switch_input_parse(inputs[0], call_inputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -992,6 +1107,7 @@ void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &con
|
|||
for (const auto &control_node : control_nodes) {
|
||||
if (IsCallNode(control_node)) {
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
|
||||
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
|
@ -1019,7 +1135,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
|
|||
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
|
||||
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
|
||||
call_input_kernel_graphs_[graph] = device_context;
|
||||
break;
|
||||
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1084,13 +1200,14 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
|
|||
return parameters;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph) {
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
const auto &graph_parameters = graph->input_nodes();
|
||||
|
||||
for (const auto &graph_parameter : graph_parameters) {
|
||||
const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter);
|
||||
const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(graph_parameter).first;
|
||||
const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter);
|
||||
const auto &internal_front_node = internal_front_node_with_index.first;
|
||||
|
||||
if (external_front_node == nullptr && internal_front_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :"
|
||||
|
@ -1098,9 +1215,9 @@ std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph)
|
|||
continue;
|
||||
}
|
||||
|
||||
const auto &front_node = (external_front_node != nullptr) ? external_front_node : internal_front_node;
|
||||
const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node, 0).first;
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter(real_front_node);
|
||||
const auto &front_node_with_index =
|
||||
((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index);
|
||||
const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index);
|
||||
parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
}
|
||||
|
||||
|
@ -1191,6 +1308,8 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr>
|
|||
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
// Switchlayer node.
|
||||
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
|
||||
} else if (IsCallNode(inputs[0])) {
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
|
||||
}
|
||||
|
@ -1232,10 +1351,6 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
|
|||
const auto graph_output_map = graph->graph_output_map();
|
||||
for (const auto &output_pair : graph_output_map) {
|
||||
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
|
||||
MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(output_pair.second.first)
|
||||
<< "index:" << output_pair.second.second << " addr:" << output_pair.second.first
|
||||
<< " second:" << AnfAlgo::GetNodeDebugString(output_pair.first.first)
|
||||
<< "index:" << output_pair.first.second << " addr:" << output_pair.first.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1246,6 +1361,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
std::set<KernelWithIndex> *results) {
|
||||
if (front_output->isa<ValueNode>()) {
|
||||
(*results).insert({front_output, 0});
|
||||
|
||||
const auto &iter = formal_to_real_parameters_.find(front_output);
|
||||
if (iter != formal_to_real_parameters_.end()) {
|
||||
for (const auto &node : iter->second) {
|
||||
|
@ -1405,6 +1521,15 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr>
|
|||
}
|
||||
}
|
||||
|
||||
for (const auto &host_parameter_to_weight : host_parameter_to_weights_) {
|
||||
for (const auto &front_weight : host_parameter_to_weight.second) {
|
||||
const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first);
|
||||
if (iter != front_to_backend_parameters_.end()) {
|
||||
formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &func_graph_to_parameters : func_graph_to_parameters_) {
|
||||
const auto &func_graph = func_graph_to_parameters.first;
|
||||
std::vector<AnfNodePtr> graph_inputs;
|
||||
|
@ -1453,7 +1578,6 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro
|
|||
const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0));
|
||||
if (iter != front_to_backend_kernels_.end()) {
|
||||
kernel_to_call_nodes_[iter->second.first.first] = control_node;
|
||||
MS_LOG(DEBUG) << "Add auto monad control arrow for node:" << AnfAlgo::GetNodeDebugString(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,6 +50,9 @@ using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>
|
|||
// 2. First input of node is a funcgraph value node.
|
||||
bool IsCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Check if the call node is the input of another call node.
|
||||
bool IsSubCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
|
@ -66,7 +69,7 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node);
|
|||
|
||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
// corresponding cnode.
|
||||
AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph);
|
||||
|
||||
// Get the funcgraph to which the node belongs.
|
||||
FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
||||
|
@ -75,7 +78,7 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node);
|
|||
std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Get parameters in kernel graph.
|
||||
std::vector<AnfNodePtr> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
|
||||
std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &graph);
|
||||
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
|
@ -205,6 +208,9 @@ class ControlNodeParser {
|
|||
// the input node of gather.
|
||||
FuncGraphToParameter func_graph_to_parameters_;
|
||||
|
||||
// The relationship between the valuenode inputs of the call node and the backend parameter
|
||||
std::map<KernelWithIndex, std::pair<AnfNodePtr, DeviceContext *>> call_node_to_backend_parameters_;
|
||||
|
||||
// Branch id of funcgraph.
|
||||
// In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to
|
||||
// different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch
|
||||
|
|
|
@ -1180,7 +1180,7 @@ std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompiler
|
|||
const auto &actor_name = control_node->DebugString();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
|
||||
control_node->cast<CNodePtr>(), branch_id, false);
|
||||
switch_actor->Initialize(graph_compiler_info.control_node_parser_);
|
||||
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
|
||||
|
||||
// Fetch all the input nodes of switch actor.
|
||||
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
|
||||
|
@ -1197,7 +1197,7 @@ std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompiler
|
|||
const auto &actor_name = return_node->DebugString();
|
||||
auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
|
||||
return_node->cast<CNodePtr>(), kInvalidBranchID, true);
|
||||
switch_actor->Initialize(graph_compiler_info.control_node_parser_);
|
||||
switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
|
||||
|
||||
// Fetch all the input nodes of switch actor.
|
||||
switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
|
||||
|
@ -1235,7 +1235,6 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
const auto &cnode = control_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto &return_node = func_graph->get_return();
|
||||
const auto &output_switch_aid = FetchActor(return_node->DebugString())->GetAID();
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
// Root funcgraph does not need to create a gather actor.
|
||||
|
@ -1245,21 +1244,26 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
}
|
||||
|
||||
// If the output of funcgraph is a value node, no need to create gather actor.
|
||||
if (inputs[kReturnInputPos]->isa<ValueNode>()) {
|
||||
if (inputs[kReturnInputPos]->isa<ValueNode>() ||
|
||||
AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto actor_name = func_graph->ToString();
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
for (const auto ¶meter : func_graph->get_inputs()) {
|
||||
if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
|
||||
continue;
|
||||
}
|
||||
parameters.emplace_back(parameter);
|
||||
parameters.push_back({parameter, 0});
|
||||
}
|
||||
|
||||
const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
|
||||
|
||||
const auto &output_switch_actor = FetchActor(return_node->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(output_switch_actor);
|
||||
const auto &output_switch_aid = output_switch_actor->GetAID();
|
||||
|
||||
auto gather_actor =
|
||||
std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
|
||||
gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
|
||||
|
@ -1275,12 +1279,12 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler
|
|||
|
||||
if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
|
||||
// Collect the parameters.
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<KernelWithIndex> parameters;
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
|
||||
continue;
|
||||
}
|
||||
parameters.emplace_back(inputs[i]);
|
||||
parameters.push_back({inputs[i], 0});
|
||||
}
|
||||
|
||||
auto func_graph = control_node->func_graph();
|
||||
|
@ -1322,9 +1326,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
auto front_node = GetFrontNodeByBackendNode(from_kernel);
|
||||
|
||||
if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
|
||||
if (HasAbstractRef(from_kernel)) {
|
||||
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
|
||||
const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
|
||||
const auto &real_front_node_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
|
||||
if (HasAbstractRef(real_front_node_with_index.first)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
|
||||
real_front_node_with_index.first.get());
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1332,9 +1339,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
const auto actor_name = graph->ToString();
|
||||
auto actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &real_front_node = GetFrontNodeByKernelGraph(from_kernel, graph);
|
||||
LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), real_front_node, to_actor,
|
||||
to_kernel_with_input_idx.second);
|
||||
LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), to_actor, real_front_node_with_index,
|
||||
to_kernel_with_input_idx);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1355,8 +1361,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf
|
|||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get());
|
||||
return;
|
||||
}
|
||||
|
||||
LinkDataArrowForGatherActor(from_actor, front_node, to_actor, to_kernel_with_input_idx.second);
|
||||
LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph, tensor, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.strategy_)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
|
@ -2053,25 +2058,84 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set) {
|
||||
for (const auto &node : graph_compiler_info.control_nodes_) {
|
||||
void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &node : control_nodes) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
const auto &from_func_graph = node->func_graph();
|
||||
auto inputs = cnode->inputs();
|
||||
|
||||
// Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
|
||||
if (inputs[0]->isa<CNode>()) {
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
// Add the input of call node to switch actor.
|
||||
if (IsCallNode(inputs[0])) {
|
||||
const auto &sub_call_cnode = inputs[0]->cast<CNodePtr>();
|
||||
const auto &sub_inputs = sub_call_cnode->inputs();
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(sub_inputs[0], prim::kPrimSwitchLayer)) {
|
||||
auto actor = FetchActor(sub_inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
}
|
||||
} else if (IsSubCallNode(cnode)) {
|
||||
// Add the input of sub call node to switch actor.
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
||||
const auto &tuple_node = inputs[0]->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
||||
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||
int pre_real_parameter_num = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) {
|
||||
pre_real_parameter_num = (tuple_inputs[i]->cast<CNodePtr>()->inputs().size() - kPartialInputStartPos);
|
||||
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]->cast<CNodePtr>()->input(kPartialFuncGraphPos));
|
||||
} else {
|
||||
func_graph = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
|
||||
}
|
||||
const auto parameters = func_graph->parameters();
|
||||
const auto &output = func_graph->output();
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) {
|
||||
const auto &sub_partial_inputs = output->cast<CNodePtr>()->inputs();
|
||||
|
||||
// Check whether the input node of the sub call node needs to be added to the switch actor. Only when
|
||||
// the final return is a partial node and the partial node needs this input, the input node is added
|
||||
// to the switch actor/
|
||||
for (size_t j = kPartialInputStartPos; j < sub_partial_inputs.size(); ++j) {
|
||||
const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[j], 0).first;
|
||||
const auto &iter = find(parameters.begin(), parameters.end(), real_partial_input);
|
||||
|
||||
if ((iter != parameters.end()) && (iter - parameters.begin() >= pre_real_parameter_num) &&
|
||||
(iter - parameters.begin() <
|
||||
SizeToInt(pre_real_parameter_num + inputs.size() - kCallInputStartPos))) {
|
||||
size_t pos = iter - parameters.begin() - pre_real_parameter_num + kCallInputStartPos;
|
||||
switch_actor->AddSingleInput(inputs[pos], i - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto actor = FetchActor(inputs[0]->DebugString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
switch_actor->AddCommonInput(inputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set) {
|
||||
PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
|
||||
|
||||
for (const auto &node : graph_compiler_info.control_nodes_) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
|
@ -2142,11 +2206,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
|
||||
for (const auto &input_node : gather_actor->data_nodes_) {
|
||||
for (const auto &input_with_index : gather_actor->data_nodes_) {
|
||||
const auto &from_func_graph = kernel_graph->GetFuncGraph();
|
||||
const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
|
||||
LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor,
|
||||
gather_actor->FetchDataNodePosition(input_node));
|
||||
gather_actor->FetchDataNodePosition(input_with_index));
|
||||
}
|
||||
}
|
||||
LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set);
|
||||
|
@ -2162,15 +2225,16 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi
|
|||
LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node,
|
||||
KernelActor *to_actor, const size_t to_index) {
|
||||
void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
|
||||
const KernelWithIndex &front_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
|
||||
auto position = from_actor->FetchDataNodePosition(front_node);
|
||||
auto position = from_actor->FetchDataNodePosition(front_node_with_index);
|
||||
|
||||
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
|
||||
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_datas_num_++;
|
||||
}
|
||||
|
@ -2181,13 +2245,18 @@ void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_w
|
|||
// Fetch all the funcgraph that call node would call.
|
||||
const auto cnode = call_node_with_index.first->cast<CNodePtr>();
|
||||
std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
|
||||
const auto &call_inputs = cnode->inputs();
|
||||
auto switch_node = call_inputs[0];
|
||||
if (IsCallNode(switch_node)) {
|
||||
switch_node = call_inputs[0]->cast<CNodePtr>()->input(0);
|
||||
}
|
||||
|
||||
// Collect the output of each funcgraph.
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
if (func_graph->output()->isa<ValueNode>()) {
|
||||
const auto &call_inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) {
|
||||
const auto &actor_name = call_inputs[0]->DebugString();
|
||||
if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &actor_name = switch_node->DebugString();
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto switch_actor = dynamic_cast<SwitchActor *>(actor);
|
||||
|
@ -2254,7 +2323,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const
|
|||
}
|
||||
for (size_t i = start_branch; i < max_branch; ++i) {
|
||||
if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
|
||||
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i;
|
||||
MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
|
||||
<< " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
|
||||
<< " to actor:" << to_actor->GetAID() << " to index:" << to_index;
|
||||
}
|
||||
auto op_arrow =
|
||||
std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
|
||||
|
@ -2277,7 +2348,7 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
} else if (IsGatherActor(input_node, actor_name_to_actor_)) {
|
||||
// The actor input is a parameter in gather actor.
|
||||
auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
|
||||
auto position = from_actor->FetchDataNodePosition(input_node);
|
||||
auto position = from_actor->FetchDataNodePosition({input_node, 0});
|
||||
auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
|
||||
from_actor->output_data_arrows_.emplace_back(op_arrow);
|
||||
} else if (IsSwitchActor(input_node)) {
|
||||
|
@ -2338,7 +2409,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c
|
|||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node);
|
||||
MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
|
||||
<< " to actor:" << to_actor->GetAID();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2559,65 +2631,6 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
OutputActor *to_actor = actor_set->output_actor_.get();
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
for (const auto &func_graph_to_branch_id : graph_compiler_info.control_node_parser_->func_graph_to_branch_id_) {
|
||||
if (func_graph_to_branch_id.second == kMainBranchID) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &func_graph = func_graph_to_branch_id.first;
|
||||
auto actor = FetchActor(func_graph->ToString());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
|
||||
for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
|
||||
const auto front_node = gather_actor->data_nodes_[i];
|
||||
auto origin_output_with_index = KernelWithIndex(front_node, 0);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &output_position : iter->second) {
|
||||
auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), output_position);
|
||||
gather_actor->output_result_arrows_.emplace_back(op_arrow);
|
||||
const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node];
|
||||
if (backend_nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No backend node for data node:" << AnfAlgo::GetNodeDebugString(front_node);
|
||||
}
|
||||
|
||||
const auto &backend_node = backend_nodes[0].first;
|
||||
if (backend_node->isa<Parameter>()) {
|
||||
std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
auto ds_op_actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(ds_op_actor);
|
||||
auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(ds_op_actor);
|
||||
MS_EXCEPTION_IF_NULL(host_ds_actor);
|
||||
|
||||
const auto &data_nodes = host_ds_actor->data_nodes_;
|
||||
const auto &node_iter = find(data_nodes.begin(), data_nodes.end(), backend_node);
|
||||
if (node_iter == data_nodes.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:"
|
||||
<< AnfAlgo::GetNodeDebugString(backend_node);
|
||||
}
|
||||
to_actor->device_contexts_[output_position] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()];
|
||||
} else {
|
||||
auto actor_base = FetchActor(backend_node->fullname_with_scope());
|
||||
MS_EXCEPTION_IF_NULL(actor_base);
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(actor_base);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
to_actor->device_contexts_[output_position] = kernel_actor->device_context_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
// Check the data source actors.
|
||||
|
@ -3098,7 +3111,7 @@ void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &of
|
|||
|
||||
ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
|
||||
for (const auto &node : actor->data_nodes_) {
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n';
|
||||
ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
|
||||
}
|
||||
|
||||
ofs << "\t\tactor front to backend node:\n";
|
||||
|
|
|
@ -233,8 +233,9 @@ class GraphScheduler {
|
|||
|
||||
// 4. The processing of control flow linking.
|
||||
void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set);
|
||||
void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor,
|
||||
const size_t to_index);
|
||||
void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor,
|
||||
const KernelWithIndex &front_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index);
|
||||
void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor);
|
||||
// Connect the input of the actor.
|
||||
void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node,
|
||||
|
@ -263,6 +264,9 @@ class GraphScheduler {
|
|||
const ControlNodeParserPtr &control_node_parser,
|
||||
const std::vector<AnfNodePtr> &origin_parameters,
|
||||
const std::vector<TensorPtr> &tensors, std::vector<TensorPtr> *host_tensors);
|
||||
// Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added
|
||||
// to switch actor.
|
||||
void PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
// The processing of actors link dynamically.
|
||||
// Analyze necessary input data of current actor, generate and cache op arrow
|
||||
|
|
Loading…
Reference in New Issue