!66115 Support inline control flow.

Merge pull request !66115 from gaoyong10/dyn-shape-dev-2
This commit is contained in:
i-robot 2024-03-11 02:31:01 +00:00 committed by Gitee
commit cc6ae8c1c8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
43 changed files with 2939 additions and 101 deletions

View File

@ -84,11 +84,14 @@ void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
auto partial_node = dyn_cast<CNode>(node);
const auto &partial_inputs = partial_node->inputs();
MS_EXCEPTION_IF_NULL(partial_inputs.at(0));
if (!IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
} else if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartialInline)) {
switch_subgraph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(partial_node, kAttrKernelGraph);
} else {
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
}
MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
} else if (node->isa<ValueNode>()) {
switch_subgraph = GetValueNode<FuncGraphPtr>(node);
} else {

View File

@ -30,6 +30,7 @@
#include "ops/arithmetic_ops.h"
#include "ops/nn_ops.h"
#include "ops/sequence_ops.h"
#include "ops/framework_ops.h"
#include "ops/op_def.h"
#include "ops/op_utils.h"
@ -494,6 +495,9 @@ const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph,
if (!node->isa<CNode>()) {
return nullptr;
}
if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
return nullptr;
}
if ((node->kernel_info() == nullptr) ||
(!dynamic_cast<device::KernelInfo *>(node->kernel_info())->has_build_info()) ||
(common::AnfAlgo::GetCNodeName(node) == "MakeTuple")) {

View File

@ -0,0 +1,78 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/pass/switch_not_cut.h"
#include <memory>
#include <vector>
#include <utility>
#include "ops/other_ops.h"
#include "ops/framework_ops.h"
#include "utils/ms_context.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
namespace opt {
bool SwitchNotCut::Run(const FuncGraphPtr &func_graph) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
static const bool is_enable_ge = (context->backend_policy() == "ge");
if (!is_enable_ge) {
// only support ge backend
return false;
}
static const auto is_enable_switch_inline = (common::GetEnv("MS_ENABLE_SWITCH_INLINE") == "1");
if (!is_enable_switch_inline) {
return false;
}
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
continue;
}
if (!std::any_of(iter->second.begin(), iter->second.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
return IsPrimitiveCNode(node_index.first, prim::kPrimSwitch);
})) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto partial_graph = cnode->input(kIndex1);
auto sub_graph = common::AnfAlgo::GetValueNodeFuncGraph(partial_graph);
sub_graph->set_flag(kFlagSwitchInline, true);
}
if (IsOneOfPrimitiveCNode(node, {prim::kPrimPartial, prim::kPrimSwitch})) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
cnode->AddPrimalAttr(kAttrNotCut, MakeValue(true));
} else if (utils::isa<CNodePtr>(node)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto primitive_input = cnode->input(kAnfPrimitiveIndex);
if (IsPrimitiveCNode(primitive_input, prim::kPrimSwitch)) {
cnode->AddPrimalAttr(kAttrNotCut, MakeValue(true));
}
}
}
return false;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_
#include <string>
#include "include/backend/optimizer/pass.h"
namespace mindspore {
namespace opt {
class SwitchNotCut : public Pass {
public:
explicit SwitchNotCut(const std::string &name = "switch_not_cut") : Pass(name) {}
~SwitchNotCut() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_

View File

@ -1578,7 +1578,8 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr>
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
MS_EXCEPTION_IF_NULL(root_graph->output());
MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract({root_graph->output()->abstract()});
abstract::AbstractBasePtrList abs_list{root_graph->output()->abstract()};
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
root_graph->set_output(make_tuple);
}

View File

@ -1957,7 +1957,7 @@ KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, c
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
MS_EXCEPTION_IF_NULL(new_cnode);
if (IsPrimitiveCNode(new_cnode, prim::kPrimCall)) {
if (IsOneOfPrimitiveCNode(new_cnode, {prim::kPrimCall, prim::kPrimPartial})) {
auto fn = new_cnode->input(kIndexOne);
MS_EXCEPTION_IF_NULL(fn);
auto child_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(fn);
@ -2809,22 +2809,52 @@ void CopyCNodeInfo(const FuncGraphPtr &func_graph, const uint32_t &target_graph_
common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
}
}
void UpdateConditionNodePair(const KernelGraphPtr &kernel_graph, const KernelGraphPtr &target_kernel_graph,
const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &condition_node_map) {
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &gather_to_switch = kernel_graph->condition_gather_to_switch();
for (const auto &pair : gather_to_switch) {
MS_EXCEPTION_IF_NULL(pair.first);
MS_EXCEPTION_IF_NULL(pair.second);
const auto &gather_iter = condition_node_map.find(pair.first);
const auto &switch_iter = condition_node_map.find(pair.second);
if (gather_iter == condition_node_map.end() || switch_iter == condition_node_map.end()) {
MS_LOG(EXCEPTION) << "Failed to get new gather node:" << pair.first->fullname_with_scope()
<< " or switch node:" << pair.second->fullname_with_scope()
<< " in graph:" << kernel_graph->ToString();
}
MS_EXCEPTION_IF_NULL(gather_iter->second);
MS_EXCEPTION_IF_NULL(switch_iter->second);
if (target_kernel_graph != nullptr) {
target_kernel_graph->AddConditionGatherSwitchPair(gather_iter->second, switch_iter->second);
MS_LOG(INFO) << "Add condition node pair:" << gather_iter->second->fullname_with_scope()
<< " and:" << switch_iter->second->fullname_with_scope()
<< " for graph:" << target_kernel_graph->ToString();
}
}
}
} // namespace
AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
const uint32_t &target_graph_id,
const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
const KernelGraphPtr &graph) {
const KernelGraphPtr &graph, bool is_switch_inline) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
KernelGraphPtr target_kernel_graph = nullptr;
if (target_func_graph->isa<KernelGraph>()) {
target_kernel_graph = target_func_graph->cast<KernelGraphPtr>();
}
Cloner cloner({}, false);
if (scope != nullptr) {
cloner.set_scope(scope);
}
cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
auto node_list = TopoSort(func_graph->output());
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_node_map;
for (auto &ori_node : node_list) {
MS_EXCEPTION_IF_NULL(ori_node);
if (ori_node->isa<Parameter>()) {
@ -2837,8 +2867,34 @@ AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGr
MS_EXCEPTION_IF_NULL(value_node);
graph->AddValueNodeToGraph(value_node);
}
// Add sub graph kernel for switch inline kernel graph.
if (new_node->isa<CNode>() && target_kernel_graph != nullptr && is_switch_inline) {
MS_LOG(DEBUG) << "Add inline sub graph for kernel:" << new_node->fullname_with_scope()
<< " graph:" << func_graph->ToString();
std::string sub_graph_name = func_graph->ToString();
if (func_graph->isa<KernelGraph>()) {
const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &sub_graph_iter = kernel_graph->inline_sub_graph_kernels().find(ori_node);
if (sub_graph_iter != kernel_graph->inline_sub_graph_kernels().end()) {
sub_graph_name = sub_graph_iter->second;
}
}
target_kernel_graph->AddInlineSubgraphKernel(new_node, sub_graph_name);
if (common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionGather) ||
common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionSwitch)) {
condition_node_map[ori_node] = new_node;
}
}
CopyCNodeInfo(func_graph, target_graph_id, ori_node, new_node);
}
// Collect condition gather node and condition switch node.
if (func_graph->isa<KernelGraph>() && is_switch_inline) {
const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
UpdateConditionNodePair(kernel_graph, target_kernel_graph, condition_node_map);
}
for (const auto &kv : ref_map) {
auto final_pair = kv.first;
auto origin_pair = kv.second;

View File

@ -104,7 +104,7 @@ class BACKEND_EXPORT KernelGraphMgr {
const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
const uint32_t &target_graph_id,
const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
const KernelGraphPtr &graph);
const KernelGraphPtr &graph, bool is_switch_inline);
private:
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;

View File

@ -2239,6 +2239,14 @@ SomasNodePtr Somas::GetSomasNode(size_t node_id) const {
}
}
namespace {
bool IsNopMakeTuple(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = common::AnfAlgo::GetInputNum(cnode);
return IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && input_num == 1;
}
} // namespace
common::KernelWithIndex Somas::GetVisitKernelWithReturnType(const AnfNodePtr &ori_node, size_t ori_index) {
auto prenode = common::AnfAlgo::VisitKernelWithReturnType(ori_node, ori_index, false);
MS_EXCEPTION_IF_NULL(prenode.first);
@ -2247,7 +2255,7 @@ common::KernelWithIndex Somas::GetVisitKernelWithReturnType(const AnfNodePtr &or
auto anf_node = prenode.first;
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!common::AnfAlgo::IsNopNode(cnode)) {
if (!common::AnfAlgo::IsNopNode(cnode) && !IsNopMakeTuple(cnode)) {
MS_LOG(INTERNAL_EXCEPTION) << "Node[" << ori_node->fullname_with_scope() << "] find input node["
<< cnode->fullname_with_scope()
<< "] doesn't exist in nodes_map and is not a nop node!!!!";

View File

@ -26,6 +26,7 @@
#include "pipeline/jit/ps/parse/data_converter.h"
#include "backend/graph_compiler/transform.h"
#include "backend/common/pass/erase_invalid_micro_depend.h"
#include "backend/common/pass/switch_not_cut.h"
#include "include/backend/distributed/recovery/recovery_context.h"
#include "include/common/utils/callbacks.h"
#include "include/common/utils/scoped_long_running.h"
@ -530,6 +531,7 @@ namespace {
void DoUnifyMindIRPass(const FuncGraphPtr &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_LOG(INFO) << "Do unify mindir pass for graph " << graph->ToString();
#ifdef ENABLE_DUMP_IR
if (context_ptr->CanDump(kIntroductory)) {
std::string file_name = "hwopt_before_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
@ -539,6 +541,7 @@ void DoUnifyMindIRPass(const FuncGraphPtr &graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
unify_mindir_pm->AddPass(std::make_shared<opt::EraseInvalidMicroDepend>());
unify_mindir_pm->AddPass(std::make_shared<opt::SwitchNotCut>());
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(graph);
#ifdef ENABLE_DUMP_IR
@ -597,6 +600,11 @@ void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) const {
}
}
DoUnifyMindIRPass(root_graph);
const auto &sub_graphs = root_graph->manager()->func_graphs_used_total(root_graph);
for (const auto &sub_graph : sub_graphs) {
MS_EXCEPTION_IF_NULL(sub_graph);
DoUnifyMindIRPass(sub_graph);
}
}
void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
@ -617,7 +625,8 @@ void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::
for (const auto &sub_graph : cand_graph) {
MS_EXCEPTION_IF_NULL(sub_graph);
bool skip_inline_graph =
sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline;
(sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) ||
sub_graph->has_flag(kFlagSwitchInline);
if (sub_graph != func_graph && sub_graph != nullptr && !sub_graph->has_flag(kFlagJitCallGraph) &&
!skip_inline_graph) {
MS_LOG(INFO) << "Compile sub graph " << sub_graph->ToString();

View File

@ -536,6 +536,19 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
void InferType();
void PostNewCNode(const CNodePtr &cnode) const;
void SetKernelInfoForNode(const AnfNodePtr &node) const;
void AddInlineSubgraphKernel(const AnfNodePtr &node, const std::string &graph_name) {
inline_sub_graph_kernels_[node] = graph_name;
}
const mindspore::HashMap<AnfNodePtr, std::string> &inline_sub_graph_kernels() const {
return inline_sub_graph_kernels_;
}
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch() const { return condition_gather_to_switch_; }
void AddConditionGatherSwitchPair(const AnfNodePtr &condition_gather, const AnfNodePtr &condition_switch) {
condition_gather_to_switch_[condition_gather] = condition_switch;
}
void RemoveConditionGatherSwitchPair(const AnfNodePtr &condition_gather) {
condition_gather_to_switch_.erase(condition_gather);
}
void set_is_from_pynative(const bool &from_pynative) { from_pynative_ = from_pynative; }
bool is_from_pynative() const { return from_pynative_; }
@ -577,6 +590,10 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
// parameters that will be updated when graph is executed
mindspore::HashSet<ParameterPtr> updated_parameters_;
// Kernel in inline subgraph for switch node.
mindspore::HashMap<AnfNodePtr, std::string> inline_sub_graph_kernels_;
// Record the relationship between condition gather and condition switch.
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch_;
// graph needn't execute
bool executable_{false};
@ -628,6 +645,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_;
// key:parallel op ptr, value:vector of <send op receive op > pairs
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_;
std::atomic<size_t> pre_graph_finished_count_{0};
std::atomic<size_t> post_graph_finished_count_{0};
bool first_step_{true};

View File

@ -65,7 +65,8 @@ class COMMON_EXPORT AnfAlgo {
static std::vector<KernelWithIndex> GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {},
bool need_make_tuple = false);
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {});
static std::vector<KernelWithIndex> GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node);
// get cnode primitive
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);

View File

@ -389,6 +389,9 @@ constexpr char kAttrTransposeX1[] = "transpose_x1";
constexpr char kAttrTransposeX2[] = "transpose_x2";
constexpr char kAttrCommTurn[] = "comm_turn";
constexpr char kAttrGatherIndex[] = "gather_index";
constexpr char kAttrBranchOutputNum[] = "branch_output_num";
constexpr char kAttrBranchGraphName[] = "branch_graph_name";
constexpr char kInlineSubGraphName[] = "inline_sub_graph_name";
// FuncGraph Flags
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
@ -402,6 +405,7 @@ constexpr auto kFlagIsPyNativeBpropKernelGraph = "is_pynative_bprop_kernel_graph
constexpr auto kFlagPyNativeWithJitCallGraph = "pynative_with_jit_call_graph";
constexpr auto kFlagJitCallGraph = "jit_call_graph";
constexpr auto kFlagJitGraph = "jit_graph";
constexpr auto kFlagSwitchInline = "switch_inline_graph";
constexpr auto kAttrPackFunction = "pack_func";
// custom operator func type

View File

@ -560,8 +560,9 @@ std::tuple<bool, std::string, ExceptionType> SelectKernelInfoWithMsg(const Kerne
}
// for backend inline
if (IsPrimitiveCNode(node, prim::kPrimCallInline)) {
GenerateKernelBuildInfo(node, KernelType::UNKNOWN_KERNEL_TYPE);
if (IsOneOfPrimitiveCNode(node, {prim::kPrimCallInline, prim::kPrimSwitch, prim::kPrimPartialInline,
prim::kPrimConditionSwitch, prim::kPrimConditionGather})) {
GenerateKernelBuildInfo(node, KernelType::RT_KERNEL);
return result;
}

View File

@ -21,8 +21,8 @@
#include <vector>
#include "include/backend/optimizer/helper.h"
#include "utils/ms_context.h"
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
#include "ops/framework_op_name.h"
#include "mindspore/core/ops/framework_ops.h"
namespace mindspore {
namespace device {
@ -100,7 +100,42 @@ void AclSomas::InitEventInfo(const session::KernelGraph &graph) {
MS_LOG(DEBUG) << "Acl Somas InitEventInfo end.";
}
bool AclSomas::DevSpecNodeProcess(const session::KernelGraph &graph) { return true; }
bool AclSomas::RuntimeNodeProcess(const session::KernelGraph &graph) {
auto &kernels = graph.execution_order();
for (auto &kernel : kernels) {
if (!IsPrimitiveCNode(kernel, {prim::kPrimConditionGather})) {
continue;
}
auto iter = nodes_map_.find(kernel.get());
if (iter != nodes_map_.end()) {
auto &node = iter->second.at(0);
MS_EXCEPTION_IF_NULL(node);
auto input_tensors = node->input_tensors_;
auto output_tensors = node->output_tensors_;
constexpr size_t value_two = 2;
MS_EXCEPTION_IF_CHECK_FAIL(input_tensors.size() == output_tensors.size() * value_two,
"Invalid input and output tensors size" + std::to_string(input_tensors.size()) + ", " +
std::to_string(output_tensors.size()));
std::vector<std::vector<size_t>> union_tensors;
for (auto &tensor : output_tensors) {
tensor->type_ = somas::kUnion;
union_tensors.push_back({tensor->GetId()});
}
for (size_t i = 0; i < output_tensors.size(); i++) {
input_tensors[i]->type_ = somas::kUnion;
input_tensors[i + output_tensors.size()]->type_ = somas::kUnion;
union_tensors[i].push_back(input_tensors[i]->GetId());
union_tensors[i].push_back(input_tensors[i + output_tensors.size()]->GetId());
}
union_tensors_list_.insert(union_tensors_list_.end(), union_tensors.begin(), union_tensors.end());
} else {
MS_LOG(EXCEPTION) << "Can't find somas node for inplace node " << kernel->fullname_with_scope();
}
}
return true;
}
bool AclSomas::DevSpecNodeProcess(const session::KernelGraph &graph) { return RuntimeNodeProcess(graph); }
void AclSomas::CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const {
if (tensors.size() != ALONE) {

View File

@ -43,6 +43,7 @@ class AclSomas : public somas::Somas {
bool InitDevSpecControlTensors(const session::KernelGraph &graph) override;
bool DevSpecNodeProcess(const session::KernelGraph &graph) override;
bool RuntimeNodeProcess(const session::KernelGraph &graph);
void InitEventInfo(const session::KernelGraph &graph);
std::map<uint32_t, somas::EventPair> event_map_;

View File

@ -60,6 +60,11 @@
namespace mindspore::device::ascend {
namespace {
constexpr size_t kSwitchInputSize = 3;
constexpr size_t kSwitchCondIndex = 1;
constexpr size_t kSwitchBranchTrueIndex = 2;
constexpr size_t kSwitchBranchFalseIndex = 3;
bool GenerateKernelMod(const std::vector<CNodePtr> &kernels) {
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
@ -83,6 +88,8 @@ bool GenerateKernelMod(const std::vector<CNodePtr> &kernels) {
} else if (kernel_type == KernelType::AKG_KERNEL) {
kernel_mod_ptr = kernel::DvmOpBuild(kernel);
#endif
} else if (AnfAlgo::GetKernelType(kernel) == KernelType::RT_KERNEL) {
kernel_mod_ptr = kernel::RtOpBuild(kernel);
} else {
MS_LOG(EXCEPTION) << "The kernel: " << kernel->fullname_with_scope() << " kernel build failed, kernel type: "
<< kernel::KernelTypeLabel(AnfAlgo::GetKernelType(kernel));
@ -128,6 +135,15 @@ void SetAclOpPrecisionMode() {
}
}
void SelectKernelInfo(const KernelGraphPtr &kernel_graph, const CNodePtr &kernel) {
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(kernel_graph, kernel);
if (!select_res) {
MS_LOG(INFO) << "node is " << kernel->fullname_with_scope() << " should backoff";
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
device::ascend::HandleKernelSelectFailure(kernel_graph, kernel, failure_info);
}
}
void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *const memo) {
// select kernel
MS_EXCEPTION_IF_NULL(memo);
@ -137,12 +153,7 @@ void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *
memo->insert(kernel_graph);
const auto &kernels = kernel_graph->execution_order();
for (const auto &kernel : kernels) {
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(kernel_graph, kernel);
if (!select_res) {
MS_LOG(INFO) << "node is " << kernel->fullname_with_scope() << " should backoff";
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
device::ascend::HandleKernelSelectFailure(kernel_graph, kernel, failure_info);
}
SelectKernelInfo(kernel_graph, kernel);
}
if (!kernel_graph->is_from_single_op()) {
kernel_graph->SetKernelObjectTypesForUnrealNodes();
@ -152,7 +163,54 @@ void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *
}
}
void InlineSubGraph(const KernelGraphPtr &graph) {
void InlineSubGraph(const KernelGraphPtr &graph, CNodePtr kernel_cnode, AnfNodePtr *last_call, bool is_switch_inline) {
MS_EXCEPTION_IF_NULL(kernel_cnode);
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
MS_EXCEPTION_IF_NULL(sub_graph);
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->fullname_with_scope() << ", sub graph: " << sub_graph->graph_id()
<< ", need inline: " << sub_graph->need_inline();
auto main_graph = kernel_cnode->func_graph();
MS_EXCEPTION_IF_NULL(main_graph);
auto mng = main_graph->manager();
auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
AnfNodePtrList inp;
auto &call_input = kernel_cnode->inputs();
// let operators on different subgraphs will not be executed interleavedly
for (size_t i = 1; i < call_input.size(); i++) {
if (last_call != nullptr && (*last_call) != nullptr) {
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), call_input[i], (*last_call)});
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(call_input[i]->abstract());
inp.push_back(depend);
} else {
inp.push_back(call_input[i]);
}
}
const auto &ref_map = sub_graph->GetRefMap();
auto out = session::KernelGraphMgr::DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope(),
kernel_info->graph_id(), ref_map, graph, is_switch_inline);
(void)mng->Replace(kernel_cnode, out);
// Inline graph boundary: MakeTuple---->Depend---->Tensormove
// Avoid long link times at runtime
if (last_call != nullptr) {
auto value_node = graph->NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(1)));
MS_EXCEPTION_IF_NULL(value_node);
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), value_node, out});
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(value_node->abstract());
auto tensor_move =
graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorMove->name())), depend});
MS_EXCEPTION_IF_NULL(tensor_move);
tensor_move->set_abstract(value_node->abstract());
common::AnfAlgo::SetNodeAttr(kAttrKernelGraphBoundary, MakeValue(sub_graph), tensor_move);
// select kernel
SelectKernelInfo(graph, tensor_move);
(*last_call) = tensor_move;
}
}
void InlineCallGraph(const KernelGraphPtr &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifdef ENABLE_DUMP_IR
@ -167,53 +225,7 @@ void InlineSubGraph(const KernelGraphPtr &graph) {
for (auto &kernel_cnode : kernel_cnodes) {
MS_EXCEPTION_IF_NULL(kernel_cnode);
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimCallInline)) {
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
MS_EXCEPTION_IF_NULL(sub_graph);
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->fullname_with_scope()
<< ", sub graph: " << sub_graph->graph_id() << ", need inline: " << sub_graph->need_inline();
auto main_graph = kernel_cnode->func_graph();
MS_EXCEPTION_IF_NULL(main_graph);
auto mng = main_graph->manager();
auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
AnfNodePtrList inp;
auto &call_input = kernel_cnode->inputs();
// let operators on different subgraphs will not be executed interleavedly
for (size_t i = 1; i < call_input.size(); i++) {
if (last_call != nullptr) {
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), call_input[i], last_call});
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(call_input[i]->abstract());
inp.push_back(depend);
} else {
inp.push_back(call_input[i]);
}
}
const auto &ref_map = sub_graph->GetRefMap();
auto out = session::KernelGraphMgr::DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope(),
kernel_info->graph_id(), ref_map, graph);
(void)mng->Replace(kernel_cnode, out);
// Inline graph boundary: MakeTuple---->Depend---->Tensormove
// Avoid long link times at runtime
auto value_node = graph->NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(1)));
MS_EXCEPTION_IF_NULL(value_node);
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), value_node, out});
MS_EXCEPTION_IF_NULL(depend);
depend->set_abstract(value_node->abstract());
auto tensor_move = graph->NewCNode({NewValueNode(prim::kPrimTensorMove), depend});
MS_EXCEPTION_IF_NULL(tensor_move);
tensor_move->set_abstract(value_node->abstract());
common::AnfAlgo::SetNodeAttr(kAttrKernelGraphBoundary, MakeValue(sub_graph), tensor_move);
// select kernel
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(graph, tensor_move);
if (!select_res) {
MS_LOG(INFO) << "node is " << tensor_move->fullname_with_scope() << " should backoff";
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
device::ascend::HandleKernelSelectFailure(graph, tensor_move, failure_info);
}
last_call = tensor_move;
InlineSubGraph(graph, kernel_cnode, &last_call, false);
}
}
GEGraphOptimization::GetInstance().OptimizeACLGraphAfterInline(graph);
@ -224,6 +236,418 @@ void InlineSubGraph(const KernelGraphPtr &graph) {
}
#endif
}
CNodePtr GetCondSwitchNode(const KernelGraphPtr &graph, const std::map<AnfNodePtr, size_t> &branch_input,
const AnfNodePtr &cond, std::map<AnfNodePtr, AnfNodePtr> *branch_tuple_getitem) {
std::vector<AnfNodePtr> cond_switch_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionSwitch->name()))};
cond_switch_inputs.resize(branch_input.size() + kIndex2);
cond_switch_inputs[kIndex1] = cond;
for (auto &kv : branch_input) {
cond_switch_inputs[kv.second + kIndex2] = kv.first;
}
auto cond_switch_node = graph->NewCNode(cond_switch_inputs);
MS_EXCEPTION_IF_NULL(cond_switch_node);
for (auto &kv : branch_input) {
if (branch_tuple_getitem->find(kv.first) == branch_tuple_getitem->end()) {
auto tuple_getitem_node =
graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cond_switch_node, NewValueNode(SizeToLong(kv.second))});
MS_EXCEPTION_IF_NULL(tuple_getitem_node);
tuple_getitem_node->set_abstract(kv.first->abstract());
(*branch_tuple_getitem)[kv.first] = tuple_getitem_node;
}
}
AbstractBasePtrList abstract_list;
for (size_t i = kIndex2; i < cond_switch_inputs.size(); ++i) {
(void)abstract_list.emplace_back(cond_switch_inputs[i]->abstract());
}
cond_switch_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
SelectKernelInfo(graph, cond_switch_node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(cond_switch_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
for (size_t input_index = 1; input_index < common::AnfAlgo::GetInputTensorNum(cond_switch_node); ++input_index) {
kernel_info->AddRefMap(input_index - 1, input_index);
}
return cond_switch_node;
}
CNodePtr GetBranchNode(const KernelGraphPtr &graph, const CNodePtr &old_branch_node,
const std::map<AnfNodePtr, AnfNodePtr> &branch_tuple_getitem) {
std::vector<AnfNodePtr> branch_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartialInline->name()))};
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(old_branch_node); i++) {
auto input = common::AnfAlgo::GetInputNode(old_branch_node, i);
if (branch_tuple_getitem.find(input) != branch_tuple_getitem.end()) {
branch_inputs.push_back(branch_tuple_getitem.at(input));
} else {
MS_LOG(EXCEPTION) << "Invalid input of branch node: " << old_branch_node->fullname_with_scope() << ", " << i
<< ", " << input->fullname_with_scope();
}
}
auto branch_node = graph->NewCNode(branch_inputs);
MS_EXCEPTION_IF_NULL(branch_node);
branch_node->set_abstract(old_branch_node->abstract());
SelectKernelInfo(graph, branch_node);
common::AnfAlgo::CopyNodeAttrs(old_branch_node, branch_node);
return branch_node;
}
void InlineSwitchGraph(const KernelGraphPtr &graph, std::set<KernelGraphPtr> *const memo) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (memo->find(graph) != memo->end()) {
return;
}
memo->insert(graph);
for (auto &child_graph : graph->child_graph_order()) {
InlineSwitchGraph(child_graph.lock(), memo);
}
#ifdef ENABLE_DUMP_IR
bool save_graphs = context_ptr->CanDump(kIntroductory);
if (save_graphs) {
std::string file_name = "hwopt_d_before_inline_switch_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
// process ConditionSwitch/ConditionGather
auto kernel_cnodes = graph->execution_order();
auto mng = graph->manager();
if (mng == nullptr) {
auto manager = MakeManager({graph});
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(graph);
graph->set_manager(manager);
mng = graph->manager();
}
std::vector<CNodePtr> partial_inline_cnode;
for (auto &kernel_cnode : kernel_cnodes) {
if (!IsPrimitiveCNode(kernel_cnode, prim::kPrimSwitch)) {
continue;
}
auto input_num = common::AnfAlgo::GetInputNum(kernel_cnode);
MS_EXCEPTION_IF_CHECK_FAIL(input_num == kSwitchInputSize,
"Invalid input num of switch node: " + kernel_cnode->DebugString());
auto cond = kernel_cnode->input(kSwitchCondIndex);
auto true_branch = kernel_cnode->input(kSwitchBranchTrueIndex);
auto false_branch = kernel_cnode->input(kSwitchBranchFalseIndex);
MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(true_branch, prim::kPrimPartialInline),
"Invalid true branch of switch node: " + kernel_cnode->DebugString());
MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(false_branch, prim::kPrimPartialInline),
"Invalid false branch of switch node: " + kernel_cnode->DebugString());
auto true_branch_cnode = true_branch->cast<CNodePtr>();
auto false_branch_cnode = false_branch->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(true_branch_cnode);
MS_EXCEPTION_IF_NULL(false_branch_cnode);
std::map<AnfNodePtr, size_t> branch_input;
std::map<AnfNodePtr, AnfNodePtr> branch_tuple_getitem;
auto now_input_cnt = 0;
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(true_branch_cnode); i++) {
auto input = common::AnfAlgo::GetInputNode(true_branch_cnode, i);
if (branch_input.find(input) != branch_input.end()) {
continue;
}
branch_input[input] = now_input_cnt++;
}
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(false_branch_cnode); i++) {
auto input = common::AnfAlgo::GetInputNode(false_branch_cnode, i);
if (branch_input.find(input) != branch_input.end()) {
continue;
}
branch_input[input] = now_input_cnt++;
}
auto cond_switch_node = GetCondSwitchNode(graph, branch_input, cond, &branch_tuple_getitem);
auto true_branch_node = GetBranchNode(graph, true_branch_cnode, branch_tuple_getitem);
auto false_branch_node = GetBranchNode(graph, false_branch_cnode, branch_tuple_getitem);
auto cond_gather_node =
graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionGather->name())), true_branch_node,
false_branch_node});
cond_gather_node->set_abstract(kernel_cnode->abstract());
SelectKernelInfo(graph, cond_gather_node);
partial_inline_cnode.push_back(true_branch_node);
partial_inline_cnode.push_back(false_branch_node);
// Record the branch info for condition node.
auto false_sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(false_branch_cnode, kAttrKernelGraph);
auto true_sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(true_branch_cnode, kAttrKernelGraph);
MS_EXCEPTION_IF_NULL(false_sub_graph);
MS_EXCEPTION_IF_NULL(true_sub_graph);
std::vector<ValuePtr> branch_graph_names;
branch_graph_names.emplace_back(std::make_shared<StringImm>(false_sub_graph->ToString()));
branch_graph_names.emplace_back(std::make_shared<StringImm>(true_sub_graph->ToString()));
cond_switch_node->AddAttr(kInlineSubGraphName, std::make_shared<ValueTuple>(branch_graph_names));
reverse(branch_graph_names.begin(), branch_graph_names.end());
cond_gather_node->AddAttr(kAttrBranchGraphName, std::make_shared<ValueTuple>(branch_graph_names));
graph->AddConditionGatherSwitchPair(cond_gather_node, cond_switch_node);
MS_LOG(DEBUG) << "Add new condition gather node:" << cond_gather_node->fullname_with_scope()
<< " and condition switch actor:" << cond_switch_node->fullname_with_scope()
<< " for graph:" << graph->ToString();
(void)mng->Replace(kernel_cnode, cond_gather_node);
}
// inline switch graph
for (auto &kernel_cnode : partial_inline_cnode) {
MS_EXCEPTION_IF_NULL(kernel_cnode);
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimPartialInline)) {
InlineSubGraph(graph, kernel_cnode, nullptr, true);
} else {
MS_LOG(EXCEPTION) << "Invalid node type, node: " << kernel_cnode->fullname_with_scope();
}
}
graph->SetExecOrderByDefault();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name = "hwopt_d_after_inline_switch_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
}
// Flatten the input abstract, and record the index construct.
// eg. tuple(tuple(1, 2), 3) -> {1, 2, 3} ((-1, -1), -1)
AbstractBasePtrList CollectAbstract(const abstract::AbstractBasePtr &abstract, ValuePtr *abstract_construct_index) {
MS_EXCEPTION_IF_NULL(abstract);
MS_EXCEPTION_IF_NULL(abstract_construct_index);
if (!abstract->isa<abstract::AbstractSequence>()) {
*abstract_construct_index = MakeValue(-1);
return {abstract};
}
const auto &seq_abs = abstract->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
if (seq_abs->dynamic_len()) {
*abstract_construct_index = MakeValue(-1);
return {seq_abs};
}
AbstractBasePtrList abs_list;
ValuePtrList construct_index_list;
for (const auto &sub_abs : seq_abs->elements()) {
ValuePtr sub_construct_index = nullptr;
AbstractBasePtrList sub_list = CollectAbstract(sub_abs, &sub_construct_index);
abs_list.insert(abs_list.end(), sub_list.begin(), sub_list.end());
construct_index_list.emplace_back(sub_construct_index);
}
*abstract_construct_index = std::make_shared<ValueTuple>(construct_index_list);
return abs_list;
}
// Rebuild the output construct by construct index.
CNodePtr ConstructMakeTupleRecursion(const ValuePtr &abstract_construct_index, std::deque<CNodePtr> *get_item_list,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(abstract_construct_index);
MS_EXCEPTION_IF_NULL(get_item_list);
MS_EXCEPTION_IF_NULL(graph);
if (!abstract_construct_index->isa<ValueSequence>()) {
if (get_item_list->empty()) {
MS_LOG(EXCEPTION) << "Failed to get item node by value:" << abstract_construct_index->ToString();
} else {
auto top = get_item_list->front();
get_item_list->pop_front();
return top;
}
}
// Build node and abstract for tuple construct.
const auto &seq_value = abstract_construct_index->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_value);
AnfNodePtrList node_list{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
AbstractBasePtrList abs_list;
for (const auto &sub_value : seq_value->value()) {
MS_EXCEPTION_IF_NULL(sub_value);
const auto &new_node = ConstructMakeTupleRecursion(sub_value, get_item_list, graph);
MS_EXCEPTION_IF_NULL(new_node);
node_list.emplace_back(new_node);
abs_list.emplace_back(new_node->abstract());
}
const auto &make_tuple = graph->NewCNode(node_list);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
return make_tuple;
}
AnfNodePtrList CreateTupleGetItemForTupleOutput(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
const auto &abstract = node->abstract();
if (abstract == nullptr) {
MS_LOG(EXCEPTION) << "Invalid abstract for node:" << node->DebugString();
}
if (!abstract->isa<abstract::AbstractSequence>()) {
return {node};
}
const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(sequence_abstract);
if (sequence_abstract->dynamic_len()) {
return {node};
}
AnfNodePtrList outputs;
for (size_t i = 0; i < sequence_abstract->elements().size(); ++i) {
const auto &sub_abstract = sequence_abstract->elements()[i];
MS_EXCEPTION_IF_NULL(sub_abstract);
auto get_item = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleGetItem->name())), node,
NewValueNode(MakeValue<int64_t>(SizeToLong(i)))});
get_item->set_abstract(sub_abstract);
const auto &sub_outputs = CreateTupleGetItemForTupleOutput(get_item, graph);
outputs.insert(outputs.end(), sub_outputs.begin(), sub_outputs.end());
}
return outputs;
}
// Flatten the tuple input of condition gather.
CNodePtr FlattenConditionGatherNodeInput(const CNodePtr &kernel, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
if (mng == nullptr) {
auto manager = MakeManager({graph});
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(graph);
graph->set_manager(manager);
mng = graph->manager();
}
AnfNodePtrList new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionGather->name()))};
size_t output_num = SIZE_MAX;
// Collect inputs.
for (size_t i = 1; i < kernel->inputs().size(); ++i) {
const auto &input = kernel->inputs()[i];
MS_EXCEPTION_IF_NULL(input);
AnfNodePtrList outputs = CreateTupleGetItemForTupleOutput(input, graph);
// All input branch should have same output num.
if (output_num != SIZE_MAX && output_num != outputs.size()) {
MS_LOG(EXCEPTION) << "Invalid output size:" << output_num << " and " << outputs.size()
<< " for kernel:" << kernel->fullname_with_scope();
}
output_num = outputs.size();
new_inputs.insert(new_inputs.end(), outputs.begin(), outputs.end());
}
// Create new condition gather node.
auto new_kernel = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_kernel);
ValuePtr abstract_construct_index = nullptr;
AbstractBasePtrList new_abstract_list = CollectAbstract(kernel->abstract(), &abstract_construct_index);
MS_EXCEPTION_IF_NULL(abstract_construct_index);
MS_LOG(INFO) << "Abstract construct index:" << abstract_construct_index->ToString()
<< " for rebuild the abstract of kernel:" << new_kernel->DebugString();
if (new_abstract_list.size() != output_num) {
MS_LOG(EXCEPTION) << "Invalid abstract list size:" << new_abstract_list.size() << " and output size:" << output_num
<< " for kernel:" << kernel->DebugString() << " abstract:" << kernel->abstract()->ToString();
}
new_kernel->set_abstract(std::make_shared<abstract::AbstractTuple>(new_abstract_list));
SelectKernelInfo(graph, new_kernel);
if (output_num == SIZE_MAX) {
MS_LOG(EXCEPTION) << "Invalid output size:" << output_num << " for kernel:" << kernel->fullname_with_scope();
}
new_kernel->AddAttr(kAttrBranchOutputNum, MakeValue<size_t>(output_num));
if (kernel->HasAttr(kAttrBranchGraphName)) {
new_kernel->AddAttr(kAttrBranchGraphName, kernel->GetAttr(kAttrBranchGraphName));
}
// Rebuild the output construct for condition gather node.
std::deque<CNodePtr> get_item_list;
for (size_t i = 0; i < output_num; ++i) {
auto get_item = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleGetItem->name())),
new_kernel, NewValueNode(MakeValue<int64_t>(i))});
get_item_list.emplace_back(get_item);
get_item->set_abstract(new_abstract_list[i]);
}
auto make_tuple = ConstructMakeTupleRecursion(abstract_construct_index, &get_item_list, graph);
(void)mng->Replace(kernel, make_tuple);
return new_kernel;
}
// Flatten the tuple input of condition node.
void FlattenConditionNodeInput(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
for (auto &kernel : graph->execution_order()) {
if (!IsPrimitiveCNode(kernel, prim::kPrimConditionGather)) {
continue;
}
const auto &new_kernel = FlattenConditionGatherNodeInput(kernel, graph);
MS_EXCEPTION_IF_NULL(new_kernel);
const auto &iter = graph->condition_gather_to_switch().find(kernel);
if (iter == graph->condition_gather_to_switch().end()) {
MS_LOG(EXCEPTION) << "Failed to get condition switch node for gather:" << kernel->DebugString();
}
const auto &inline_iter = graph->inline_sub_graph_kernels().find(kernel);
if (inline_iter != graph->inline_sub_graph_kernels().end()) {
graph->AddInlineSubgraphKernel(new_kernel, inline_iter->second);
MS_LOG(INFO) << "Add new condition gather node:" << new_kernel->fullname_with_scope()
<< " subgraph name:" << inline_iter->second << " to graph:" << graph->ToString();
}
graph->AddConditionGatherSwitchPair(new_kernel, iter->second);
graph->RemoveConditionGatherSwitchPair(kernel);
MS_LOG(INFO) << "Add new condition gather node:" << new_kernel->fullname_with_scope()
<< " to replace node:" << kernel->fullname_with_scope() << " branch name:"
<< (kernel->HasAttr(kAttrBranchGraphName) ? new_kernel->GetAttr(kAttrBranchGraphName)->ToString()
: " null");
}
graph->SetExecOrderByDefault();
#ifdef ENABLE_DUMP_IR
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->CanDump(kIntroductory);
if (save_graphs) {
std::string file_name = "hwopt_d_after_flatten_gather_input_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph, true, kWholeStack);
}
#endif
}
std::string GetBranchName(const KernelGraphPtr &graph, const CNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(graph);
std::string current_branch = graph->ToString();
const auto &iter = graph->inline_sub_graph_kernels().find(kernel);
if (iter != graph->inline_sub_graph_kernels().end()) {
current_branch = iter->second;
}
return current_branch;
}
// Put the kernels belonging to the same inline subgraph together in the execution order.
void FixExecutionOrderForInlineControlFlowGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->condition_gather_to_switch().empty()) {
return;
}
auto execution_order = graph->execution_order();
for (const auto &condition_node_pair : graph->condition_gather_to_switch()) {
std::vector<CNodePtr> new_order;
std::vector<CNodePtr> new_order_after_switch;
MS_EXCEPTION_IF_NULL(condition_node_pair.first);
MS_EXCEPTION_IF_NULL(condition_node_pair.second);
std::string current_branch = GetBranchName(graph, condition_node_pair.second->cast<CNodePtr>());
bool is_get_switch = false;
for (auto iter = execution_order.begin(); iter != execution_order.end(); ++iter) {
if (*iter == condition_node_pair.second) {
is_get_switch = true;
continue;
}
if (*iter == condition_node_pair.first) {
if (!is_get_switch) {
MS_LOG(EXCEPTION) << "Condition gather:" << condition_node_pair.first->fullname_with_scope()
<< " is in front of condition switch: "
<< condition_node_pair.second->fullname_with_scope();
}
new_order.emplace_back(condition_node_pair.second->cast<CNodePtr>());
new_order.insert(new_order.end(), new_order_after_switch.begin(), new_order_after_switch.end());
new_order.insert(new_order.end(), iter, execution_order.end());
break;
}
if (!is_get_switch || current_branch == GetBranchName(graph, *iter)) {
new_order.emplace_back(*iter);
} else {
new_order_after_switch.emplace_back(*iter);
}
}
if (execution_order.size() != new_order.size()) {
MS_LOG(EXCEPTION) << "Failed to reorder execution kernel for graph:" << graph->ToString();
}
execution_order = new_order;
}
graph->set_execution_order(execution_order);
}
} // namespace
void GeKernelExecutor::Initialize() {
@ -276,7 +700,10 @@ void GeKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
memo.clear();
GEGraphOptimization::GetInstance().OptimizeACLGraphAfterKernelSelect(kernel_graph, &memo);
memo.clear();
InlineSubGraph(kernel_graph);
InlineCallGraph(kernel_graph);
memo.clear();
InlineSwitchGraph(kernel_graph, &memo);
FlattenConditionNodeInput(kernel_graph);
OptimizeExecutionOrder(NOT_NULL(graph));
profiler::CollectHostInfo("Ascend", "Graph Optimization", "GeOptimizeGraph", 1, 0, 1);
}
@ -375,6 +802,7 @@ void GeKernelExecutor::OptimizeExecutionOrder(const FuncGraphPtr &graph) const {
kernel_graph->DisableRuntimeCache();
kernel_graph->set_execution_order(execution_order);
MS_LOG(DEBUG) << "Status record: end optimize execution order. graph id: " << kernel_graph->graph_id();
FixExecutionOrderForInlineControlFlowGraph(kernel_graph);
}
void GeKernelExecutor::PreprocessBeforeRun(const FuncGraphPtr &graph) const {

View File

@ -41,11 +41,7 @@ file(GLOB_RECURSE SRC_IN_910B RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"opapi/*.cc"
"pyboost/*.cc"
"pyboost/auto_generate/*.cc"
"rts/send.cc"
"rts/recv.cc"
"rts/rt_kernel.cc"
"rts/rt_kernel_build.cc"
"rts/rt_kernel_info.cc"
"rts/*.cc"
)
list(REMOVE_ITEM SRC_IN_910B ${AICPU_OPS_SRC})

View File

@ -78,7 +78,7 @@ bool IsNotHcclCommunicationOp(const std::string &op_name) {
}
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
static const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
static const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
kNumberTypeFloat32, kNumberTypeInt16, kNumberTypeBFloat16};
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);

View File

@ -0,0 +1,38 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/rts/condition_gather.h"
#include "include/common/utils/anfalgo.h"
#include "include/backend/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
ConditionGatherKernel::~ConditionGatherKernel() {}
bool ConditionGatherKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node);
std::vector<KernelTensor *> output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node);
Resize(input_kernel_tensors, output_kernel_tensors);
return true;
}
bool ConditionGatherKernel::Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
const std::vector<KernelTensor *> &, void *) {
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H
#include <memory>
#include <vector>
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
namespace mindspore {
namespace kernel {
class ConditionGatherKernel : public RtKernel {
public:
ConditionGatherKernel() = default;
~ConditionGatherKernel() override;
bool Init(const AnfNodePtr &anf_node) override;
bool Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
const std::vector<KernelTensor *> &, void *) override;
std::vector<KernelAttr> GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in RtKernel."; }
};
MS_REG_RTKERNEL(conditiongather, ConditionGatherKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H

View File

@ -0,0 +1,38 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/kernel/rts/condition_switch.h"
#include "include/common/utils/anfalgo.h"
#include "include/backend/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
ConditionSwitchKernel::~ConditionSwitchKernel() {}
bool ConditionSwitchKernel::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node);
std::vector<KernelTensor *> output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node);
Resize(input_kernel_tensors, output_kernel_tensors);
return true;
}
bool ConditionSwitchKernel::Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
const std::vector<KernelTensor *> &, void *) {
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H
#include <memory>
#include <vector>
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
namespace mindspore {
namespace kernel {
class ConditionSwitchKernel : public RtKernel {
public:
ConditionSwitchKernel() = default;
~ConditionSwitchKernel() override;
bool Init(const AnfNodePtr &anf_node) override;
bool Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
const std::vector<KernelTensor *> &, void *) override;
std::vector<KernelAttr> GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in RtKernel."; }
};
MS_REG_RTKERNEL(conditionswitch, ConditionSwitchKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H

View File

@ -0,0 +1,63 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ge/process_partial_inline.h"
#include <memory>
#include <string>
#include <vector>
#include "include/backend/anf_runtime_algorithm.h"
#include "include/backend/optimizer/helper.h"
#include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/framework_ops.h"
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
const BaseRef ProcessPartialInline::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimPartial, Xs});
}
const AnfNodePtr ProcessPartialInline::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
if (!context->IsKByKExecutorMode()) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->HasPrimalAttr(kAttrNotCut)) {
return nullptr;
}
auto partial_graph = cnode->input(kIndex1);
auto sub_kernel_graph = session::AnfRuntimeAlgorithm::GetValueNodeKernelGraph(partial_graph);
std::vector<AnfNodePtr> partial_inline_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimPartialInline->name()))};
for (size_t i = kIndex1; i < common::AnfAlgo::GetInputNum(cnode); i++) {
partial_inline_inputs.emplace_back(common::AnfAlgo::GetInputNode(cnode, i));
}
auto partial_inline = graph->NewCNode(partial_inline_inputs);
MS_EXCEPTION_IF_NULL(partial_inline);
partial_inline->set_abstract(cnode->abstract());
common::AnfAlgo::SetNodeAttr(kAttrKernelGraph, MakeValue(sub_kernel_graph), partial_inline);
return partial_inline;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_
#include "include/backend/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class ProcessPartialInline : public PatternProcessPass {
public:
explicit ProcessPartialInline(bool multi_graph = true) : PatternProcessPass("process partial inline", multi_graph) {}
~ProcessPartialInline() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_

View File

@ -41,6 +41,7 @@
#include "plugin/device/ascend/optimizer/ge/unfold_nested_output.h"
#include "plugin/device/ascend/optimizer/ge/resize_bilinear_add_attr.h"
#include "plugin/device/ascend/optimizer/ge/process_call_inline.h"
#include "plugin/device/ascend/optimizer/ge/process_partial_inline.h"
#include "plugin/device/ascend/optimizer/format_type/deal_ref_output.h"
#include "plugin/device/ascend/optimizer/ge/hcom/insert_load_for_allgather.h"
#include "plugin/device/ascend/optimizer/format_type/set_fracz_group_attr.h"
@ -125,6 +126,7 @@ void GEBackendOptimizeACL(const KernelGraphPtr &kernel_graph) {
opt_acl_pm->AddPass(std::make_shared<InsertTensorMoveForCommunication>());
opt_acl_pm->AddPass(std::make_shared<opt::TransDependValueToInt32>());
opt_acl_pm->AddPass(std::make_shared<opt::ProcessCallInline>());
opt_acl_pm->AddPass(std::make_shared<opt::ProcessPartialInline>());
opt_acl_pm->AddPass(std::make_shared<opt::ExpanderFallback>());
optimizer->AddPassManager(opt_acl_pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -524,20 +524,16 @@ void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *de
kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
MS_LOG(DEBUG) << "Kernel tensor created without set stream id, but set after device address created.";
auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
if (user_data != nullptr) {
device_address->SetNodeIndex(kernel, i);
}
device_address->SetNodeIndex(kernel, i);
if (is_from_persistent_mem) {
device_address->set_from_persistent_mem(true);
}
if (find(outputs.begin(), outputs.end(), kernel) != outputs.end()) {
device_address->SetNodeIndex(kernel, i);
}
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
<< " addr:" << device_address << " type:" << device_address->type_id()
<< ", kernel tensor addr:" << kernel_tensor.get()
<< ", kernel tensor: " << kernel_tensor->ToString() << " addr size:" << address_size
<< " real size:" << device_address->GetSize();
<< " real size:" << device_address->GetSize()
<< " origin ref count:" << device_address->original_ref_count();
device_address->set_stream_id(AnfAlgo::GetStreamId(kernel));
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
}
@ -735,6 +731,8 @@ void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr
MS_EXCEPTION_IF_NULL(group_node_device_address);
// Update the reference count of device address.
device_address->IncreaseOriginalRefCount();
MS_LOG(DEBUG) << "After increase ref count for device address:" << device_address
<< " ref count:" << device_address->original_ref_count();
device_address->ResetRefCount();
group_node_device_address->set_pointer_ref_count(device_address->pointer_ref_count());
}
@ -798,6 +796,8 @@ void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur
cur_node_output_addr->DecreaseOriginalRefCount();
cur_node_output_addr->ResetRefCount();
origin_node_output_addr->IncreaseOriginalRefCount();
MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
<< " ref count:" << origin_node_output_addr->original_ref_count();
origin_node_output_addr->ResetRefCount();
cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
cur_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);

View File

@ -26,11 +26,14 @@ void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<
// The unused data may be invalid ptr.
if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
!TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed)) {
MS_LOG(EXCEPTION) << "The input_data does not have a valid ptr of actor:" << GetAID().Name()
<< " with index:" << input_data->index_ << ", flag:" << input_data->data_->flag()
<< " device address:" << input_data->data_ << " ref count:" << input_data->data_->ref_count()
<< " dynamic ref count:" << input_data->data_->dynamic_ref_count()
<< " origin ref count:" << input_data->data_->original_ref_count();
std::string error_info = "The input_data does not have a valid ptr of actor:" + GetAID().Name() +
" with index:" + std::to_string(input_data->index_) +
", flag:" + std::to_string(input_data->data_->flag()) +
" device address:" + std::to_string((int64_t)(input_data->data_)) +
" ref count:" + std::to_string(input_data->data_->ref_count()) +
" dynamic ref count:" + std::to_string(input_data->data_->dynamic_ref_count()) +
" origin ref count:" + std::to_string(input_data->data_->original_ref_count());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);

View File

@ -161,6 +161,15 @@ bool IsRpcActor(const AnfNodePtr &node) {
return false;
}
bool IsInnerControlFlowActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == "ConditionSwitch" ||
common::AnfAlgo::GetCNodeName(node) == "ConditionGather")) {
return true;
}
return false;
}
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {

View File

@ -116,7 +116,10 @@ enum class KernelTransformType {
// Memory actor type.
kMemoryAllocActor,
kMemoryFreeActor,
kMemorySwapActor
kMemorySwapActor,
// Inner control flow actor type.
kConditionGatherActor,
kConditionSwitchActor
};
#define SET_OPCONTEXT_FAIL_RET_WITH_ERROR(op_context, message) \
@ -302,6 +305,7 @@ bool IsSkippedKernelActor(const AnfNodePtr &node);
bool IsRpcActor(const AnfNodePtr &node);
bool IsInnerControlFlowActor(const AnfNodePtr &node);
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph.
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);

View File

@ -46,6 +46,8 @@
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
#include "runtime/graph_scheduler/actor/memory/memory_swap_actor.h"
#ifdef ENABLE_RPC_ACTOR

View File

@ -0,0 +1,147 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
namespace mindspore {
namespace runtime {
ConditionGatherActor::ConditionGatherActor(const std::string &name, const CNodePtr &kernel,
const DeviceContext *device_context, const AID &memory_manager_aid,
const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy,
const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes,
const KernelTransformType &type)
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
void ConditionGatherActor::RunBranchName(const std::string &branch_name, OpContext<DeviceTensor> *const context) {
MS_LOG(DEBUG) << "Condition gather actor:" << GetAID() << " branch name:" << branch_name;
current_branch_name_ = branch_name;
if (branch_name_to_input_data_num_.find(current_branch_name_) == branch_name_to_input_data_num_.end()) {
input_datas_num_ = 0;
} else {
input_datas_num_ = branch_name_to_input_data_num_[current_branch_name_];
}
if (branch_name_to_input_control_num_.find(current_branch_name_) == branch_name_to_input_control_num_.end()) {
input_controls_num_ = 0;
} else {
input_controls_num_ = branch_name_to_input_control_num_[current_branch_name_];
}
if (input_datas_num_ == 0 && input_controls_num_ == 0) {
MS_LOG(EXCEPTION) << "No input data and input control, branch id:" << current_branch_name_
<< " for actor:" << GetAID();
}
MS_LOG(DEBUG) << "Input data num:" << input_datas_num_ << " control num:" << input_controls_num_
<< " for actor:" << GetAID();
}
void ConditionGatherActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
input_device_tensors_.resize(branch_output_num_);
InitOutputData();
}
void ConditionGatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto iter = std::find(branch_names_.begin(), branch_names_.end(), current_branch_name_);
if (iter == branch_names_.end()) {
MS_LOG(EXCEPTION) << "Invalid current branch name:" << current_branch_name_ << " total:" << branch_names_
<< " for actor:" << GetAID();
}
size_t start_index = branch_output_num_ * (iter - branch_names_.begin());
memory_free_list_.clear();
// Fetch input device tensor from input data.
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) < start_index ||
IntToSize(input_data->index_) - start_index >= input_device_tensors_.size()) {
std::string error_info =
"Invalid input index:" + std::to_string(input_data->index_) + " start:" + std::to_string(start_index) +
" total:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
MS_EXCEPTION_IF_NULL(input_data->data_);
input_device_tensors_[IntToSize(input_data->index_) - start_index] = input_data->data_;
memory_free_list_.emplace_back(input_data->data_);
}
}
// Fetch input device tensor from device tensor store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
if (device_tensor_store_key.first < start_index ||
device_tensor_store_key.first - start_index >= input_device_tensors_.size()) {
continue;
}
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
input_device_tensors_[device_tensor_store_key.first - start_index] = device_tensor.get();
}
if (output_data_.size() != output_data_arrows_.size()) {
MS_LOG(EXCEPTION) << "Invalid output data size:" << output_data_.size()
<< " and output data arrow size:" << output_data_arrows_.size() << " for actor:" << GetAID();
}
for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
MS_EXCEPTION_IF_NULL(output_data_arrows_[i]);
MS_EXCEPTION_IF_NULL(output_data_[i].first);
const auto &from_index = output_data_arrows_[i]->from_output_index_;
if (IntToSize(from_index) >= input_device_tensors_.size() || from_index < 0) {
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " to actor:" << output_data_arrows_[i]->to_op_id_
<< " to index:" << output_data_arrows_[i]->to_input_index_ << " for actor:" << GetAID();
}
if (input_device_tensors_[from_index] == nullptr) {
std::string error_info =
GetAID().Name() + " get input device tensor index:" + std::to_string(from_index) + " failed.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
output_data_[i].first->data_ = input_device_tensors_[from_index];
}
}
void ConditionGatherActor::Run(OpContext<DeviceTensor> *const context) {
try {
FetchInput(context);
if (memory_free_list_.size() > 0) {
SendMemoryFreeReq(context);
}
MS_LOG(DEBUG) << "My executor order log launch kernel:" << kernel_->fullname_with_scope();
EraseInput(context);
SendOutput(context);
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info =
"#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,67 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_
#include <set>
#include <vector>
#include <string>
#include <memory>
#include "runtime/graph_scheduler/actor/actor_common.h"
#include "runtime/graph_scheduler/actor/kernel_actor.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
// Condition gather actor is used to collect the output of different branch from condition switch actor.
class ConditionGatherActor : public KernelActor {
public:
ConditionGatherActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes,
const KernelTransformType &type = KernelTransformType::kConditionGatherActor);
~ConditionGatherActor() override = default;
// Receive the branch name from condition switch actor.
void RunBranchName(const std::string &branch_name, OpContext<DeviceTensor> *const context);
protected:
void Init() override;
void FetchInput(OpContext<DeviceTensor> *const context);
void Run(OpContext<DeviceTensor> *const context) override;
private:
friend class InlineControlFlowScheduler;
// Output num of each branch.
size_t branch_output_num_;
// The order of each branch name.
std::vector<std::string> branch_names_;
// The current execute branch between switch and gather actor.
std::string current_branch_name_;
// Input data and control num for each branch.
mindspore::HashMap<std::string, size_t> branch_name_to_id_;
mindspore::HashMap<std::string, size_t> branch_name_to_input_data_num_;
mindspore::HashMap<std::string, size_t> branch_name_to_input_control_num_;
};
using ConditionGatherActorPtr = std::shared_ptr<ConditionGatherActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_

View File

@ -0,0 +1,191 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
namespace mindspore {
namespace runtime {
ConditionSwitchActor::ConditionSwitchActor(const std::string &name, const CNodePtr &kernel,
const DeviceContext *device_context, const AID &memory_manager_aid,
const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy,
const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes,
const KernelTransformType &type)
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
void ConditionSwitchActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
input_device_tensors_.resize(common::AnfAlgo::GetInputTensorNum(kernel_));
InitOutputData();
output_data_by_output_index_.resize(AnfAlgo::GetOutputTensorNum(kernel_));
if (output_data_.size() != output_data_arrows_.size()) {
MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
}
for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
const auto &output_data = output_data_[i].first;
const auto &data_arrow = output_data_arrows_[i];
MS_EXCEPTION_IF_NULL(output_data);
MS_EXCEPTION_IF_NULL(data_arrow);
const auto &from_index = data_arrow->from_output_index_;
if (IntToSize(from_index) >= output_data_by_output_index_.size()) {
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
<< " and output size:" << output_data_by_output_index_.size() << " for actor:" << GetAID();
}
output_data_by_output_index_[from_index].emplace_back(output_data.get());
}
}
void ConditionSwitchActor::SendOutput(OpContext<DeviceTensor> *const context, size_t index) {
MS_EXCEPTION_IF_NULL(gather_aid_);
MS_LOG(DEBUG) << "condition actor run for index:" << index << " branch name:" << branch_names_[index]
<< " for actor:" << GetAID();
ActorDispatcher::Send(*gather_aid_, &ConditionGatherActor::RunBranchName, branch_names_[index], context);
if (output_data_arrows_.size() != output_data_nodes_.size() || output_data_nodes_.size() != output_data_.size() ||
output_data_.size() != output_data_branch_indexes_.size()) {
MS_LOG(EXCEPTION) << "Invalid data arrow size:" << output_data_arrows_.size()
<< " node size:" << output_data_nodes_.size() << " data size:" << output_data_.size()
<< " index size:" << output_data_branch_indexes_.size() << " for actor:" << GetAID();
}
for (size_t i = 0; i < output_data_branch_indexes_.size(); ++i) {
if (output_data_branch_indexes_[i] == index) {
ActorDispatcher::Send(output_data_arrows_[i]->to_op_id_, &OpActor::RunOpData, output_data_[i].first.get(),
context);
}
}
if (output_control_arrows_.size() != output_control_branch_indexes_.size()) {
MS_LOG(EXCEPTION) << "Invalid control arrow size:" << output_control_arrows_.size()
<< output_control_branch_indexes_.size() << " for actor:" << GetAID();
}
for (size_t i = 0; i < output_control_branch_indexes_.size(); ++i) {
MS_EXCEPTION_IF_NULL(output_control_arrows_[i]);
if (output_control_branch_indexes_[i] == index) {
ActorDispatcher::Send(output_control_arrows_[i]->to_op_id_, &OpActor::RunOpControl, const_cast<AID *>(&GetAID()),
context);
}
}
}
void ConditionSwitchActor::Run(OpContext<DeviceTensor> *const context) {
try {
FetchInput(context);
MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
MS_EXCEPTION_IF_NULL(input_device_tensors_[0]->kernel_tensor());
bool index = input_device_tensors_[0]->kernel_tensor()->GetValueWithCheck<bool>();
MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
EraseInput(context);
CollectMemoryFreeList(index);
if (memory_free_list_.size() > 0) {
SendMemoryFreeReq(context);
}
MS_LOG(DEBUG) << "Launch kernel:" << kernel_->fullname_with_scope();
SendOutput(context, index);
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info =
"#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
}
}
void ConditionSwitchActor::CollectMemoryFreeList(size_t index) {
memory_free_list_.clear();
memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin(), input_device_tensors_.end());
for (size_t i = 0; i < branch_origin_ref_count_.size(); ++i) {
if (i == index) {
continue;
}
if (branch_origin_ref_count_[i].size() + 1 != input_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "Invalid origin ref count size:" << branch_origin_ref_count_[i]
<< " and input size:" << input_device_tensors_.size() << " for actor:" << GetAID();
}
MS_LOG(DEBUG) << "Free memory for branch:" << i << " for actor:" << GetAID();
for (size_t j = 0; j < branch_origin_ref_count_[i].size(); ++j) {
std::fill_n(back_inserter(memory_free_list_), branch_origin_ref_count_[i][j], input_device_tensors_[j + 1]);
}
}
}
void ConditionSwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Fetch input device tensor from input data.
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
std::string error_info = "Invalid input index, need:" + std::to_string(input_data->index_) +
" current:" + std::to_string(input_device_tensors_.size()) +
" for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
MS_EXCEPTION_IF_NULL(input_data->data_);
input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
}
}
// Fetch input device tensor from device tensor store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
if (device_tensor_store_key.first >= input_device_tensors_.size()) {
std::string error_info =
"The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
" current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
MS_EXCEPTION_IF_NULL(device_tensor);
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
}
if (output_data_by_output_index_.size() + 1 != input_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "Invalid output size:" << output_data_by_output_index_.size()
<< " and input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
}
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
if (output_data_by_output_index_[i].empty()) {
continue;
}
const auto &data = input_device_tensors_[i + 1];
MS_EXCEPTION_IF_NULL(data);
for (auto &output_data : output_data_by_output_index_[i]) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = data;
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "runtime/graph_scheduler/actor/actor_common.h"
#include "runtime/graph_scheduler/actor/kernel_actor.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
// Condition switch actor is used to execute the branch according to the input condition in kernel graph.
class ConditionSwitchActor : public KernelActor {
public:
ConditionSwitchActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes,
const KernelTransformType &type = KernelTransformType::kConditionSwitchActor);
~ConditionSwitchActor() override = default;
protected:
void Init() override;
void Run(OpContext<DeviceTensor> *const context) override;
void FetchInput(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context, size_t index);
private:
friend class InlineControlFlowScheduler;
// Collect memory free list, as the ref counts of different branches are superimposed on the output,
// so the excess reference counts of other branches need to be subtracted in advance.
void CollectMemoryFreeList(size_t index);
// Graph name of each branch,
std::vector<std::string> branch_names_;
// Ref count of each branch.
std::vector<std::vector<size_t>> branch_origin_ref_count_;
// Branch of data arrow and control arrow.
std::vector<size_t> output_data_branch_indexes_;
std::vector<size_t> output_control_branch_indexes_;
// Cache output data by output index to modify the output data effectively.
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
// Switch needs to send current branch name to the corresponding gather actor to check its inputs.
AID *gather_aid_{nullptr};
};
using ConditionSwitchActorPtr = std::shared_ptr<ConditionSwitchActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_

View File

@ -552,7 +552,7 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
(void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageLink, 1, 0, 0);
Link(actor_set.get(), graph_compiler_info);
(void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageLink, 1, 0, 1);
inline_control_flow_scheduler_.Link(actor_set.get(), graph_compiler_info);
DumpActor(actor_set.get(), graph_compiler_info);
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
SchedulerHelper::CheckActorValid(actor_set.get());
@ -1062,6 +1062,8 @@ void GraphScheduler::UpdateDeviceAddressByRefInternalParameter(const GraphCompil
cur_node_output_addr->DecreaseOriginalRefCount();
cur_node_output_addr->ResetRefCount();
origin_node_output_addr->IncreaseOriginalRefCount();
MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
<< " ref count:" << origin_node_output_addr->original_ref_count();
origin_node_output_addr->ResetRefCount();
cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
}
@ -1285,6 +1287,9 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
KernelActorPtr kernel_actor = nullptr;
if (IsRpcActor(kernel)) {
kernel_actor = GenerateRpcActor(kernel, real_device_context, strategy, ref_input_indexes, ref_output_indexes);
} else if (IsInnerControlFlowActor(kernel)) {
kernel_actor =
GenerateInnerControlFlowActor(kernel, real_device_context, strategy, ref_input_indexes, ref_output_indexes);
} else {
kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, real_device_context,
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
@ -1518,6 +1523,27 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
return nullptr;
}
KernelActorPtr GraphScheduler::GenerateInnerControlFlowActor(const CNodePtr &kernel,
const DeviceContext *device_context,
GraphExecutionStrategy strategy,
const std::set<size_t> &ref_input_indexes,
const std::set<size_t> &ref_output_indexes) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::GetCNodeName(kernel) != "ConditionSwitch" &&
common::AnfAlgo::GetCNodeName(kernel) != "ConditionGather") {
MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Kernel " << kernel->fullname_with_scope()
<< " is not a inner control flow kernel.";
}
if (common::AnfAlgo::GetCNodeName(kernel) == "ConditionSwitch") {
return std::make_shared<ConditionSwitchActor>(kernel->fullname_with_scope(), kernel, device_context,
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
ref_input_indexes, ref_output_indexes);
}
return std::make_shared<ConditionGatherActor>(kernel->fullname_with_scope(), kernel, device_context,
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
ref_input_indexes, ref_output_indexes);
}
namespace {
void GetAllUInputByCNode(const CNodePtr &cnode,
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
@ -2385,7 +2411,7 @@ void GraphScheduler::LinkDataArrowForCustomActor(const ActorSet *actor_set,
void GraphScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
const GraphCompilerInfo &graph_compiler_info) const {
MS_EXCEPTION_IF_NULL(graph);
if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
if (graph->is_graph_run_mode() || graph->is_any_type_input() || !graph->inline_sub_graph_kernels().empty()) {
return;
}

View File

@ -29,6 +29,7 @@
#include "utils/hash_set.h"
#include "runtime/graph_scheduler/control_node_scheduler.h"
#include "runtime/graph_scheduler/any_type_graph_scheduler.h"
#include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
#include "runtime/graph_scheduler/mem_swap_scheduler.h"
#include "runtime/graph_scheduler/actor/actor_set.h"
#include "runtime/graph_scheduler/graph_compiler.h"
@ -131,7 +132,11 @@ class BACKEND_EXPORT GraphScheduler {
KernelActorPtr GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes);
// Generate inner control flow actor in execution order.
KernelActorPtr GenerateInnerControlFlowActor(const CNodePtr &kernel, const DeviceContext *device_context,
GraphExecutionStrategy strategy,
const std::set<size_t> &ref_input_indexes,
const std::set<size_t> &ref_output_indexes);
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
@ -247,6 +252,8 @@ class BACKEND_EXPORT GraphScheduler {
ControlNodeScheduler control_node_scheduler_;
// If there is an any type input in graph, it will be used to transform it.
AnyTypeGraphScheduler any_type_graph_scheduler_;
// If there is inline control flow in kernel graph, it will be used to transform it.
InlineControlFlowScheduler inline_control_flow_scheduler_;
// Build and link swap actor when memory offload is enabled.
MemSwapScheduler swap_node_scheduler_;

View File

@ -0,0 +1,610 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
#include <vector>
#include "runtime/graph_scheduler/scheduler_helper.h"
#include "ops/framework_ops.h"
namespace mindspore {
namespace runtime {
void InlineControlFlowScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(graph);
const auto &inline_sub_graph_kernels = graph->inline_sub_graph_kernels();
if (inline_sub_graph_kernels.empty()) {
return;
}
MS_LOG(DEBUG) << "Link control arrow for graph:" << graph->ToString();
// Only link control arrow between kernels in the same graph.
mindspore::HashMap<std::string, AbstractActor *> branch_last_actor;
for (size_t i = 0; i < graph->execution_order().size(); ++i) {
const auto &to_kernel = graph->execution_order()[i];
if (IsRpcActor(to_kernel)) {
MS_LOG(INFO) << "Rpc op is not available in the execution order, from kernel: "
<< graph->execution_order()[i - 1]->fullname_with_scope()
<< ", to kernel:" << graph->execution_order()[i]->fullname_with_scope();
continue;
}
const auto &iter = inline_sub_graph_kernels.find(to_kernel);
std::string current_branch = graph->ToString();
if (iter != inline_sub_graph_kernels.end()) {
current_branch = iter->second;
MS_LOG(DEBUG) << "Kernel:" << to_kernel->fullname_with_scope() << " branch:" << current_branch;
}
const auto to_kernel_type = FetchKernelTransformType(to_kernel, graph, {}, GraphExecutionStrategy::kPipeline);
auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_kernel, graph);
const auto &actor_iter = branch_last_actor.find(current_branch);
if (actor_iter == branch_last_actor.end()) {
if (!common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
branch_last_actor[current_branch] = to_actor;
MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
}
continue;
}
MS_LOG(DEBUG) << "Add control arrow between " << actor_iter->second->GetAID() << " and " << to_actor->GetAID();
SchedulerHelper::AddControlArrow(actor_iter->second, to_actor);
if (common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
// The control relation end after the condition switch node in graph.
branch_last_actor.erase(current_branch);
MS_LOG(DEBUG) << "For branch:" << current_branch << " end actor:" << to_actor->GetAID();
} else {
// The control relation start first kernel in graph.
branch_last_actor[current_branch] = to_actor;
MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
}
}
}
// Get the branch name by input data arrow.
std::string InlineControlFlowScheduler::GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
KernelActor *condition_gather_actor,
DataArrow *data_arrow,
const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(condition_switch_actor);
MS_EXCEPTION_IF_NULL(condition_gather_actor);
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(condition_gather_actor->kernel());
const auto &condition_pair_iter = kernel_graph->condition_gather_to_switch().find(condition_gather_actor->kernel());
if (condition_pair_iter == kernel_graph->condition_gather_to_switch().end() ||
condition_pair_iter->second != condition_switch_actor->kernel()) {
MS_LOG(EXCEPTION) << "Condition switch actor:" << condition_switch_actor->GetAID()
<< " and gather actor:" << condition_gather_actor << " is not match.";
}
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchOutputNum)) {
MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
}
// Get the output branch index in condition gather actor.
const auto &output_value = condition_gather_actor->kernel()->GetAttr(kAttrBranchOutputNum);
MS_EXCEPTION_IF_NULL(output_value);
size_t branch_index = data_arrow->to_input_index_ / GetValue<size_t>(output_value);
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchGraphName)) {
MS_LOG(EXCEPTION) << "Failed to get branch graph name by actor:" << condition_gather_actor->GetAID();
}
// Get output branch name by branch index.
const auto &branch_graph_names = condition_gather_actor->kernel()->GetAttr(kAttrBranchGraphName);
MS_EXCEPTION_IF_NULL(branch_graph_names);
MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
<< " for actor:" << condition_gather_actor->GetAID();
if (!branch_graph_names->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
<< " for actor:" << condition_gather_actor->GetAID();
}
const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_name);
if (branch_index >= tuple_name->size()) {
MS_LOG(EXCEPTION) << "Invalid to index:" << data_arrow->to_input_index_
<< " output num:" << GetValue<size_t>(output_value)
<< " branch graph name:" << tuple_name->ToString()
<< " from actor:" << condition_switch_actor->GetAID()
<< " to actor:" << condition_gather_actor->GetAID();
}
MS_EXCEPTION_IF_NULL(tuple_name->value()[branch_index]);
return GetValue<std::string>(tuple_name->value()[branch_index]);
}
void InlineControlFlowScheduler::FixRefCountByKernelGraphRefMap(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph) {
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
for (const auto &ref_pair : kernel_graph->GetRefMap()) {
const auto &output_pair = ref_pair.first;
const auto &origin_pair = ref_pair.second;
MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
<< " origin node:" << origin_pair.first->fullname_with_scope();
const auto &recursive_origin_pair = kernel_graph->GetRefNodeRecursive(output_pair);
// If the input node of ref node pair is a condition switch node , the ref count of corresponding switch node input
// should add 1.
if (recursive_origin_pair.first == condition_switch_actor->kernel() && output_pair.first != nullptr) {
MS_LOG(DEBUG) << "Condition switch node is an input of ref node:" << output_pair.first->fullname_with_scope();
if (inline_sub_graph_kernels.find(output_pair.first) == inline_sub_graph_kernels.end()) {
MS_LOG(EXCEPTION) << "Failed to get inline subgraph name by ref node:"
<< output_pair.first->fullname_with_scope();
}
// Get the branch index for ref output.
const auto &current_branch_name = inline_sub_graph_kernels.at(output_pair.first);
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
condition_switch_actor->branch_names_.end(), current_branch_name);
if (iter == condition_switch_actor->branch_names_.end()) {
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
<< " total branch name:" << condition_switch_actor->branch_names_
<< " for actor:" << condition_switch_actor->GetAID();
}
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
if (recursive_origin_pair.second >= output_num || branch_index >= condition_switch_actor->branch_names_.size()) {
MS_LOG(EXCEPTION) << "Invalid output index:" << recursive_origin_pair.second << " total:" << output_num
<< " and branch index:" << branch_index
<< " total:" << condition_switch_actor->branch_names_.size()
<< " for actor:" << condition_switch_actor->GetAID();
}
// The ref count of the corresponding branch add 1.
condition_switch_actor->branch_origin_ref_count_[branch_index][recursive_origin_pair.second]++;
MS_LOG(DEBUG) << "Add ref count for current branch:" << current_branch_name << " branch index:" << branch_index
<< " output index:" << recursive_origin_pair.second
<< " of actor:" << condition_switch_actor->GetAID();
}
}
}
void InlineControlFlowScheduler::FixRefCountByConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(condition_switch_actor);
// Collect all the output ref count of condition switch actor.
std::vector<size_t> total_ref_count;
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
for (size_t i = 0; i < output_num; ++i) {
const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_switch_actor->kernel(), i);
MS_EXCEPTION_IF_NULL(device_address);
total_ref_count.emplace_back(device_address->original_ref_count());
MS_LOG(DEBUG) << "For actor:" << condition_switch_actor->GetAID() << " output device address:" << device_address
<< " output index:" << i << " ref_count:" << total_ref_count.back();
}
size_t input_num = common::AnfAlgo::GetInputTensorNum(condition_switch_actor->kernel());
// Input num should same as the output num and the condition of switch node.
if (input_num != output_num + 1) {
MS_LOG(EXCEPTION) << "Invalid input num:" << input_num << " and output num:" << output_num
<< " for actor:" << condition_switch_actor->GetAID();
}
// Add the ref count to the input of condition switch actor.
for (size_t i = 1; i < input_num; ++i) {
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_switch_actor->kernel(), i);
MS_EXCEPTION_IF_NULL(device_address);
MS_LOG(DEBUG) << "For actor::" << condition_switch_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
if (device_address->original_ref_count() == SIZE_MAX) {
continue;
}
device_address->set_original_ref_count(device_address->original_ref_count() + total_ref_count[i - 1] - 1);
device_address->ResetRefCount();
MS_LOG(DEBUG) << "For actor::" << condition_switch_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
}
FixRefCountByKernelGraphRefMap(condition_switch_actor, kernel_graph);
}
void InlineControlFlowScheduler::InitOutputDataBranchInfoForConditionSwitchActor(
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
condition_switch_actor->output_data_branch_indexes_.resize(condition_switch_actor->output_data_arrows().size());
// Get the index for each output data arrow.
for (size_t i = 0; i < condition_switch_actor->output_data_arrows().size(); ++i) {
const auto &output_node = condition_switch_actor->output_data_nodes()[i];
const auto &data_arrow = condition_switch_actor->output_data_arrows()[i];
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(data_arrow);
const auto &to_actor = FetchActor(data_arrow->to_op_id_.Name());
if (to_actor == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get actor:" << data_arrow->to_op_id_.Name()
<< " from actor:" << condition_switch_actor->GetAID();
}
if (to_actor->type() != KernelTransformType::kConditionGatherActor &&
to_actor->type() != KernelTransformType::kKernelActor) {
MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
<< " from actor:" << condition_switch_actor->GetAID();
}
const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
MS_EXCEPTION_IF_NULL(to_kernel_actor);
MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
std::string current_branch_name;
if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
current_branch_name =
GetBranchNameByConditionGatherActor(condition_switch_actor, to_kernel_actor, data_arrow.get(), kernel_graph);
} else {
if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
<< to_kernel_actor->kernel()->fullname_with_scope()
<< " in actor:" << condition_switch_actor->GetAID();
}
MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
<< " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
<< " in actor:" << condition_switch_actor->GetAID()
<< " from index:" << data_arrow->from_output_index_ << " to actor:" << data_arrow->to_op_id_
<< " to index:" << data_arrow->to_input_index_;
current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
}
// Get branch index for output data arrow.
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
condition_switch_actor->branch_names_.end(), current_branch_name);
if (iter == condition_switch_actor->branch_names_.end()) {
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
<< " total branch name:" << condition_switch_actor->branch_names_
<< " from actor:" << condition_switch_actor->GetAID() << " to actor:" << to_actor->GetAID();
}
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
if (IntToSize(data_arrow->from_output_index_) >= output_num ||
branch_index >= condition_switch_actor->branch_names_.size()) {
MS_LOG(EXCEPTION) << "Invalid output index:" << data_arrow->from_output_index_ << " total:" << output_num
<< " and branch index:" << branch_index
<< " total:" << condition_switch_actor->branch_names_.size()
<< " for actor:" << condition_switch_actor->GetAID();
}
condition_switch_actor->branch_origin_ref_count_[branch_index][data_arrow->from_output_index_]++;
condition_switch_actor->output_data_branch_indexes_[i] = branch_index;
}
}
void InlineControlFlowScheduler::InitOutputControlBranchInfoForConditionSwitchActor(
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
condition_switch_actor->output_control_branch_indexes_.resize(condition_switch_actor->output_control_arrows().size());
// Get the index for each output control arrow.
for (size_t i = 0; i < condition_switch_actor->output_control_arrows().size(); ++i) {
const auto &arrow = condition_switch_actor->output_control_arrows()[i];
MS_EXCEPTION_IF_NULL(arrow);
const auto &to_actor = FetchActor(arrow->to_op_id_.Name());
if (to_actor == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get actor:" << arrow->to_op_id_.Name()
<< " from actor:" << condition_switch_actor->GetAID();
}
if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
condition_switch_actor->output_control_branch_indexes_[i] = SIZE_MAX;
continue;
}
if (to_actor->type() != KernelTransformType::kKernelActor &&
to_actor->type() != KernelTransformType::kConditionSwitchActor) {
MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
<< " from actor:" << condition_switch_actor->GetAID();
}
const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
MS_EXCEPTION_IF_NULL(to_kernel_actor);
MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
<< to_kernel_actor->kernel()->fullname_with_scope()
<< " in actor:" << condition_switch_actor->GetAID();
}
MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
<< " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
<< " in actor:" << condition_switch_actor->GetAID() << " to actor:" << arrow->to_op_id_;
const auto &current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
condition_switch_actor->branch_names_.end(), current_branch_name);
if (iter == condition_switch_actor->branch_names_.end()) {
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
<< " total branch name:" << condition_switch_actor->branch_names_
<< " for actor:" << condition_switch_actor->GetAID();
}
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
condition_switch_actor->output_control_branch_indexes_[i] = branch_index;
}
}
void InlineControlFlowScheduler::InitOutputBranchInfoForConditionSwitchActor(
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
if (condition_switch_actor->output_data_nodes().size() != condition_switch_actor->output_data_arrows().size()) {
MS_LOG(EXCEPTION) << "Invalid data node size:" << condition_switch_actor->output_data_nodes().size()
<< " and arrow size:" << condition_switch_actor->output_data_arrows().size()
<< " for actor:" << condition_switch_actor->GetAID();
}
InitOutputDataBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
InitOutputControlBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
MS_LOG(DEBUG) << "Branch origin ref count:" << condition_switch_actor->branch_origin_ref_count_
<< " output data branch index:" << condition_switch_actor->output_data_branch_indexes_
<< " output control branch index:" << condition_switch_actor->output_control_branch_indexes_
<< " for actor:" << condition_switch_actor->GetAID();
}
void InlineControlFlowScheduler::HandleConditionSwitchActor(const KernelActorPtr &kernel_actor) {
MS_EXCEPTION_IF_NULL(kernel_actor);
const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(kernel_actor.get());
MS_EXCEPTION_IF_NULL(condition_switch_actor);
MS_EXCEPTION_IF_NULL(condition_switch_actor->kernel());
const auto &graph = condition_switch_actor->kernel()->func_graph();
if (graph == nullptr || !graph->isa<KernelGraph>()) {
MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_switch_actor->GetAID();
}
const auto &kernel_graph = graph->cast<KernelGraphPtr>();
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
<< " by actor:" << condition_switch_actor->GetAID();
if (!condition_switch_actor->kernel()->HasAttr(kInlineSubGraphName)) {
MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_switch_actor->GetAID();
}
const auto &inline_sub_graph_names = condition_switch_actor->kernel()->GetAttr(kInlineSubGraphName);
MS_EXCEPTION_IF_NULL(inline_sub_graph_names);
MS_LOG(DEBUG) << "inline sub graph name:" << inline_sub_graph_names->ToString()
<< " for actor:" << condition_switch_actor->GetAID();
if (!inline_sub_graph_names->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Invalid input subgraph name:" << inline_sub_graph_names->ToString()
<< " for actor:" << condition_switch_actor->GetAID();
}
const auto &tuple_name = inline_sub_graph_names->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_name);
std::vector<std::string> branch_names;
for_each(tuple_name->value().begin(), tuple_name->value().end(),
[&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
condition_switch_actor->branch_names_ = branch_names;
// Fix ref count.
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
condition_switch_actor->branch_origin_ref_count_ =
std::vector<std::vector<size_t>>(tuple_name->size(), vector<size_t>(output_num, 0));
FixRefCountByConditionSwitchActor(condition_switch_actor, kernel_graph);
InitOutputBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
}
void InlineControlFlowScheduler::FixRefCountByKernelGraphRefMap(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph) {
// If the input node of ref node pair is a condition gather node , the ref count of corresponding gather node input
// should add 1.
for (const auto &ref_pair : kernel_graph->GetRefMap()) {
const auto &output_pair = ref_pair.first;
const auto &origin_pair = ref_pair.second;
MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
<< " origin node:" << origin_pair.first->fullname_with_scope();
const auto &recursive_origin_pair = kernel_graph->GetRefNodeRecursive(output_pair);
if (recursive_origin_pair.first == condition_gather_actor->kernel() && output_pair.first != nullptr) {
MS_LOG(DEBUG) << "Condition gather node output index:" << recursive_origin_pair.second
<< " is an input of ref node:" << output_pair.first->fullname_with_scope()
<< " to index:" << output_pair.second
<< " need update ref count for actor:" << condition_gather_actor->GetAID();
for (size_t i = recursive_origin_pair.second; i < common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
i += condition_gather_actor->branch_output_num_) {
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i);
MS_EXCEPTION_IF_NULL(device_address);
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
if (device_address->original_ref_count() == SIZE_MAX) {
continue;
}
size_t pre_origin_ref_count = device_address->original_ref_count();
device_address->set_original_ref_count(device_address->original_ref_count() + 1);
device_address->ResetRefCount();
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " fix ref count from:" << pre_origin_ref_count
<< " to:" << device_address->original_ref_count();
}
}
}
}
void InlineControlFlowScheduler::FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph) {
std::vector<size_t> total_ref_count;
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_gather_actor->kernel());
for (size_t i = 0; i < output_num; ++i) {
const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_gather_actor->kernel(), i);
MS_EXCEPTION_IF_NULL(device_address);
total_ref_count.emplace_back(device_address->original_ref_count());
MS_LOG(DEBUG) << "For actor:" << condition_gather_actor->GetAID() << " output device address:" << device_address
<< " output index:" << i << " ref_count:" << total_ref_count.back();
}
size_t input_num = common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
if (input_num == 0 || input_num % condition_gather_actor->branch_output_num_ != 0) {
MS_LOG(EXCEPTION) << "Invalid input num:" << input_num
<< " branch output num:" << condition_gather_actor->branch_output_num_
<< " for actor:" << condition_gather_actor->GetAID();
}
for (size_t i = 0; i < input_num; ++i) {
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i);
MS_EXCEPTION_IF_NULL(device_address);
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
if (device_address->original_ref_count() == SIZE_MAX) {
continue;
}
size_t pre_origin_ref_count = device_address->original_ref_count();
// The real ref count is the relative position of this branch output.
device_address->set_original_ref_count(device_address->original_ref_count() +
total_ref_count[i % condition_gather_actor->branch_output_num_] - 1);
device_address->ResetRefCount();
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
<< " input index:" << i << " fix ref count from:" << pre_origin_ref_count
<< " to:" << device_address->original_ref_count();
}
FixRefCountByKernelGraphRefMap(condition_gather_actor, kernel_graph);
}
void InlineControlFlowScheduler::InitInputDataBranchInfoForConditionGatherActor(
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
<< " by actor:" << condition_gather_actor->GetAID();
for (const auto &pair : condition_gather_actor->input_data_arrow_aids_) {
const auto &from_aid = pair.first;
const auto &data_arrow = pair.second;
MS_EXCEPTION_IF_NULL(data_arrow);
const auto &from_actor = FetchActor(from_aid.Name());
if (from_actor == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
}
if (from_actor->type() != KernelTransformType::kKernelActor &&
from_actor->type() != KernelTransformType::kConditionSwitchActor) {
MS_LOG(EXCEPTION) << "Invalid to actor:" << from_actor->GetAID()
<< " from actor:" << condition_gather_actor->GetAID();
}
const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
MS_EXCEPTION_IF_NULL(from_kernel_actor);
MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
std::string current_branch_name;
if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
current_branch_name =
GetBranchNameByConditionGatherActor(from_kernel_actor, condition_gather_actor, data_arrow, kernel_graph);
} else {
if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
<< from_kernel_actor->kernel()->fullname_with_scope()
<< " in actor:" << condition_gather_actor->GetAID();
}
MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
<< " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
<< " in actor:" << condition_gather_actor->GetAID();
current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
}
const auto &iter = condition_gather_actor->branch_name_to_id_.find(current_branch_name);
if (iter == condition_gather_actor->branch_name_to_id_.end()) {
condition_gather_actor->branch_name_to_id_[current_branch_name] =
condition_gather_actor->branch_name_to_id_.size();
MS_LOG(DEBUG) << "Add branch index:" << condition_gather_actor->branch_name_to_id_[current_branch_name]
<< " branch name:" << current_branch_name << " for actor:" << condition_gather_actor->GetAID();
}
// Get the input data num of each branch.
if (condition_gather_actor->branch_name_to_input_data_num_.find(current_branch_name) ==
condition_gather_actor->branch_name_to_input_data_num_.end()) {
condition_gather_actor->branch_name_to_input_data_num_[current_branch_name] = 1;
} else {
condition_gather_actor->branch_name_to_input_data_num_[current_branch_name]++;
}
}
}
void InlineControlFlowScheduler::InitInputControlBranchInfoForConditionGatherActor(
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
<< " by actor:" << condition_gather_actor->GetAID();
for (const auto &pair : condition_gather_actor->input_control_arrow_aids_) {
const auto &from_aid = pair.first;
const auto &from_actor = FetchActor(from_aid.Name());
if (from_actor == nullptr) {
MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
}
if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
continue;
}
if (from_actor->type() != KernelTransformType::kKernelActor &&
from_actor->type() != KernelTransformType::kConditionGatherActor) {
MS_LOG(EXCEPTION) << "Invalid from actor:" << from_actor->GetAID()
<< " to actor:" << condition_gather_actor->GetAID();
}
const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
MS_EXCEPTION_IF_NULL(from_kernel_actor);
MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
<< from_kernel_actor->kernel()->fullname_with_scope()
<< " in actor:" << condition_gather_actor->GetAID();
}
MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
<< " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
<< " in actor:" << condition_gather_actor->GetAID();
const auto &current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
// Get input op control num of each branch.
if (condition_gather_actor->branch_name_to_input_control_num_.find(current_branch_name) ==
condition_gather_actor->branch_name_to_input_control_num_.end()) {
condition_gather_actor->branch_name_to_input_control_num_[current_branch_name] = 1;
} else {
condition_gather_actor->branch_name_to_input_control_num_[current_branch_name]++;
}
}
}
void InlineControlFlowScheduler::InitInputBranchInfoForConditionGatherActor(
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
InitInputDataBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
InitInputControlBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
}
void InlineControlFlowScheduler::HandleConditionGatherActor(const KernelActorPtr &kernel_actor) {
const auto &condition_gather_actor = dynamic_cast<ConditionGatherActor *>(kernel_actor.get());
MS_EXCEPTION_IF_NULL(condition_gather_actor);
MS_EXCEPTION_IF_NULL(condition_gather_actor->kernel());
const auto &graph = condition_gather_actor->kernel()->func_graph();
if (graph == nullptr || !graph->isa<KernelGraph>()) {
MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_gather_actor->GetAID();
}
const auto &kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &gather_to_switch_iter = kernel_graph->condition_gather_to_switch().find(condition_gather_actor->kernel());
if (gather_to_switch_iter == kernel_graph->condition_gather_to_switch().end()) {
MS_LOG(EXCEPTION) << "Failed to get switch node by gather node:"
<< condition_gather_actor->kernel()->fullname_with_scope();
}
MS_EXCEPTION_IF_NULL(gather_to_switch_iter->second);
const auto &actor = FetchActor(gather_to_switch_iter->second->fullname_with_scope());
MS_EXCEPTION_IF_NULL(actor);
const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
MS_EXCEPTION_IF_NULL(condition_switch_actor);
condition_switch_actor->gather_aid_ = const_cast<AID *>(&condition_gather_actor->GetAID());
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchOutputNum)) {
MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
}
const auto &output_value = condition_gather_actor->kernel()->GetAttr(kAttrBranchOutputNum);
MS_EXCEPTION_IF_NULL(output_value);
condition_gather_actor->branch_output_num_ = GetValue<size_t>(output_value);
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchGraphName)) {
MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_gather_actor->GetAID();
}
const auto &branch_graph_names = condition_gather_actor->kernel()->GetAttr(kAttrBranchGraphName);
MS_EXCEPTION_IF_NULL(branch_graph_names);
MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
<< " for actor:" << condition_gather_actor->GetAID();
if (!branch_graph_names->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
<< " for actor:" << condition_gather_actor->GetAID();
}
const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_name);
std::vector<std::string> branch_names;
std::for_each(tuple_name->value().begin(), tuple_name->value().end(),
[&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
condition_gather_actor->branch_names_ = branch_names;
// Fix ref count.
FixRefCountByConditionGatherActor(condition_gather_actor, kernel_graph);
InitInputBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
}
void InlineControlFlowScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
LinkControlArrowByExecutionOrder(graph, graph_compiler_info);
}
}
for (const auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
if (kernel_actor->type() == KernelTransformType::kConditionSwitchActor) {
HandleConditionSwitchActor(kernel_actor);
} else if (kernel_actor->type() == KernelTransformType::kConditionGatherActor) {
HandleConditionGatherActor(kernel_actor);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
#include <string>
#include "runtime/graph_scheduler/actor/actor_set.h"
namespace mindspore {
namespace runtime {
class InlineControlFlowScheduler {
public:
InlineControlFlowScheduler() = default;
~InlineControlFlowScheduler() = default;
DISABLE_COPY_AND_ASSIGN(InlineControlFlowScheduler);
// Link control arrows and fix the member variables for condition actors.
void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
private:
void LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info);
// Fix the member variables for condition actors.
void HandleConditionSwitchActor(const KernelActorPtr &kernel_actor);
void HandleConditionGatherActor(const KernelActorPtr &kernel_actor);
// Init the output branch info for condition actor.
// For condition switch actor, the output arrow include all the output branch and should be distinguished.
void InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph);
void InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph);
void InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph);
void InitInputBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph);
void InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph);
void InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph);
// Fix ref count for condition actors.
// In condition switch actor, the ref count of actor should be change to total num for both branch.
// In condition gather actor, the ref count of gather input should add the ref count of gather output.
// The ref count of ref node should be add to the input of condition actor.
void FixRefCountByConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph);
void FixRefCountByKernelGraphRefMap(ConditionSwitchActor *const condition_switch_actor,
const KernelGraphPtr &kernel_graph);
void FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph);
void FixRefCountByKernelGraphRefMap(ConditionGatherActor *const condition_gather_actor,
const KernelGraphPtr &kernel_graph);
std::string GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
KernelActor *condition_gather_actor, DataArrow *data_arrow,
const KernelGraphPtr &kernel_graph);
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_

View File

@ -717,7 +717,9 @@ void SchedulerHelper::AddMemorySign(AbstractActor *const from_actor, AbstractAct
KernelGraphPtr SchedulerHelper::FetchKernelGraphByActor(AbstractActor *const actor) {
MS_EXCEPTION_IF_NULL(actor);
AnfNode *from_kernel = nullptr;
if (actor->type() == KernelTransformType::kKernelActor) {
if (actor->type() == KernelTransformType::kKernelActor ||
actor->type() == KernelTransformType::kConditionGatherActor ||
actor->type() == KernelTransformType::kConditionSwitchActor) {
auto kernel_actor = dynamic_cast<KernelActor *>(actor);
MS_EXCEPTION_IF_NULL(kernel_actor);
from_kernel = kernel_actor->kernel().get();

View File

@ -119,7 +119,15 @@ bool IsMultiLayerTuple(const abstract::AbstractBasePtr &abstract) {
});
}
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node) {
namespace {
bool IsMultiOutput(const AnfNodePtr &node) {
return node != nullptr && node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>() &&
node->abstract()->cast<abstract::AbstractSequencePtr>()->size() > 1;
}
} // namespace
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Output node: " << node->fullname_with_scope();
std::vector<KernelWithIndex> ret;
@ -129,17 +137,26 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->size(); i++) {
auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i));
auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i), return_types);
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
}
return ret;
}
if (std::any_of(return_types.begin(), return_types.end(), [&node](const PrimitivePtr &prim_type) -> bool {
return common::AnfAlgo::CheckPrimitiveType(node, prim_type);
})) {
if (IsMultiOutput(node)) {
MS_LOG(EXCEPTION) << "Invalid get all output with index node:" << node->DebugString()
<< " abstract:" << node->abstract()->ToString();
}
MS_LOG(DEBUG) << "Need node flatten output of node:" << node->DebugString();
return {KernelWithIndex(node, 0)};
}
// The depend node need get the real node.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
auto depend_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_node);
auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend));
auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend), return_types);
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
return ret;
}
@ -196,7 +213,7 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
// The MakeTuple/MakeSparse node need recurse.
if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
auto output_vector = GetAllOutputWithIndexInner(output_with_index.first);
auto output_vector = GetAllOutputWithIndexInner(output_with_index.first, return_types);
if (output_vector.size() <= output_with_index.second) {
MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_with_index.second
<< " for outputs of node:" << output_with_index.first->DebugString();
@ -438,8 +455,9 @@ std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithOutMonadAndParameter(const
return real_output;
}
std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node) {
auto ret = GetAllOutputWithIndexInner(node);
std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types) {
auto ret = GetAllOutputWithIndexInner(node, return_types);
std::map<AnfNodePtr, size_t> value_node_index;
// Unify the output of the front and back end to the ValueTuple

View File

@ -169,12 +169,17 @@ GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>(kCallOpName));
GVAR_DEF(PrimitivePtr, kPrimRaise,
std::make_shared<Primitive>(kRaiseOpName, mindspore::HashMap<std::string, ValuePtr>(
{{std::string(GRAPH_FLAG_SIDE_EFFECT_IO), MakeValue(true)}})));
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("call_inline"));
GVAR_DEF(PrimitivePtr, kPrimSwitchLayer, std::make_shared<Primitive>("switch_layer"));
GVAR_DEF(PrimitivePtr, kPrimStringUpper, std::make_shared<Primitive>(kStringUpperOpName));
GVAR_DEF(PrimitivePtr, kPrimStringLower, std::make_shared<Primitive>(kStringLowerOpName));
GVAR_DEF(PrimitivePtr, kPrimFormat, std::make_shared<Primitive>(kFormatOpName));
// Backend Inline
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("CallInline"));
GVAR_DEF(PrimitivePtr, kPrimPartialInline, std::make_shared<Primitive>("PartialInline"));
GVAR_DEF(PrimitivePtr, kPrimConditionSwitch, std::make_shared<Primitive>("ConditionSwitch"));
GVAR_DEF(PrimitivePtr, kPrimConditionGather, std::make_shared<Primitive>("ConditionGather"));
// Pack
GVAR_DEF(PrimitivePtr, kPrimPackFunc, std::make_shared<Primitive>(kPackFuncOpName));
} // namespace prim

View File

@ -602,7 +602,8 @@ bool AnfUtils::NeedJumpMonadOutput(const AnfNodePtr &node) {
return false;
}
std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName};
std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName, prim::kPrimConditionSwitch->name(),
prim::kPrimConditionGather->name()};
if (std::find(jump_monad_output_nodes.begin(), jump_monad_output_nodes.end(), GetCNodeName(cnode)) !=
jump_monad_output_nodes.end()) {
return true;

View File

@ -0,0 +1,672 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from mindspore import context, Tensor, jit, ops, mutable
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path='./log/')
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_if():
"""
Feature: Contrtol flow inline.
Description: Inline switch node into kernel graph.
Expectation: Not throw exception.
"""
param_a = Parameter(Tensor(5, mstype.int32), name='a')
param_b = Parameter(Tensor(4, mstype.int32), name='b')
@jit
def foo(x, y, param_a, param_b):
if param_a > param_b:
param_b += 1
return x + param_b, y + param_b
x = Tensor(2, mstype.int32)
ret1 = foo(x, x, param_a, param_b)
ret2 = foo(x, x, param_a, param_b)
assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32))
assert ret2
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_parameter():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
param_b = Parameter(Tensor(5))
@jit
def foo(x, param_a, param_b):
if x < 3:
return param_a
return param_b
ret1 = foo(Tensor(1), param_a, param_b)
assert ret1
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_valuenode():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, param_a, param_b):
if x < 3:
return 1
return 2
ret1 = foo(Tensor(1), param_a, param_b)
assert ret1
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_input():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y, z):
if x < 3:
return y
return z
ret1 = foo(Tensor(1), Tensor(2), Tensor(3))
assert ret1
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_value_node_output_in_single_branch():
"""
Feature: Contrtol flow inline.
Description: Inline switch node into kernel graph.
Expectation: Not throw exception.
"""
@jit
def BranchReturnTensor(x, y):
x = x + Tensor(2, mstype.int32)
y = x + y
if x < 5:
return y, Tensor(2, mstype.int32)
return x, y
x = Tensor(2, mstype.int32)
ret1 = BranchReturnTensor(x, x)
ret2 = BranchReturnTensor(x, x)
ret3 = BranchReturnTensor(x, x)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_diff_ref_count_in_branch():
"""
Feature: Contrtol flow inline.
Description: Inline switch node into kernel graph.
Expectation: Not throw exception.
"""
@jit
def BranchDiffRefCount(x, y):
x = x + Tensor(2, mstype.int32)
y = x + y
if x < 5:
x = x + 3
y = x + y
else:
x = x + 3
x = x + 4
x = x + 5
y = x + y
y = x + y
y = x + y
return x, y
x = Tensor(2, mstype.int32)
ret1 = BranchDiffRefCount(x, x)
x = Tensor(4, mstype.int32)
ret2 = BranchDiffRefCount(x, x)
assert ret1
assert ret2
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_kernel_backoff():
"""
Feature: Contrtol flow inline.
Description: Inline switch node into kernel graph.
Expectation: Not throw exception.
"""
@jit
def BranchKernelBackOff(x, y, shape):
x = x + Tensor(2, mstype.int32)
if y < 5:
z = ops.reshape(x, shape)
else:
z = x
return z + 1
x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
y = Tensor(2, mstype.int32)
ret1 = BranchKernelBackOff(x, y, mutable((2, 3)))
ret2 = BranchKernelBackOff(x, y, mutable((2, 3)))
ret3 = BranchKernelBackOff(x, y, mutable((2, 3)))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_update_parameter():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
@jit
def foo(x, param_a):
x = x + param_a
if x < 3:
param_a = param_a + 2
else:
param_a = param_a + x
return param_a
ret1 = foo(Tensor(1), param_a)
ret2 = foo(Tensor(1), param_a)
ret3 = foo(Tensor(1), param_a)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_update_and_return_parameter():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
param_b = Parameter(Tensor(5))
@jit
def foo(x, param_a, param_b):
x = x + param_a
if x < 3:
param_a = param_a + 2
param_b = param_b - param_a
return Tensor(2), param_b
param_a = param_a + x
param_b = param_b + param_a
return param_a, param_b
ret1 = foo(Tensor(1), param_a, param_b)
ret2 = foo(Tensor(1), param_a, param_b)
ret3 = foo(Tensor(1), param_a, param_b)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_switch_input_in_branch():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
param_b = Parameter(Tensor(5))
@jit
def foo(x, param_a, param_b):
x = x + param_a
if x < 3:
param_a = param_a + 2
param_b = param_b - param_a
return x, param_b
param_a = param_a + x
param_b = param_b + param_a
return param_a, param_b
ret1 = foo(Tensor(1), param_a, param_b)
ret2 = foo(Tensor(1), param_a, param_b)
ret3 = foo(Tensor(1), param_a, param_b)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_switch_input():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
param_b = Parameter(Tensor(5))
@jit
def foo(x, param_a, param_b):
x = x + param_a
if x < 3:
param_a = param_a + 2
param_b = param_b - param_a
else:
param_a = param_a + x
param_b = param_b + param_a
return x, param_b, 3
ret1 = foo(Tensor(1), param_a, param_b)
ret2 = foo(Tensor(1), param_a, param_b)
ret3 = foo(Tensor(1), param_a, param_b)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tuple_args_to_dynamic_tuple_para():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
y_shape = ops.shape(y)
if x < 3:
y_shape = y_shape * 2
else:
y_shape = y_shape * 3
return y_shape[0]
ret1 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
ret2 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
ret3 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tuple_input_to_switch():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y, dst_shape):
y, _ = ops.unique(y)
y = ops.reshape(y, dst_shape)
y_shape = ops.shape(y)
if x < 3:
y_shape = y_shape * 2
else:
y_shape = y_shape * 3
return y_shape
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_tuple_input_to_switch():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, dyn_tuple):
if x < 3:
dyn_tuple = dyn_tuple * 2
else:
dyn_tuple = dyn_tuple * 3
return dyn_tuple
ret1 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
ret2 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
ret3 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_condition():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, cond):
if cond:
x = x * 2
return x, cond
x = x * 3
return x, cond
ret1 = foo(Tensor(1), Tensor(True))
ret2 = foo(Tensor(1), Tensor(True))
ret3 = foo(Tensor(1), Tensor(True))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_include_other_output():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
y = y + 2
y = y * 3
y = y / 4
y = y - 5
y = y * y
if x < 5:
x = x * 2
else:
x = x + 2
return x, y
ret1 = foo(Tensor(1), Tensor(2))
ret2 = foo(Tensor(1), Tensor(2))
ret3 = foo(Tensor(1), Tensor(2))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_output_include_refnode():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y, dst_shape):
y, _ = ops.unique(y)
y = ops.reshape(y, dst_shape)
if x < 3:
y = ops.expand_dims(y, 1)
y = ops.flatten(y)
return y
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]), mutable((2, 3)))
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_include_dynamic_shape():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
y, _ = ops.unique(y)
if x < 3:
y = y * 2
else:
z1 = y / 6
z2 = y * 2
z3 = y - Tensor([[6, 12, 18], [24, 30, 36]])
z4 = y + Tensor([[1, 2, 3], [4, 5, 6]])
y = z1 + z2 + z3 + z4
return y
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [12, 18, 30], [18, 24, 36]]))
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]))
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_control_arrow_from_switch_to_gather():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
param_a = Parameter(Tensor(5))
param_b = Parameter(Tensor(5))
@jit
def foo(x, param_a, param_b):
x = x + param_a
if x < 3:
param_a = param_a + 2
param_b = param_b - param_a
return Tensor(2), param_b
x = x + param_a
return param_a, param_b
ret1 = foo(Tensor(1), param_a, param_b)
ret2 = foo(Tensor(1), param_a, param_b)
ret3 = foo(Tensor(1), param_a, param_b)
assert ret1
assert ret2
assert ret3
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_only_u_input():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
x = x + 1
if x < 3:
ops.print("this is true")
else:
y = ops.reshape(y, (4, 1))
ops.print("this is false")
return ops.shape(y)
ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
assert ret1
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_u_input_and_input():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
x = x + 1
if x < 3:
ops.print("this is true")
else:
y = ops.reshape(y, (4, 1))
ops.print("this is false")
return ops.shape(y)
ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
assert ret1
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_output_real_tuple():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y):
if x < 3:
y, _ = ops.unique(y)
y = ops.expand_dims(y, 1)
y = ops.flatten(y)
z = ops.shape(y)
else:
z = ops.shape(y)
return z
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
ret2 = foo(Tensor(5), Tensor([[6, 12, 18], [24, 30, 36]]))
assert ret1
assert ret2
@pytest.mark.skip(reason="No support")
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_branch_output_dynamic_tuple():
"""
Feature: Contrtol flow inline.
Description: Control flow if.
Expectation: AttributeError.
"""
@jit
def foo(x, y, shape):
if y < 5:
z = ops.reshape(x, shape)
out = ops.shape(z)
else:
out = ops.shape(x)
return out
x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
y = Tensor(2, mstype.int32)
ret1 = foo(x, y, mutable((2, 3), dynamic_len=True))
assert ret1