!26752 control flow support call to call
Merge pull request !26752 from gaoyong10/runtime_second12
This commit is contained in:
commit
1b0a82fc30
|
@ -208,9 +208,14 @@ void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
// Send Partial.
|
||||
for (const auto &partial_arrow : output_partial_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(partial_arrow);
|
||||
MS_EXCEPTION_IF_NULL(output_partial_.first);
|
||||
ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial_.first,
|
||||
output_partial_.second, IntToSize(partial_arrow->to_input_index_), context);
|
||||
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
|
||||
MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_
|
||||
<< " current:" << input_partials_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
auto output_partial = input_partials_[partial_arrow->from_output_index_];
|
||||
MS_EXCEPTION_IF_NULL(output_partial.first);
|
||||
ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first,
|
||||
output_partial.second, IntToSize(partial_arrow->to_input_index_), context);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
|
|
|
@ -72,6 +72,22 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
ActorDispatcher::Send(control_arrow, &OpActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
// 3.Send output partial in output branch.
|
||||
const auto &partial_iter = output_branch_partial_arrows_.find(output_branch_id_);
|
||||
if (partial_iter != output_branch_partial_arrows_.end()) {
|
||||
for (const auto &partial_arrow : partial_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(partial_arrow);
|
||||
if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) {
|
||||
MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_
|
||||
<< " current:" << input_partials_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
auto output_partial = input_partials_[partial_arrow->from_output_index_];
|
||||
MS_EXCEPTION_IF_NULL(output_partial.first);
|
||||
Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first, output_partial.second,
|
||||
IntToSize(partial_arrow->to_input_index_), context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ExitActor::CopyDeviceAddress() {
|
||||
|
|
|
@ -29,23 +29,22 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
ControlActor::FetchInput(context);
|
||||
output_partial_ = input_partials_[0];
|
||||
MS_EXCEPTION_IF_NULL(output_partial_.first);
|
||||
MS_EXCEPTION_IF_NULL(input_partials_[0].first);
|
||||
|
||||
// Put other real parameter in partial.
|
||||
for (const auto &device_tensor : input_device_tensors_) {
|
||||
if (device_tensor != nullptr) {
|
||||
output_partial_.second.emplace_back(device_tensor);
|
||||
input_partials_[0].second.emplace_back(device_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
||||
// Send data with branch id.
|
||||
const auto &iter = output_data_with_branch_id_arrows_.find(output_partial_.first);
|
||||
const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0].first);
|
||||
if (iter != output_data_with_branch_id_arrows_.end()) {
|
||||
for (const auto &data_with_branch_id_arrow : iter->second) {
|
||||
ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, output_partial_.second,
|
||||
ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, input_partials_[0].second,
|
||||
output_branch_id_, context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,25 +27,56 @@ StackActor::StackActor(const std::string &name, const std::vector<KernelWithInde
|
|||
|
||||
void StackActor::Init() {
|
||||
ControlActor::Init();
|
||||
for (const auto &formal_parameter : formal_parameters_) {
|
||||
if (AnfAlgo::IsCallNode(formal_parameter.first)) {
|
||||
break;
|
||||
}
|
||||
++input_parameter_data_num_;
|
||||
}
|
||||
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_;
|
||||
if (input_parameter_data_num_ < device_tensor_store_keys_.size()) {
|
||||
// The stack actor has 6 parts of input :
|
||||
// 1. Directly input data.
|
||||
// 2. Direct input partial.
|
||||
// 3. Weight.
|
||||
// 4. Local tensor.
|
||||
// 5. Call input data.
|
||||
// 6. Call input partial.
|
||||
input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_ - input_parameter_partial_num_;
|
||||
if (input_parameter_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_parameter_data_num_
|
||||
<< " device store num:" << device_tensor_store_keys_.size() << " for actor:" << GetAID();
|
||||
<< " device store num:" << device_tensor_store_keys_.size() << " local device tensor num"
|
||||
<< local_device_tensors_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
// Fetch the total number of input partial.
|
||||
int total_partials_num = 0;
|
||||
for (const auto &formal_parameter : formal_parameters_) {
|
||||
const auto &abstract = formal_parameter.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
total_partials_num++;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch call input data num.
|
||||
input_datas_num_ = formal_parameters_.size() - total_partials_num - input_parameter_data_num_;
|
||||
input_partials_num_ = total_partials_num - input_parameter_partial_num_;
|
||||
// Fetch call input partial num.
|
||||
input_parameter_data_num_ -= (device_tensor_store_keys_.size() + local_device_tensors_.size());
|
||||
// Check if the input num is valid.
|
||||
if (input_parameter_data_num_ + input_parameter_partial_num_ + input_datas_num_ + input_partials_num_ +
|
||||
device_tensor_store_keys_.size() + local_device_tensors_.size() !=
|
||||
formal_parameters_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_parameter_data_num_
|
||||
<< " input parameter partial num:" << input_parameter_partial_num_
|
||||
<< " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_
|
||||
<< " device tensor store size:" << device_tensor_store_keys_.size()
|
||||
<< " need total size:" << formal_parameters_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
input_parameter_data_num_ -= device_tensor_store_keys_.size();
|
||||
}
|
||||
|
||||
void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
// The parameters from the inside of the subgraph need to be put into the stack.
|
||||
if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size()) {
|
||||
if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size() +
|
||||
input_parameter_partial_num_ + local_device_tensors_.size()) {
|
||||
FillStack(input_data, context);
|
||||
} else {
|
||||
// The outputs of call nodes are placed directly in the input data.
|
||||
|
@ -56,6 +87,22 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev
|
|||
}
|
||||
}
|
||||
|
||||
void StackActor::RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// The parameters from the inside of the subgraph need to be put into the stack.
|
||||
if (IntToSize(position) < input_parameter_data_num_ + device_tensor_store_keys_.size() +
|
||||
input_parameter_partial_num_ + local_device_tensors_.size()) {
|
||||
input_parameter_partial_[context->sequential_num_][position].push(OpPartial(func_graph, input_data));
|
||||
} else {
|
||||
input_op_partials_[context->sequential_num_].emplace_back(position, OpPartial(func_graph, input_data));
|
||||
}
|
||||
if (CheckRunningCondition(context)) {
|
||||
Run(context);
|
||||
}
|
||||
}
|
||||
|
||||
void StackActor::FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
|
@ -122,6 +169,26 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_parameter_partial_num_ != 0) {
|
||||
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
|
||||
if (partial_iter == input_parameter_partial_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (partial_iter->second.size() != input_parameter_partial_num_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = input_branch_ids_.find(context->sequential_num_);
|
||||
if (iter == input_branch_ids_.end() || iter->second.empty()) {
|
||||
MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID();
|
||||
}
|
||||
size_t branch_id_size = iter->second.size();
|
||||
if (std::any_of(partial_iter->second.begin(), partial_iter->second.end(),
|
||||
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -133,13 +200,29 @@ void StackActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
|
||||
}
|
||||
for (const auto &one_stack : data_iter->second) {
|
||||
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size()) {
|
||||
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() +
|
||||
local_device_tensors_.size() + input_parameter_partial_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
input_device_tensors_[one_stack.first] = one_stack.second.top();
|
||||
}
|
||||
}
|
||||
|
||||
if (input_parameter_partial_num_ != 0) {
|
||||
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
|
||||
if (partial_iter == input_parameter_partial_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
|
||||
}
|
||||
for (const auto &one_stack : partial_iter->second) {
|
||||
if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() +
|
||||
local_device_tensors_.size() + input_parameter_partial_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_partial_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
input_partials_[one_stack.first] = one_stack.second.top();
|
||||
}
|
||||
}
|
||||
ControlActor::FetchInput(context);
|
||||
}
|
||||
|
||||
|
@ -160,6 +243,20 @@ void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) {
|
|||
one_stack.second.pop();
|
||||
}
|
||||
}
|
||||
|
||||
if (input_parameter_partial_num_ != 0) {
|
||||
const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_);
|
||||
if (partial_iter == input_parameter_partial_.end()) {
|
||||
MS_LOG(ERROR) << "Invalid input for actor:" << GetAID();
|
||||
}
|
||||
|
||||
for (auto &one_stack : partial_iter->second) {
|
||||
if (one_stack.second.empty()) {
|
||||
MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID();
|
||||
}
|
||||
one_stack.second.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,9 +39,11 @@ class StackActor : public ControlActor {
|
|||
~StackActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
// The input data of the stack actor needs to be pushed into the stack according to the input index, so it is
|
||||
// implemented separately.
|
||||
// The input data and partial of the stack actor needs to be pushed into the stack according to the input index,
|
||||
// so it is implemented separately.
|
||||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
void RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position,
|
||||
OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
protected:
|
||||
void FetchInput(OpContext<DeviceTensor> *const context) override;
|
||||
|
@ -56,12 +58,15 @@ class StackActor : public ControlActor {
|
|||
// The device tensors created and stored by the stack.
|
||||
std::vector<DeviceTensorPtr> created_device_tensors_;
|
||||
|
||||
// The input data records that the stack actor is copied from the input nodes and needs to be stored in the
|
||||
// device tensor in the stack.
|
||||
// The input data and partials records that the stack actor is copied from the input nodes and needs to be
|
||||
// stored in the device tensor in the stack.
|
||||
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_parameter_data_;
|
||||
// Input parameter data num represents the number of actor's input come from funcgraph itself, these inputs
|
||||
// will be ranked at the front of input.
|
||||
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<OpPartial>>> input_parameter_partial_;
|
||||
|
||||
// Input parameter num represents the number of actor's input come from funcgraph itself, these inputs will
|
||||
// be ranked at the front of input.
|
||||
size_t input_parameter_data_num_{0};
|
||||
size_t input_parameter_partial_num_{0};
|
||||
// The backend parameter is used to save the backend node corresponding to the device tensor in the stack.
|
||||
// When these device tensors are used as output, they need to be placed in the node of the result arrow,
|
||||
// so these nodes need to be saved.
|
||||
|
|
|
@ -54,7 +54,7 @@ void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
if (!output_partial_arrows_.empty()) {
|
||||
auto func_graph = input_partials_[index + kSwitchCondPos].first;
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
output_partial_ = input_partials_[index + kSwitchCondPos];
|
||||
input_partials_[0] = input_partials_[index + kSwitchCondPos];
|
||||
}
|
||||
|
||||
for (auto &output_data : output_data_by_output_index_[kSwitchDefaultOutputNum - 1]) {
|
||||
|
|
|
@ -261,7 +261,7 @@ size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node)
|
|||
MS_EXCEPTION_IF_NULL(data_node);
|
||||
const auto &iter = data_node_position_map_.find(data_node);
|
||||
if (iter == data_node_position_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist.";
|
||||
MS_LOG(EXCEPTION) << "Data node: " << data_node->DebugString() << " is not exist.";
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
|
|
@ -261,6 +261,65 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index
|
|||
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address;
|
||||
AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get());
|
||||
}
|
||||
|
||||
// Check if there is a recursive call to funcgraph, if a calls b, b calls c, and c calls a, it is a recursive call.
|
||||
bool IsRecursionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *checked_funcgraphs,
|
||||
const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &func_graph_call_relation) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) {
|
||||
return true;
|
||||
}
|
||||
checked_funcgraphs->emplace(func_graph);
|
||||
auto iter = func_graph_call_relation.find(func_graph);
|
||||
if (iter == func_graph_call_relation.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto &called_func_graph : iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(called_func_graph);
|
||||
if (IsRecursionFunction(called_func_graph, checked_funcgraphs, func_graph_call_relation)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fetch all inputs of node.
|
||||
std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||
// The node is divided into the following types:
|
||||
// 1. depend and load.
|
||||
const auto &node_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
auto real_node = node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
std::vector<KernelWithIndex> results;
|
||||
// 2. MakeTuple.
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
|
||||
const auto &make_tuple_cnode = real_node->cast<CNodePtr>();
|
||||
const auto &make_tuple_inputs = make_tuple_cnode->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < make_tuple_inputs.size(); ++i) {
|
||||
const auto &sub_results = FetchInputNodeByNode(make_tuple_inputs[i]);
|
||||
results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// 3. One output node.
|
||||
const auto &abstract = real_node->abstract();
|
||||
if (abstract == nullptr || (!abstract->isa<abstract::AbstractTuple>())) {
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
|
||||
}
|
||||
return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)};
|
||||
}
|
||||
|
||||
// 4. Abstract is Tuple.
|
||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
results.emplace_back(real_node, i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
|
@ -285,6 +344,72 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, Kernel
|
|||
return front_node_with_index;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> results;
|
||||
// The first input of normal cnode is the primitive of node, and the real input starts from the second input,
|
||||
// but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
|
||||
size_t input_start_pos = kCNodeInputStartPos;
|
||||
if (AnfAlgo::IsCallNode(node)) {
|
||||
input_start_pos = 0;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto inputs = cnode->inputs();
|
||||
|
||||
// The first branch of the input of the switch node is the true branch, and the second is the false branch.
|
||||
// But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
|
||||
// of the switch node needs to exchange the positions of the two branches. So deal separately.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
|
||||
}
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
|
||||
return results;
|
||||
}
|
||||
|
||||
for (size_t i = input_start_pos; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
// skip monad node.
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
const auto &sub_results = FetchInputNodeByNode(inputs[i]);
|
||||
results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!abstract->isa<abstract::AbstractTuple>()) {
|
||||
if (index != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
return abstract;
|
||||
}
|
||||
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
size_t real_index = index;
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
size_t tmp_index = AnfAlgo::GetOutputNumByAbstract(sub_abstract);
|
||||
if (real_index >= tmp_index) {
|
||||
real_index -= tmp_index;
|
||||
continue;
|
||||
}
|
||||
return FetchAbstractByIndex(sub_abstract, real_index);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph,
|
||||
const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) {
|
||||
|
@ -309,12 +434,18 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
ParseCallNodeToFuncGraph(control_nodes);
|
||||
|
||||
FetchNeedStackControlNode(control_nodes);
|
||||
|
||||
ParseUnRecursionCallNode();
|
||||
|
||||
FetchFrontNodeToKernelGraph(graphs);
|
||||
|
||||
ParseFormalToRealParameter(control_nodes);
|
||||
|
||||
ParseFrontToBackendParameter(graphs, device_contexts);
|
||||
|
||||
CreateDeviceTensorForRootGraphParameter(device_contexts[0]);
|
||||
|
||||
FetchFrontToBackendKernel(graphs, device_contexts);
|
||||
|
||||
FetchHostParameterToWeight();
|
||||
|
@ -323,7 +454,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
|||
|
||||
ParseDeviceContext(control_nodes, graphs, device_contexts, func_graph_to_kernel_graphs);
|
||||
|
||||
FetchFrontValueNode(device_contexts[0]);
|
||||
FetchFrontValueNode(control_nodes, device_contexts[0]);
|
||||
|
||||
FetchControlNodeParameter(control_nodes);
|
||||
|
||||
|
@ -363,6 +494,11 @@ bool ControlNodeParser::IsRootGraphParameter(const AnfNodePtr &node) {
|
|||
return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end();
|
||||
}
|
||||
|
||||
bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) == unrecursion_call_nodes_.end();
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<KernelGraphPtr> &kernel_graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
|
@ -604,6 +740,16 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
|
|||
|
||||
void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
|
||||
for (const auto &graph : graphs) {
|
||||
if (graph->execution_order().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &kernel : graph->execution_order()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
|
||||
if (front_node != nullptr) {
|
||||
front_node_to_kernel_graph_[front_node] = graph;
|
||||
}
|
||||
}
|
||||
const auto &graph_outputs = graph->graph_output_map();
|
||||
for (const auto &backend_to_front : graph_outputs) {
|
||||
front_node_to_kernel_graph_[backend_to_front.second.first] = graph;
|
||||
|
@ -654,7 +800,8 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
|
||||
void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes,
|
||||
DeviceContext *default_context) {
|
||||
MS_EXCEPTION_IF_NULL(default_context);
|
||||
|
||||
for (const auto &formal_to_real_parameter : formal_to_real_parameters_) {
|
||||
|
@ -688,6 +835,27 @@ void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) {
|
|||
front_value_nodes_.emplace(front_to_backend_parameters.first, device_context);
|
||||
}
|
||||
}
|
||||
|
||||
// Create device tensors for those value nodes which direct return by a return node.
|
||||
for (const auto &control_node : control_nodes) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_with_indexs = FetchInputNodeByCNode(control_node);
|
||||
auto iter = control_node_to_device_contexts_.find(control_node);
|
||||
if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString();
|
||||
}
|
||||
for (size_t i = 0; i < input_with_indexs.size(); ++i) {
|
||||
const auto &input_with_index = input_with_indexs[i];
|
||||
if (input_with_index.first->isa<ValueNode>() && (!IsValueNode<FuncGraph>(input_with_index.first)) &&
|
||||
front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) {
|
||||
CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]);
|
||||
front_value_nodes_.emplace(input_with_index, iter->second[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
|
@ -977,7 +1145,9 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP
|
|||
|
||||
const auto graph_output_map = graph->graph_output_map();
|
||||
for (const auto &output_pair : graph_output_map) {
|
||||
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
|
||||
if (output_pair.first.first->isa<CNode>()) {
|
||||
front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1039,5 +1209,57 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::ParseUnRecursionCallNode() {
|
||||
std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> func_graph_call_relation;
|
||||
// Collect the call relationship between funcgraphs.
|
||||
for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
|
||||
const auto &call_node = call_node_to_func_graphs.first;
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
const auto &func_graph = call_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph_call_relation[func_graph].insert(call_node_to_func_graphs.second.begin(),
|
||||
call_node_to_func_graphs.second.end());
|
||||
}
|
||||
|
||||
for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) {
|
||||
const auto &call_node = call_node_to_func_graphs.first;
|
||||
std::set<FuncGraphPtr> checked_func_graphs{call_node->func_graph()};
|
||||
bool is_recursion_call_node = false;
|
||||
if (std::any_of(call_node_to_func_graphs.second.begin(), call_node_to_func_graphs.second.end(),
|
||||
[&is_recursion_call_node, &checked_func_graphs, &func_graph_call_relation](const auto &func_graph) {
|
||||
return IsRecursionFunction(func_graph, &checked_func_graphs, func_graph_call_relation);
|
||||
})) {
|
||||
is_recursion_call_node = true;
|
||||
}
|
||||
if (!is_recursion_call_node && need_stack_control_nodes_.find(call_node) == need_stack_control_nodes_.end()) {
|
||||
unrecursion_call_nodes_.emplace(call_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(control_node);
|
||||
if (AnfAlgo::IsCallNode(control_node)) {
|
||||
auto input_with_indexs = FetchInputNodeByCNode(control_node);
|
||||
if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(),
|
||||
[](const auto &input_with_index) { return AnfAlgo::IsCallNode(input_with_index.first); })) {
|
||||
need_stack_control_nodes_.emplace(control_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) {
|
||||
MS_EXCEPTION_IF_NULL(default_context);
|
||||
for (const auto ¶meter : root_graph_parameters_) {
|
||||
KernelWithIndex parameter_with_index(parameter, 0);
|
||||
if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) {
|
||||
CreateDeviceTensorForFrontNode(parameter_with_index, default_context);
|
||||
front_to_backend_parameters_[parameter_with_index].emplace(parameter, default_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <queue>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "utils/hash_map.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
@ -76,7 +77,10 @@ bool HasAbstractRef(const AnfNodePtr &node);
|
|||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
// corresponding cnode.
|
||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph);
|
||||
|
||||
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
|
||||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
|
||||
// Fetch the sub abstract from the top abstract by the index.
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
|
@ -94,6 +98,7 @@ class ControlNodeParser {
|
|||
// 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor.
|
||||
bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node);
|
||||
bool IsRootGraphParameter(const AnfNodePtr &node);
|
||||
bool IsRecursionCallNode(const AnfNodePtr &node);
|
||||
|
||||
const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; }
|
||||
const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; }
|
||||
|
@ -117,7 +122,7 @@ class ControlNodeParser {
|
|||
// value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for
|
||||
// them separately during initialization.
|
||||
// The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to.
|
||||
void FetchFrontValueNode(DeviceContext *default_context);
|
||||
void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, DeviceContext *default_context);
|
||||
// Create branch id for all call node in the control flow.
|
||||
void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
|
||||
|
@ -138,6 +143,8 @@ class ControlNodeParser {
|
|||
const FormalToRealParameter &formal_to_real_parameters,
|
||||
std::set<KernelWithIndex> *total_real_parameters,
|
||||
std::set<AnfNodePtr> *invalid_real_parameter);
|
||||
// Get all the call nodes without a recursion call relation.
|
||||
void ParseUnRecursionCallNode();
|
||||
|
||||
// Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be
|
||||
// copied between different device memories. The analysis steps:
|
||||
|
@ -179,7 +186,12 @@ class ControlNodeParser {
|
|||
void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// Fetch the formal parameter in root graph by parameters in subgraph.
|
||||
AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node);
|
||||
|
||||
// Get the control nodes which need to add a stack actor for them.
|
||||
// When a control node has input that is a call node, you need to add a stack actor for it.
|
||||
void FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes);
|
||||
// When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device
|
||||
// tensor needs to be created for it.
|
||||
void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context);
|
||||
// In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded
|
||||
// in this map.
|
||||
FuncGraphToKernelGraph func_graph_to_kernel_graphs_;
|
||||
|
@ -220,6 +232,12 @@ class ControlNodeParser {
|
|||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
||||
// Control nodes without a control node input in the topological sorting of funcgraph.
|
||||
mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
|
||||
// Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
|
||||
// belong.
|
||||
std::set<AnfNodePtr> unrecursion_call_nodes_;
|
||||
// Those control nodes that need to create the corresponding stack actor, when there is a call node in the inputs
|
||||
// of the control node, the stack actor is needed to collect these inputs.
|
||||
std::set<AnfNodePtr> need_stack_control_nodes_;
|
||||
|
||||
// In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding
|
||||
// to the type needs to be parsed in advance so that it can add some copy operation in the scheduler.
|
||||
|
|
|
@ -35,66 +35,41 @@ std::string GetActorName(const AnfNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
|
||||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return {};
|
||||
// Fetch the depend nodes according to the monad node.
|
||||
void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *depend_nodes) {
|
||||
// Find the real input node, include the monad node and make tuple node.
|
||||
const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
|
||||
prim::kPrimMakeTuple};
|
||||
const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_types);
|
||||
auto real_node = node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
if (!real_node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> results;
|
||||
// The first input of normal cnode is the primitive of node, and the real input starts from the second input,
|
||||
// but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial.
|
||||
size_t input_start_pos = kCNodeInputStartPos;
|
||||
if (AnfAlgo::IsCallNode(node)) {
|
||||
input_start_pos = 0;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto inputs = cnode->inputs();
|
||||
const auto &real_cnode = real_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_cnode);
|
||||
const auto &real_inputs = real_cnode->inputs();
|
||||
|
||||
// The first branch of the input of the switch node is the true branch, and the second is the false branch.
|
||||
// But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input
|
||||
// of the switch node needs to exchange the positions of the two branches. So deal separately.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||
if (inputs.size() != kSwitchInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString();
|
||||
// Make tuple node needs to be expanded.
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) {
|
||||
for (size_t i = 1; i < real_inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(real_inputs[i]);
|
||||
FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
|
||||
}
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0));
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0));
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0));
|
||||
return results;
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = input_start_pos; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
// skip monad node.
|
||||
if (HasAbstractMonad(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &node_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(inputs[i], 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node_with_index.first);
|
||||
for (size_t j = 0; j < output_num; ++j) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) {
|
||||
const auto &make_tuple_cnode = node_with_index.first->cast<CNodePtr>();
|
||||
const auto &make_tuple_inputs = make_tuple_cnode->inputs();
|
||||
if (make_tuple_inputs.size() <= j + kMakeTupleInputStartPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input:" << j + kMakeTupleInputStartPos
|
||||
<< " for make tuple node:" << make_tuple_cnode->DebugString();
|
||||
}
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[j + kMakeTupleInputStartPos], 0));
|
||||
} else if (node_with_index.first->isa<ValueNode>()) {
|
||||
// When the value node is a value tuple, the value node will have multiple outputs, which need to be directly
|
||||
// put into the vector, and the output cannot be obtained through the VisitKernelWithReturnType interface.
|
||||
results.emplace_back(node_with_index.first, j);
|
||||
} else {
|
||||
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(node_with_index.first, j));
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) {
|
||||
FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) {
|
||||
for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) {
|
||||
FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);
|
||||
}
|
||||
} else {
|
||||
depend_nodes->emplace(real_node);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -283,6 +258,8 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
// Create a corresponding stack actor for each kernel graph that has a call node as input.
|
||||
for (const auto &graph_with_context : parser->call_input_kernel_graphs_) {
|
||||
std::vector<KernelWithIndex> formal_parameters;
|
||||
size_t input_parameter_data_num = 0;
|
||||
|
||||
const auto &graph = graph_with_context.first;
|
||||
const auto &device_context = graph_with_context.second;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -303,10 +280,12 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
}
|
||||
|
||||
// If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
|
||||
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
|
||||
if (AnfAlgo::IsCallNode(front_node_with_index.first) &&
|
||||
(parser->IsRecursionCallNode(front_node_with_index.first))) {
|
||||
formal_parameters.emplace_back(front_node_with_index);
|
||||
} else {
|
||||
formal_parameters.insert(formal_parameters.begin(), front_node_with_index);
|
||||
input_parameter_data_num++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -315,11 +294,79 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp
|
|||
stack_actors.emplace_back(stack_actor);
|
||||
stack_actor->device_contexts_.insert(stack_actor->device_contexts_.begin(), formal_parameters.size(),
|
||||
device_context);
|
||||
stack_actor->input_parameter_data_num_ = input_parameter_data_num;
|
||||
InsertActor(stack_actor.get());
|
||||
}
|
||||
// Create stack actors for control nodes.
|
||||
BuildStackActorForControlNode(graph_compiler_info, &stack_actors);
|
||||
|
||||
return stack_actors;
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
|
||||
std::vector<StackActorPtr> *stack_actors) {
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
|
||||
for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
|
||||
MS_EXCEPTION_IF_NULL(need_stack_control_node);
|
||||
if (!AnfAlgo::IsCallNode(need_stack_control_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> formal_parameters;
|
||||
std::vector<const DeviceContext *> device_contexts;
|
||||
size_t input_parameter_data_num = 0;
|
||||
size_t input_parameter_partials_num = 0;
|
||||
|
||||
// Fetch the control actor of control node.
|
||||
auto gather_actor_name = GetActorName(need_stack_control_node);
|
||||
auto actor = FetchActor(gather_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto gather_actor = dynamic_cast<GatherActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor);
|
||||
if (gather_actor->formal_parameters_.size() > gather_actor->device_contexts_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid device context size:" << gather_actor->device_contexts_.size()
|
||||
<< " and formal parameter size:" << gather_actor->formal_parameters_.size()
|
||||
<< " for actor:" << gather_actor->GetAID();
|
||||
}
|
||||
|
||||
// Collect formal parameters and device contexts, skip the value nodes.
|
||||
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
||||
const auto ¶meter = gather_actor->formal_parameters_[i];
|
||||
auto device_context = gather_actor->device_contexts_[i];
|
||||
if (AnfAlgo::IsCallNode(parameter.first) && (parser->IsRecursionCallNode(parameter.first))) {
|
||||
formal_parameters.emplace_back(parameter);
|
||||
device_contexts.emplace_back(device_context);
|
||||
} else if (parameter.first->isa<ValueNode>()) {
|
||||
continue;
|
||||
} else {
|
||||
formal_parameters.insert(formal_parameters.begin(), parameter);
|
||||
device_contexts.insert(device_contexts.begin(), device_context);
|
||||
|
||||
const auto &abstract = parameter.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
input_parameter_partials_num++;
|
||||
} else {
|
||||
input_parameter_data_num++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create stack actor.
|
||||
const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
|
||||
const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, formal_parameters);
|
||||
stack_actor->device_contexts_ = device_contexts;
|
||||
stack_actor->input_parameter_data_num_ = input_parameter_data_num;
|
||||
stack_actor->input_parameter_partial_num_ = input_parameter_partials_num;
|
||||
|
||||
InsertActor(stack_actor.get());
|
||||
stack_actors->emplace_back(stack_actor);
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
|
||||
|
@ -362,11 +409,18 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
}
|
||||
|
||||
for (auto &gather_actor : control_actor_set->gather_actors_) {
|
||||
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
|
||||
parser);
|
||||
MS_EXCEPTION_IF_NULL(gather_actor->node_);
|
||||
if (parser->need_stack_control_nodes_.find(gather_actor->node_) == parser->need_stack_control_nodes_.end()) {
|
||||
for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
|
||||
LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
|
||||
parser);
|
||||
}
|
||||
} else {
|
||||
// If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
|
||||
LinkArrowFromStackActor(gather_actor.get());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &entrance_actor : control_actor_set->entrance_actors_) {
|
||||
for (const auto &call_node : entrance_actor->call_nodes_) {
|
||||
LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser);
|
||||
|
@ -387,6 +441,38 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr
|
|||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowFromStackActor(ControlActor *to_actor) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor->node_);
|
||||
auto stack_actor_name = GetActorName(to_actor->node_) + kStackActorNameSuffix;
|
||||
auto actor = FetchActor(stack_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto stack_actor = dynamic_cast<StackActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(stack_actor);
|
||||
|
||||
for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
|
||||
const auto &formal_parameter = to_actor->formal_parameters_[to_index];
|
||||
const auto &from_node = formal_parameter.first;
|
||||
if (from_node->isa<ValueNode>()) {
|
||||
LinkArrowByValueNode(from_node, to_actor, formal_parameter.second, to_index);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch the arrow type of input.
|
||||
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
||||
const auto &abstract = formal_parameter.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
|
||||
// Link arrow according to abstract.
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
LinkPartialArrow(stack_actor, to_actor, from_index, to_index);
|
||||
} else {
|
||||
LinkDataArrow(stack_actor, to_actor, from_index, to_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index,
|
||||
|
@ -394,20 +480,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
const auto &from_node = from_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
if (from_node->isa<ValueNode>()) {
|
||||
if (IsValueNode<FuncGraph>(from_node)) {
|
||||
// Link local partial.
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(from_node);
|
||||
to_actor->local_partials_[to_node_with_index.second] = OpPartial(func_graph.get(), {});
|
||||
} else {
|
||||
// Link device store value node.
|
||||
if (!AnfAlgo::OutputAddrExist(from_node, from_node_with_index.second)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_node_with_index.second
|
||||
<< " for value node:" << from_node->DebugString();
|
||||
}
|
||||
to_actor->local_device_tensors_[to_node_with_index.second] =
|
||||
AnfAlgo::GetMutableOutputAddr(from_node, from_node_with_index.second, false).get();
|
||||
to_actor->local_device_tensors_[to_node_with_index.second]->SetNodeIndex(from_node, from_node_with_index.second);
|
||||
}
|
||||
LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second);
|
||||
} else if (from_node->isa<Parameter>()) {
|
||||
LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser);
|
||||
} else if (AnfAlgo::IsCallNode(from_node)) {
|
||||
|
@ -436,6 +509,26 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
|
|||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor,
|
||||
size_t from_index, size_t to_index) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
// Link local partial.
|
||||
const auto &func_graph = GetValueNode<FuncGraphPtr>(value_node);
|
||||
to_actor->local_partials_[to_index] = OpPartial(func_graph.get(), {});
|
||||
} else {
|
||||
// Link device store value node.
|
||||
if (!AnfAlgo::OutputAddrExist(value_node, from_index)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output address index:" << from_index
|
||||
<< " for value node:" << value_node->DebugString();
|
||||
}
|
||||
to_actor->local_device_tensors_[to_index] = AnfAlgo::GetMutableOutputAddr(value_node, from_index, false).get();
|
||||
to_actor->local_device_tensors_[to_index]->SetNodeIndex(value_node, from_index);
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index,
|
||||
const KernelWithIndex &to_node_with_index,
|
||||
|
@ -467,6 +560,11 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
|
||||
if (to_actor->type_ != KernelTransformType::kEntranceActor) {
|
||||
// Link arrow from exit actor to control actor.
|
||||
const auto &abstract = call_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
|
||||
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(from_node);
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -475,10 +573,19 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto exit_actor = dynamic_cast<ExitActor *>(actor);
|
||||
size_t branch_id = parser->FetchBranchIDByCallNode(from_node);
|
||||
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
|
||||
branch_id);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
LinkPartialArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
|
||||
branch_id);
|
||||
} else {
|
||||
LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second,
|
||||
branch_id);
|
||||
}
|
||||
}
|
||||
if (abstract->isa<abstract::AbstractFunction>()) {
|
||||
to_actor->input_partials_num_++;
|
||||
} else {
|
||||
to_actor->input_datas_num_++;
|
||||
}
|
||||
to_actor->input_datas_num_++;
|
||||
} else {
|
||||
// Link arrow from gather actor to entrance actor.
|
||||
const auto &actor_name = GetActorName(from_node);
|
||||
|
@ -604,6 +711,34 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
LinkControlArrow(entrance_actor, control_actor);
|
||||
}
|
||||
}
|
||||
|
||||
// Link auto monad control arrow for control actor.
|
||||
std::vector<ControlActor *> control_actors;
|
||||
(void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
|
||||
std::back_inserter(control_actors), [](auto &switch_actor) { return switch_actor.get(); });
|
||||
(void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
|
||||
std::back_inserter(control_actors), [](auto &gather_actor) { return gather_actor.get(); });
|
||||
(void)std::transform(control_actor_set->exit_actors_.begin(), control_actor_set->exit_actors_.end(),
|
||||
std::back_inserter(control_actors), [](auto &exit_actor) { return exit_actor.get(); });
|
||||
for (auto control_actor : control_actors) {
|
||||
MS_EXCEPTION_IF_NULL(control_actor);
|
||||
const auto &node = control_actor->node_;
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
for (const auto &input : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) {
|
||||
LinkControlArrowByAutoMonad(control_actor, input, parser);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
|
||||
|
@ -613,6 +748,13 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
|
|||
|
||||
// Link control arrow from entrance actors or stack actors to no input kernel actors.
|
||||
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
|
||||
// In control flow, when the input of the kernel actor is a parameter, this input needs to be linked to the
|
||||
// control actor, so the no-input kernel actor collected in the graph scheduler will also collect this actor,
|
||||
// and it needs to be skipped here.
|
||||
if ((no_input_kernel_actor->input_datas_num_ != 0) || (no_input_kernel_actor->input_controls_num_ != 0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
KernelGraphPtr kernel_graph = nullptr;
|
||||
if (no_input_kernel_actor->type_ == KernelTransformType::kSuperKernelActor) {
|
||||
const auto &super_kernel_actor = dynamic_cast<SuperKernelActor *>(no_input_kernel_actor.get());
|
||||
|
@ -665,6 +807,46 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_
|
|||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
|
||||
const ControlNodeParserPtr &parser) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
|
||||
std::set<AnfNodePtr> depend_nodes;
|
||||
FetchRealDependNodeByAutoMonad(from_node, &depend_nodes);
|
||||
|
||||
for (const auto &depend_node : depend_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto from_actor = FetchActor(depend_node->DebugString());
|
||||
if (AnfAlgo::IsCallNode(depend_node)) {
|
||||
int branch_id = parser->FetchBranchIDByCallNode(depend_node);
|
||||
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(depend_node);
|
||||
if (func_graphs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get funcgraph by call node:" << depend_node->DebugString();
|
||||
}
|
||||
for (const auto func_graph : func_graphs) {
|
||||
auto exit_actor_name = func_graph->ToString() + kExitActorNameSuffix;
|
||||
auto actor = FetchActor(exit_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto exit_actor = dynamic_cast<ExitActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
LinkControlArrowForExitActor(exit_actor, to_actor, branch_id);
|
||||
}
|
||||
to_actor->input_controls_num_ -= (func_graphs.size() - 1);
|
||||
} else if (from_actor != nullptr) {
|
||||
LinkControlArrow(from_actor, to_actor);
|
||||
} else {
|
||||
auto graph = parser->FetchKernelGraphByFrontNode(depend_node);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to find actor for node:" << depend_node->DebugString();
|
||||
}
|
||||
from_actor = FetchActor(graph->ToString() + kExitActorNameSuffix);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
LinkControlArrow(from_actor, to_actor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(control_actor_set);
|
||||
|
||||
|
@ -886,6 +1068,15 @@ void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor
|
|||
(void)to_actor->input_data_arrow_aids_.emplace_back(exit_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
|
||||
size_t from_index, size_t to_index, int branch_id) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
auto partial_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
|
||||
(void)exit_actor->output_branch_partial_arrows_[branch_id].emplace_back(partial_arrow);
|
||||
(void)to_actor->input_partial_arrow_aids_.emplace_back(exit_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
|
|
@ -50,7 +50,8 @@ class ControlNodeScheduler {
|
|||
std::vector<EntranceActorPtr> BuildEntranceActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<ExitActorPtr> BuildExitActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<StackActorPtr> BuildStackActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
void BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
|
||||
std::vector<StackActorPtr> *stack_actors);
|
||||
// Interface to link control actors.
|
||||
void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set);
|
||||
|
@ -67,6 +68,10 @@ class ControlNodeScheduler {
|
|||
void LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor,
|
||||
const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index,
|
||||
const ControlNodeParserPtr &parser);
|
||||
void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index,
|
||||
size_t to_index);
|
||||
// Link arrow from stack actor to control actor.
|
||||
void LinkArrowFromStackActor(ControlActor *to_actor);
|
||||
|
||||
// Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor.
|
||||
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
@ -75,6 +80,9 @@ class ControlNodeScheduler {
|
|||
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
|
||||
const ControlNodeParserPtr &parser);
|
||||
|
||||
// Interface tool to link arrows between actors.
|
||||
void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor);
|
||||
// Data arrow with branch id is only exists from gather actor to entrance actor.
|
||||
|
@ -90,6 +98,8 @@ class ControlNodeScheduler {
|
|||
void LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id);
|
||||
void LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index,
|
||||
size_t to_index, int branch_id);
|
||||
void LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index,
|
||||
size_t to_index, int branch_id);
|
||||
bool IsNoInputActor(const ControlActor *control_actor);
|
||||
};
|
||||
} // namespace runtime
|
||||
|
|
Loading…
Reference in New Issue