!66115 Support inline control flow.
Merge pull request !66115 from gaoyong10/dyn-shape-dev-2
This commit is contained in:
commit
cc6ae8c1c8
|
@ -84,11 +84,14 @@ void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
|
|||
auto partial_node = dyn_cast<CNode>(node);
|
||||
const auto &partial_inputs = partial_node->inputs();
|
||||
MS_EXCEPTION_IF_NULL(partial_inputs.at(0));
|
||||
if (!IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
|
||||
if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartial)) {
|
||||
MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
|
||||
switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
|
||||
} else if (IsPrimitive(partial_inputs.at(0), prim::kPrimPartialInline)) {
|
||||
switch_subgraph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(partial_node, kAttrKernelGraph);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(partial_inputs.at(kPartialArgsIndex));
|
||||
switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
switch_subgraph = GetValueNode<FuncGraphPtr>(node);
|
||||
} else {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "ops/arithmetic_ops.h"
|
||||
#include "ops/nn_ops.h"
|
||||
#include "ops/sequence_ops.h"
|
||||
#include "ops/framework_ops.h"
|
||||
#include "ops/op_def.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
|
@ -494,6 +495,9 @@ const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph,
|
|||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
|
||||
return nullptr;
|
||||
}
|
||||
if ((node->kernel_info() == nullptr) ||
|
||||
(!dynamic_cast<device::KernelInfo *>(node->kernel_info())->has_build_info()) ||
|
||||
(common::AnfAlgo::GetCNodeName(node) == "MakeTuple")) {
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/pass/switch_not_cut.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ops/other_ops.h"
|
||||
#include "ops/framework_ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool SwitchNotCut::Run(const FuncGraphPtr &func_graph) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
static const bool is_enable_ge = (context->backend_policy() == "ge");
|
||||
if (!is_enable_ge) {
|
||||
// only support ge backend
|
||||
return false;
|
||||
}
|
||||
static const auto is_enable_switch_inline = (common::GetEnv("MS_ENABLE_SWITCH_INLINE") == "1");
|
||||
if (!is_enable_switch_inline) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
AnfNodePtr return_node = func_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter == manager->node_users().end()) {
|
||||
continue;
|
||||
}
|
||||
if (!std::any_of(iter->second.begin(), iter->second.end(), [](const std::pair<AnfNodePtr, int> &node_index) {
|
||||
return IsPrimitiveCNode(node_index.first, prim::kPrimSwitch);
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto partial_graph = cnode->input(kIndex1);
|
||||
auto sub_graph = common::AnfAlgo::GetValueNodeFuncGraph(partial_graph);
|
||||
sub_graph->set_flag(kFlagSwitchInline, true);
|
||||
}
|
||||
if (IsOneOfPrimitiveCNode(node, {prim::kPrimPartial, prim::kPrimSwitch})) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->AddPrimalAttr(kAttrNotCut, MakeValue(true));
|
||||
} else if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primitive_input = cnode->input(kAnfPrimitiveIndex);
|
||||
if (IsPrimitiveCNode(primitive_input, prim::kPrimSwitch)) {
|
||||
cnode->AddPrimalAttr(kAttrNotCut, MakeValue(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_
|
||||
#include <string>
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SwitchNotCut : public Pass {
|
||||
public:
|
||||
explicit SwitchNotCut(const std::string &name = "switch_not_cut") : Pass(name) {}
|
||||
~SwitchNotCut() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_SWITCH_NOT_CUT_H_
|
|
@ -1578,7 +1578,8 @@ void AnfRuntimeAlgorithm::InsertMakeTupleForOutput(const NotNull<KernelGraphPtr>
|
|||
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
|
||||
MS_EXCEPTION_IF_NULL(root_graph->output());
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
make_tuple->set_abstract({root_graph->output()->abstract()});
|
||||
abstract::AbstractBasePtrList abs_list{root_graph->output()->abstract()};
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
root_graph->set_output(make_tuple);
|
||||
}
|
||||
|
||||
|
|
|
@ -1957,7 +1957,7 @@ KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, c
|
|||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
if (IsPrimitiveCNode(new_cnode, prim::kPrimCall)) {
|
||||
if (IsOneOfPrimitiveCNode(new_cnode, {prim::kPrimCall, prim::kPrimPartial})) {
|
||||
auto fn = new_cnode->input(kIndexOne);
|
||||
MS_EXCEPTION_IF_NULL(fn);
|
||||
auto child_kernel_graph = AnfRuntimeAlgorithm::GetValueNodeKernelGraph(fn);
|
||||
|
@ -2809,22 +2809,52 @@ void CopyCNodeInfo(const FuncGraphPtr &func_graph, const uint32_t &target_graph_
|
|||
common::AnfAlgo::SetNodeAttr(kAttrPreKernelGraph, MakeValue(func_graph), new_node);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateConditionNodePair(const KernelGraphPtr &kernel_graph, const KernelGraphPtr &target_kernel_graph,
|
||||
const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &condition_node_map) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &gather_to_switch = kernel_graph->condition_gather_to_switch();
|
||||
for (const auto &pair : gather_to_switch) {
|
||||
MS_EXCEPTION_IF_NULL(pair.first);
|
||||
MS_EXCEPTION_IF_NULL(pair.second);
|
||||
const auto &gather_iter = condition_node_map.find(pair.first);
|
||||
const auto &switch_iter = condition_node_map.find(pair.second);
|
||||
if (gather_iter == condition_node_map.end() || switch_iter == condition_node_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get new gather node:" << pair.first->fullname_with_scope()
|
||||
<< " or switch node:" << pair.second->fullname_with_scope()
|
||||
<< " in graph:" << kernel_graph->ToString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(gather_iter->second);
|
||||
MS_EXCEPTION_IF_NULL(switch_iter->second);
|
||||
if (target_kernel_graph != nullptr) {
|
||||
target_kernel_graph->AddConditionGatherSwitchPair(gather_iter->second, switch_iter->second);
|
||||
MS_LOG(INFO) << "Add condition node pair:" << gather_iter->second->fullname_with_scope()
|
||||
<< " and:" << switch_iter->second->fullname_with_scope()
|
||||
<< " for graph:" << target_kernel_graph->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
|
||||
const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
|
||||
const uint32_t &target_graph_id,
|
||||
const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
|
||||
const KernelGraphPtr &graph) {
|
||||
const KernelGraphPtr &graph, bool is_switch_inline) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(target_func_graph);
|
||||
KernelGraphPtr target_kernel_graph = nullptr;
|
||||
if (target_func_graph->isa<KernelGraph>()) {
|
||||
target_kernel_graph = target_func_graph->cast<KernelGraphPtr>();
|
||||
}
|
||||
Cloner cloner({}, false);
|
||||
if (scope != nullptr) {
|
||||
cloner.set_scope(scope);
|
||||
}
|
||||
cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
|
||||
auto node_list = TopoSort(func_graph->output());
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_node_map;
|
||||
for (auto &ori_node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(ori_node);
|
||||
if (ori_node->isa<Parameter>()) {
|
||||
|
@ -2837,8 +2867,34 @@ AnfNodePtr KernelGraphMgr::DoInline(const FuncGraphPtr &func_graph, const FuncGr
|
|||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
graph->AddValueNodeToGraph(value_node);
|
||||
}
|
||||
// Add sub graph kernel for switch inline kernel graph.
|
||||
if (new_node->isa<CNode>() && target_kernel_graph != nullptr && is_switch_inline) {
|
||||
MS_LOG(DEBUG) << "Add inline sub graph for kernel:" << new_node->fullname_with_scope()
|
||||
<< " graph:" << func_graph->ToString();
|
||||
std::string sub_graph_name = func_graph->ToString();
|
||||
if (func_graph->isa<KernelGraph>()) {
|
||||
const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &sub_graph_iter = kernel_graph->inline_sub_graph_kernels().find(ori_node);
|
||||
if (sub_graph_iter != kernel_graph->inline_sub_graph_kernels().end()) {
|
||||
sub_graph_name = sub_graph_iter->second;
|
||||
}
|
||||
}
|
||||
target_kernel_graph->AddInlineSubgraphKernel(new_node, sub_graph_name);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionGather) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(new_node, prim::kPrimConditionSwitch)) {
|
||||
condition_node_map[ori_node] = new_node;
|
||||
}
|
||||
}
|
||||
CopyCNodeInfo(func_graph, target_graph_id, ori_node, new_node);
|
||||
}
|
||||
// Collect condition gather node and condition switch node.
|
||||
if (func_graph->isa<KernelGraph>() && is_switch_inline) {
|
||||
const auto &kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
UpdateConditionNodePair(kernel_graph, target_kernel_graph, condition_node_map);
|
||||
}
|
||||
|
||||
for (const auto &kv : ref_map) {
|
||||
auto final_pair = kv.first;
|
||||
auto origin_pair = kv.second;
|
||||
|
|
|
@ -104,7 +104,7 @@ class BACKEND_EXPORT KernelGraphMgr {
|
|||
const AnfNodePtrList &func_graph_args, const ScopePtr &scope,
|
||||
const uint32_t &target_graph_id,
|
||||
const std::map<session::AnfWithOutIndex, session::AnfWithOutIndex> &ref_map,
|
||||
const KernelGraphPtr &graph);
|
||||
const KernelGraphPtr &graph, bool is_switch_inline);
|
||||
|
||||
private:
|
||||
void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;
|
||||
|
|
|
@ -2239,6 +2239,14 @@ SomasNodePtr Somas::GetSomasNode(size_t node_id) const {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsNopMakeTuple(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_num = common::AnfAlgo::GetInputNum(cnode);
|
||||
return IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && input_num == 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
common::KernelWithIndex Somas::GetVisitKernelWithReturnType(const AnfNodePtr &ori_node, size_t ori_index) {
|
||||
auto prenode = common::AnfAlgo::VisitKernelWithReturnType(ori_node, ori_index, false);
|
||||
MS_EXCEPTION_IF_NULL(prenode.first);
|
||||
|
@ -2247,7 +2255,7 @@ common::KernelWithIndex Somas::GetVisitKernelWithReturnType(const AnfNodePtr &or
|
|||
auto anf_node = prenode.first;
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!common::AnfAlgo::IsNopNode(cnode)) {
|
||||
if (!common::AnfAlgo::IsNopNode(cnode) && !IsNopMakeTuple(cnode)) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "Node[" << ori_node->fullname_with_scope() << "] find input node["
|
||||
<< cnode->fullname_with_scope()
|
||||
<< "] doesn't exist in nodes_map and is not a nop node!!!!";
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "pipeline/jit/ps/parse/data_converter.h"
|
||||
#include "backend/graph_compiler/transform.h"
|
||||
#include "backend/common/pass/erase_invalid_micro_depend.h"
|
||||
#include "backend/common/pass/switch_not_cut.h"
|
||||
#include "include/backend/distributed/recovery/recovery_context.h"
|
||||
#include "include/common/utils/callbacks.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
|
@ -530,6 +531,7 @@ namespace {
|
|||
void DoUnifyMindIRPass(const FuncGraphPtr &graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_LOG(INFO) << "Do unify mindir pass for graph " << graph->ToString();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (context_ptr->CanDump(kIntroductory)) {
|
||||
std::string file_name = "hwopt_before_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
|
||||
|
@ -539,6 +541,7 @@ void DoUnifyMindIRPass(const FuncGraphPtr &graph) {
|
|||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::EraseInvalidMicroDepend>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::SwitchNotCut>());
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
@ -597,6 +600,11 @@ void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) const {
|
|||
}
|
||||
}
|
||||
DoUnifyMindIRPass(root_graph);
|
||||
const auto &sub_graphs = root_graph->manager()->func_graphs_used_total(root_graph);
|
||||
for (const auto &sub_graph : sub_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
DoUnifyMindIRPass(sub_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
|
||||
|
@ -617,7 +625,8 @@ void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::
|
|||
for (const auto &sub_graph : cand_graph) {
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
bool skip_inline_graph =
|
||||
sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline;
|
||||
(sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) ||
|
||||
sub_graph->has_flag(kFlagSwitchInline);
|
||||
if (sub_graph != func_graph && sub_graph != nullptr && !sub_graph->has_flag(kFlagJitCallGraph) &&
|
||||
!skip_inline_graph) {
|
||||
MS_LOG(INFO) << "Compile sub graph " << sub_graph->ToString();
|
||||
|
|
|
@ -536,6 +536,19 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
void InferType();
|
||||
void PostNewCNode(const CNodePtr &cnode) const;
|
||||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||
void AddInlineSubgraphKernel(const AnfNodePtr &node, const std::string &graph_name) {
|
||||
inline_sub_graph_kernels_[node] = graph_name;
|
||||
}
|
||||
const mindspore::HashMap<AnfNodePtr, std::string> &inline_sub_graph_kernels() const {
|
||||
return inline_sub_graph_kernels_;
|
||||
}
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch() const { return condition_gather_to_switch_; }
|
||||
void AddConditionGatherSwitchPair(const AnfNodePtr &condition_gather, const AnfNodePtr &condition_switch) {
|
||||
condition_gather_to_switch_[condition_gather] = condition_switch;
|
||||
}
|
||||
void RemoveConditionGatherSwitchPair(const AnfNodePtr &condition_gather) {
|
||||
condition_gather_to_switch_.erase(condition_gather);
|
||||
}
|
||||
|
||||
void set_is_from_pynative(const bool &from_pynative) { from_pynative_ = from_pynative; }
|
||||
bool is_from_pynative() const { return from_pynative_; }
|
||||
|
@ -577,6 +590,10 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
|
||||
// parameters that will be updated when graph is executed
|
||||
mindspore::HashSet<ParameterPtr> updated_parameters_;
|
||||
// Kernel in inline subgraph for switch node.
|
||||
mindspore::HashMap<AnfNodePtr, std::string> inline_sub_graph_kernels_;
|
||||
// Record the relationship between condition gather and condition switch.
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch_;
|
||||
|
||||
// graph needn't execute
|
||||
bool executable_{false};
|
||||
|
@ -628,6 +645,7 @@ class BACKEND_EXPORT KernelGraph : public FuncGraph {
|
|||
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_;
|
||||
// key:parallel op ptr, value:vector of <send op receive op > pairs
|
||||
mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_;
|
||||
|
||||
std::atomic<size_t> pre_graph_finished_count_{0};
|
||||
std::atomic<size_t> post_graph_finished_count_{0};
|
||||
bool first_step_{true};
|
||||
|
|
|
@ -65,7 +65,8 @@ class COMMON_EXPORT AnfAlgo {
|
|||
static std::vector<KernelWithIndex> GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {},
|
||||
bool need_make_tuple = false);
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node);
|
||||
// get cnode primitive
|
||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||
|
|
|
@ -389,6 +389,9 @@ constexpr char kAttrTransposeX1[] = "transpose_x1";
|
|||
constexpr char kAttrTransposeX2[] = "transpose_x2";
|
||||
constexpr char kAttrCommTurn[] = "comm_turn";
|
||||
constexpr char kAttrGatherIndex[] = "gather_index";
|
||||
constexpr char kAttrBranchOutputNum[] = "branch_output_num";
|
||||
constexpr char kAttrBranchGraphName[] = "branch_graph_name";
|
||||
constexpr char kInlineSubGraphName[] = "inline_sub_graph_name";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagIsPynativeBpropGraph = "is_pynative_bprop_graph";
|
||||
|
@ -402,6 +405,7 @@ constexpr auto kFlagIsPyNativeBpropKernelGraph = "is_pynative_bprop_kernel_graph
|
|||
constexpr auto kFlagPyNativeWithJitCallGraph = "pynative_with_jit_call_graph";
|
||||
constexpr auto kFlagJitCallGraph = "jit_call_graph";
|
||||
constexpr auto kFlagJitGraph = "jit_graph";
|
||||
constexpr auto kFlagSwitchInline = "switch_inline_graph";
|
||||
constexpr auto kAttrPackFunction = "pack_func";
|
||||
|
||||
// custom operator func type
|
||||
|
|
|
@ -560,8 +560,9 @@ std::tuple<bool, std::string, ExceptionType> SelectKernelInfoWithMsg(const Kerne
|
|||
}
|
||||
|
||||
// for backend inline
|
||||
if (IsPrimitiveCNode(node, prim::kPrimCallInline)) {
|
||||
GenerateKernelBuildInfo(node, KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
if (IsOneOfPrimitiveCNode(node, {prim::kPrimCallInline, prim::kPrimSwitch, prim::kPrimPartialInline,
|
||||
prim::kPrimConditionSwitch, prim::kPrimConditionGather})) {
|
||||
GenerateKernelBuildInfo(node, KernelType::RT_KERNEL);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
#include <vector>
|
||||
#include "include/backend/optimizer/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
|
||||
#include "ops/framework_op_name.h"
|
||||
#include "mindspore/core/ops/framework_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -100,7 +100,42 @@ void AclSomas::InitEventInfo(const session::KernelGraph &graph) {
|
|||
MS_LOG(DEBUG) << "Acl Somas InitEventInfo end.";
|
||||
}
|
||||
|
||||
bool AclSomas::DevSpecNodeProcess(const session::KernelGraph &graph) { return true; }
|
||||
bool AclSomas::RuntimeNodeProcess(const session::KernelGraph &graph) {
|
||||
auto &kernels = graph.execution_order();
|
||||
for (auto &kernel : kernels) {
|
||||
if (!IsPrimitiveCNode(kernel, {prim::kPrimConditionGather})) {
|
||||
continue;
|
||||
}
|
||||
auto iter = nodes_map_.find(kernel.get());
|
||||
if (iter != nodes_map_.end()) {
|
||||
auto &node = iter->second.at(0);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_tensors = node->input_tensors_;
|
||||
auto output_tensors = node->output_tensors_;
|
||||
constexpr size_t value_two = 2;
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_tensors.size() == output_tensors.size() * value_two,
|
||||
"Invalid input and output tensors size" + std::to_string(input_tensors.size()) + ", " +
|
||||
std::to_string(output_tensors.size()));
|
||||
std::vector<std::vector<size_t>> union_tensors;
|
||||
for (auto &tensor : output_tensors) {
|
||||
tensor->type_ = somas::kUnion;
|
||||
union_tensors.push_back({tensor->GetId()});
|
||||
}
|
||||
for (size_t i = 0; i < output_tensors.size(); i++) {
|
||||
input_tensors[i]->type_ = somas::kUnion;
|
||||
input_tensors[i + output_tensors.size()]->type_ = somas::kUnion;
|
||||
union_tensors[i].push_back(input_tensors[i]->GetId());
|
||||
union_tensors[i].push_back(input_tensors[i + output_tensors.size()]->GetId());
|
||||
}
|
||||
union_tensors_list_.insert(union_tensors_list_.end(), union_tensors.begin(), union_tensors.end());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Can't find somas node for inplace node " << kernel->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AclSomas::DevSpecNodeProcess(const session::KernelGraph &graph) { return RuntimeNodeProcess(graph); }
|
||||
|
||||
void AclSomas::CommunicationTensorProcess(const std::vector<somas::SomasTensorPtr> &tensors) const {
|
||||
if (tensors.size() != ALONE) {
|
||||
|
|
|
@ -43,6 +43,7 @@ class AclSomas : public somas::Somas {
|
|||
|
||||
bool InitDevSpecControlTensors(const session::KernelGraph &graph) override;
|
||||
bool DevSpecNodeProcess(const session::KernelGraph &graph) override;
|
||||
bool RuntimeNodeProcess(const session::KernelGraph &graph);
|
||||
|
||||
void InitEventInfo(const session::KernelGraph &graph);
|
||||
std::map<uint32_t, somas::EventPair> event_map_;
|
||||
|
|
|
@ -60,6 +60,11 @@
|
|||
|
||||
namespace mindspore::device::ascend {
|
||||
namespace {
|
||||
constexpr size_t kSwitchInputSize = 3;
|
||||
constexpr size_t kSwitchCondIndex = 1;
|
||||
constexpr size_t kSwitchBranchTrueIndex = 2;
|
||||
constexpr size_t kSwitchBranchFalseIndex = 3;
|
||||
|
||||
bool GenerateKernelMod(const std::vector<CNodePtr> &kernels) {
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
|
@ -83,6 +88,8 @@ bool GenerateKernelMod(const std::vector<CNodePtr> &kernels) {
|
|||
} else if (kernel_type == KernelType::AKG_KERNEL) {
|
||||
kernel_mod_ptr = kernel::DvmOpBuild(kernel);
|
||||
#endif
|
||||
} else if (AnfAlgo::GetKernelType(kernel) == KernelType::RT_KERNEL) {
|
||||
kernel_mod_ptr = kernel::RtOpBuild(kernel);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The kernel: " << kernel->fullname_with_scope() << " kernel build failed, kernel type: "
|
||||
<< kernel::KernelTypeLabel(AnfAlgo::GetKernelType(kernel));
|
||||
|
@ -128,6 +135,15 @@ void SetAclOpPrecisionMode() {
|
|||
}
|
||||
}
|
||||
|
||||
void SelectKernelInfo(const KernelGraphPtr &kernel_graph, const CNodePtr &kernel) {
|
||||
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(kernel_graph, kernel);
|
||||
if (!select_res) {
|
||||
MS_LOG(INFO) << "node is " << kernel->fullname_with_scope() << " should backoff";
|
||||
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
|
||||
device::ascend::HandleKernelSelectFailure(kernel_graph, kernel, failure_info);
|
||||
}
|
||||
}
|
||||
|
||||
void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *const memo) {
|
||||
// select kernel
|
||||
MS_EXCEPTION_IF_NULL(memo);
|
||||
|
@ -137,12 +153,7 @@ void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *
|
|||
memo->insert(kernel_graph);
|
||||
const auto &kernels = kernel_graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(kernel_graph, kernel);
|
||||
if (!select_res) {
|
||||
MS_LOG(INFO) << "node is " << kernel->fullname_with_scope() << " should backoff";
|
||||
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
|
||||
device::ascend::HandleKernelSelectFailure(kernel_graph, kernel, failure_info);
|
||||
}
|
||||
SelectKernelInfo(kernel_graph, kernel);
|
||||
}
|
||||
if (!kernel_graph->is_from_single_op()) {
|
||||
kernel_graph->SetKernelObjectTypesForUnrealNodes();
|
||||
|
@ -152,7 +163,54 @@ void SelectKernel(const KernelGraphPtr &kernel_graph, std::set<KernelGraphPtr> *
|
|||
}
|
||||
}
|
||||
|
||||
void InlineSubGraph(const KernelGraphPtr &graph) {
|
||||
void InlineSubGraph(const KernelGraphPtr &graph, CNodePtr kernel_cnode, AnfNodePtr *last_call, bool is_switch_inline) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_cnode);
|
||||
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->fullname_with_scope() << ", sub graph: " << sub_graph->graph_id()
|
||||
<< ", need inline: " << sub_graph->need_inline();
|
||||
auto main_graph = kernel_cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(main_graph);
|
||||
auto mng = main_graph->manager();
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_cnode->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
AnfNodePtrList inp;
|
||||
auto &call_input = kernel_cnode->inputs();
|
||||
// let operators on different subgraphs will not be executed interleavedly
|
||||
for (size_t i = 1; i < call_input.size(); i++) {
|
||||
if (last_call != nullptr && (*last_call) != nullptr) {
|
||||
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), call_input[i], (*last_call)});
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(call_input[i]->abstract());
|
||||
inp.push_back(depend);
|
||||
} else {
|
||||
inp.push_back(call_input[i]);
|
||||
}
|
||||
}
|
||||
const auto &ref_map = sub_graph->GetRefMap();
|
||||
auto out = session::KernelGraphMgr::DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope(),
|
||||
kernel_info->graph_id(), ref_map, graph, is_switch_inline);
|
||||
(void)mng->Replace(kernel_cnode, out);
|
||||
// Inline graph boundary: MakeTuple---->Depend---->Tensormove
|
||||
// Avoid long link times at runtime
|
||||
if (last_call != nullptr) {
|
||||
auto value_node = graph->NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(1)));
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), value_node, out});
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(value_node->abstract());
|
||||
auto tensor_move =
|
||||
graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTensorMove->name())), depend});
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
tensor_move->set_abstract(value_node->abstract());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKernelGraphBoundary, MakeValue(sub_graph), tensor_move);
|
||||
// select kernel
|
||||
SelectKernelInfo(graph, tensor_move);
|
||||
(*last_call) = tensor_move;
|
||||
}
|
||||
}
|
||||
|
||||
void InlineCallGraph(const KernelGraphPtr &graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
@ -167,53 +225,7 @@ void InlineSubGraph(const KernelGraphPtr &graph) {
|
|||
for (auto &kernel_cnode : kernel_cnodes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_cnode);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimCallInline)) {
|
||||
auto sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(kernel_cnode, kAttrKernelGraph);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
MS_LOG(INFO) << "InlineSubGraph: " << kernel_cnode->fullname_with_scope()
|
||||
<< ", sub graph: " << sub_graph->graph_id() << ", need inline: " << sub_graph->need_inline();
|
||||
auto main_graph = kernel_cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(main_graph);
|
||||
auto mng = main_graph->manager();
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_cnode->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
AnfNodePtrList inp;
|
||||
auto &call_input = kernel_cnode->inputs();
|
||||
// let operators on different subgraphs will not be executed interleavedly
|
||||
for (size_t i = 1; i < call_input.size(); i++) {
|
||||
if (last_call != nullptr) {
|
||||
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), call_input[i], last_call});
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(call_input[i]->abstract());
|
||||
inp.push_back(depend);
|
||||
} else {
|
||||
inp.push_back(call_input[i]);
|
||||
}
|
||||
}
|
||||
const auto &ref_map = sub_graph->GetRefMap();
|
||||
auto out = session::KernelGraphMgr::DoInline(sub_graph, main_graph, inp, kernel_cnode->input(0)->scope(),
|
||||
kernel_info->graph_id(), ref_map, graph);
|
||||
(void)mng->Replace(kernel_cnode, out);
|
||||
// Inline graph boundary: MakeTuple---->Depend---->Tensormove
|
||||
// Avoid long link times at runtime
|
||||
auto value_node = graph->NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(1)));
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto depend = graph->NewCNode({NewValueNode(prim::kPrimDepend), value_node, out});
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_abstract(value_node->abstract());
|
||||
auto tensor_move = graph->NewCNode({NewValueNode(prim::kPrimTensorMove), depend});
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
tensor_move->set_abstract(value_node->abstract());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKernelGraphBoundary, MakeValue(sub_graph), tensor_move);
|
||||
|
||||
// select kernel
|
||||
auto [select_res, msg, etype] = device::ascend::SelectKernelInfoWithMsg(graph, tensor_move);
|
||||
if (!select_res) {
|
||||
MS_LOG(INFO) << "node is " << tensor_move->fullname_with_scope() << " should backoff";
|
||||
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
|
||||
device::ascend::HandleKernelSelectFailure(graph, tensor_move, failure_info);
|
||||
}
|
||||
|
||||
last_call = tensor_move;
|
||||
InlineSubGraph(graph, kernel_cnode, &last_call, false);
|
||||
}
|
||||
}
|
||||
GEGraphOptimization::GetInstance().OptimizeACLGraphAfterInline(graph);
|
||||
|
@ -224,6 +236,418 @@ void InlineSubGraph(const KernelGraphPtr &graph) {
|
|||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
CNodePtr GetCondSwitchNode(const KernelGraphPtr &graph, const std::map<AnfNodePtr, size_t> &branch_input,
|
||||
const AnfNodePtr &cond, std::map<AnfNodePtr, AnfNodePtr> *branch_tuple_getitem) {
|
||||
std::vector<AnfNodePtr> cond_switch_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionSwitch->name()))};
|
||||
cond_switch_inputs.resize(branch_input.size() + kIndex2);
|
||||
cond_switch_inputs[kIndex1] = cond;
|
||||
for (auto &kv : branch_input) {
|
||||
cond_switch_inputs[kv.second + kIndex2] = kv.first;
|
||||
}
|
||||
auto cond_switch_node = graph->NewCNode(cond_switch_inputs);
|
||||
MS_EXCEPTION_IF_NULL(cond_switch_node);
|
||||
|
||||
for (auto &kv : branch_input) {
|
||||
if (branch_tuple_getitem->find(kv.first) == branch_tuple_getitem->end()) {
|
||||
auto tuple_getitem_node =
|
||||
graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cond_switch_node, NewValueNode(SizeToLong(kv.second))});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem_node);
|
||||
tuple_getitem_node->set_abstract(kv.first->abstract());
|
||||
(*branch_tuple_getitem)[kv.first] = tuple_getitem_node;
|
||||
}
|
||||
}
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t i = kIndex2; i < cond_switch_inputs.size(); ++i) {
|
||||
(void)abstract_list.emplace_back(cond_switch_inputs[i]->abstract());
|
||||
}
|
||||
cond_switch_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
SelectKernelInfo(graph, cond_switch_node);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(cond_switch_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
for (size_t input_index = 1; input_index < common::AnfAlgo::GetInputTensorNum(cond_switch_node); ++input_index) {
|
||||
kernel_info->AddRefMap(input_index - 1, input_index);
|
||||
}
|
||||
return cond_switch_node;
|
||||
}
|
||||
|
||||
CNodePtr GetBranchNode(const KernelGraphPtr &graph, const CNodePtr &old_branch_node,
|
||||
const std::map<AnfNodePtr, AnfNodePtr> &branch_tuple_getitem) {
|
||||
std::vector<AnfNodePtr> branch_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartialInline->name()))};
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(old_branch_node); i++) {
|
||||
auto input = common::AnfAlgo::GetInputNode(old_branch_node, i);
|
||||
if (branch_tuple_getitem.find(input) != branch_tuple_getitem.end()) {
|
||||
branch_inputs.push_back(branch_tuple_getitem.at(input));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid input of branch node: " << old_branch_node->fullname_with_scope() << ", " << i
|
||||
<< ", " << input->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
auto branch_node = graph->NewCNode(branch_inputs);
|
||||
MS_EXCEPTION_IF_NULL(branch_node);
|
||||
branch_node->set_abstract(old_branch_node->abstract());
|
||||
SelectKernelInfo(graph, branch_node);
|
||||
common::AnfAlgo::CopyNodeAttrs(old_branch_node, branch_node);
|
||||
return branch_node;
|
||||
}
|
||||
|
||||
void InlineSwitchGraph(const KernelGraphPtr &graph, std::set<KernelGraphPtr> *const memo) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
memo->insert(graph);
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
InlineSwitchGraph(child_graph.lock(), memo);
|
||||
}
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
bool save_graphs = context_ptr->CanDump(kIntroductory);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_before_inline_switch_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
// process ConditionSwitch/ConditionGather
|
||||
auto kernel_cnodes = graph->execution_order();
|
||||
auto mng = graph->manager();
|
||||
if (mng == nullptr) {
|
||||
auto manager = MakeManager({graph});
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
mng = graph->manager();
|
||||
}
|
||||
std::vector<CNodePtr> partial_inline_cnode;
|
||||
for (auto &kernel_cnode : kernel_cnodes) {
|
||||
if (!IsPrimitiveCNode(kernel_cnode, prim::kPrimSwitch)) {
|
||||
continue;
|
||||
}
|
||||
auto input_num = common::AnfAlgo::GetInputNum(kernel_cnode);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_num == kSwitchInputSize,
|
||||
"Invalid input num of switch node: " + kernel_cnode->DebugString());
|
||||
auto cond = kernel_cnode->input(kSwitchCondIndex);
|
||||
auto true_branch = kernel_cnode->input(kSwitchBranchTrueIndex);
|
||||
auto false_branch = kernel_cnode->input(kSwitchBranchFalseIndex);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(true_branch, prim::kPrimPartialInline),
|
||||
"Invalid true branch of switch node: " + kernel_cnode->DebugString());
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(IsPrimitiveCNode(false_branch, prim::kPrimPartialInline),
|
||||
"Invalid false branch of switch node: " + kernel_cnode->DebugString());
|
||||
auto true_branch_cnode = true_branch->cast<CNodePtr>();
|
||||
auto false_branch_cnode = false_branch->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(true_branch_cnode);
|
||||
MS_EXCEPTION_IF_NULL(false_branch_cnode);
|
||||
std::map<AnfNodePtr, size_t> branch_input;
|
||||
std::map<AnfNodePtr, AnfNodePtr> branch_tuple_getitem;
|
||||
auto now_input_cnt = 0;
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(true_branch_cnode); i++) {
|
||||
auto input = common::AnfAlgo::GetInputNode(true_branch_cnode, i);
|
||||
if (branch_input.find(input) != branch_input.end()) {
|
||||
continue;
|
||||
}
|
||||
branch_input[input] = now_input_cnt++;
|
||||
}
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetInputNum(false_branch_cnode); i++) {
|
||||
auto input = common::AnfAlgo::GetInputNode(false_branch_cnode, i);
|
||||
if (branch_input.find(input) != branch_input.end()) {
|
||||
continue;
|
||||
}
|
||||
branch_input[input] = now_input_cnt++;
|
||||
}
|
||||
|
||||
auto cond_switch_node = GetCondSwitchNode(graph, branch_input, cond, &branch_tuple_getitem);
|
||||
auto true_branch_node = GetBranchNode(graph, true_branch_cnode, branch_tuple_getitem);
|
||||
auto false_branch_node = GetBranchNode(graph, false_branch_cnode, branch_tuple_getitem);
|
||||
auto cond_gather_node =
|
||||
graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionGather->name())), true_branch_node,
|
||||
false_branch_node});
|
||||
cond_gather_node->set_abstract(kernel_cnode->abstract());
|
||||
SelectKernelInfo(graph, cond_gather_node);
|
||||
partial_inline_cnode.push_back(true_branch_node);
|
||||
partial_inline_cnode.push_back(false_branch_node);
|
||||
|
||||
// Record the branch info for condition node.
|
||||
auto false_sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(false_branch_cnode, kAttrKernelGraph);
|
||||
auto true_sub_graph = common::AnfAlgo::GetNodeAttr<KernelGraphPtr>(true_branch_cnode, kAttrKernelGraph);
|
||||
MS_EXCEPTION_IF_NULL(false_sub_graph);
|
||||
MS_EXCEPTION_IF_NULL(true_sub_graph);
|
||||
std::vector<ValuePtr> branch_graph_names;
|
||||
branch_graph_names.emplace_back(std::make_shared<StringImm>(false_sub_graph->ToString()));
|
||||
branch_graph_names.emplace_back(std::make_shared<StringImm>(true_sub_graph->ToString()));
|
||||
cond_switch_node->AddAttr(kInlineSubGraphName, std::make_shared<ValueTuple>(branch_graph_names));
|
||||
reverse(branch_graph_names.begin(), branch_graph_names.end());
|
||||
cond_gather_node->AddAttr(kAttrBranchGraphName, std::make_shared<ValueTuple>(branch_graph_names));
|
||||
graph->AddConditionGatherSwitchPair(cond_gather_node, cond_switch_node);
|
||||
MS_LOG(DEBUG) << "Add new condition gather node:" << cond_gather_node->fullname_with_scope()
|
||||
<< " and condition switch actor:" << cond_switch_node->fullname_with_scope()
|
||||
<< " for graph:" << graph->ToString();
|
||||
(void)mng->Replace(kernel_cnode, cond_gather_node);
|
||||
}
|
||||
|
||||
// inline switch graph
|
||||
for (auto &kernel_cnode : partial_inline_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_cnode);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnode, prim::kPrimPartialInline)) {
|
||||
InlineSubGraph(graph, kernel_cnode, nullptr, true);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid node type, node: " << kernel_cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_after_inline_switch_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Flatten the input abstract, and record the index construct.
|
||||
// eg. tuple(tuple(1, 2), 3) -> {1, 2, 3} ((-1, -1), -1)
|
||||
AbstractBasePtrList CollectAbstract(const abstract::AbstractBasePtr &abstract, ValuePtr *abstract_construct_index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
MS_EXCEPTION_IF_NULL(abstract_construct_index);
|
||||
if (!abstract->isa<abstract::AbstractSequence>()) {
|
||||
*abstract_construct_index = MakeValue(-1);
|
||||
return {abstract};
|
||||
}
|
||||
const auto &seq_abs = abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
if (seq_abs->dynamic_len()) {
|
||||
*abstract_construct_index = MakeValue(-1);
|
||||
return {seq_abs};
|
||||
}
|
||||
|
||||
AbstractBasePtrList abs_list;
|
||||
ValuePtrList construct_index_list;
|
||||
for (const auto &sub_abs : seq_abs->elements()) {
|
||||
ValuePtr sub_construct_index = nullptr;
|
||||
AbstractBasePtrList sub_list = CollectAbstract(sub_abs, &sub_construct_index);
|
||||
abs_list.insert(abs_list.end(), sub_list.begin(), sub_list.end());
|
||||
construct_index_list.emplace_back(sub_construct_index);
|
||||
}
|
||||
*abstract_construct_index = std::make_shared<ValueTuple>(construct_index_list);
|
||||
return abs_list;
|
||||
}
|
||||
|
||||
// Rebuild the output construct by construct index.
|
||||
CNodePtr ConstructMakeTupleRecursion(const ValuePtr &abstract_construct_index, std::deque<CNodePtr> *get_item_list,
|
||||
const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(abstract_construct_index);
|
||||
MS_EXCEPTION_IF_NULL(get_item_list);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!abstract_construct_index->isa<ValueSequence>()) {
|
||||
if (get_item_list->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get item node by value:" << abstract_construct_index->ToString();
|
||||
} else {
|
||||
auto top = get_item_list->front();
|
||||
get_item_list->pop_front();
|
||||
return top;
|
||||
}
|
||||
}
|
||||
|
||||
// Build node and abstract for tuple construct.
|
||||
const auto &seq_value = abstract_construct_index->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(seq_value);
|
||||
AnfNodePtrList node_list{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
AbstractBasePtrList abs_list;
|
||||
for (const auto &sub_value : seq_value->value()) {
|
||||
MS_EXCEPTION_IF_NULL(sub_value);
|
||||
const auto &new_node = ConstructMakeTupleRecursion(sub_value, get_item_list, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
node_list.emplace_back(new_node);
|
||||
abs_list.emplace_back(new_node->abstract());
|
||||
}
|
||||
const auto &make_tuple = graph->NewCNode(node_list);
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
AnfNodePtrList CreateTupleGetItemForTupleOutput(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &abstract = node->abstract();
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract for node:" << node->DebugString();
|
||||
}
|
||||
|
||||
if (!abstract->isa<abstract::AbstractSequence>()) {
|
||||
return {node};
|
||||
}
|
||||
const auto &sequence_abstract = abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sequence_abstract);
|
||||
if (sequence_abstract->dynamic_len()) {
|
||||
return {node};
|
||||
}
|
||||
AnfNodePtrList outputs;
|
||||
for (size_t i = 0; i < sequence_abstract->elements().size(); ++i) {
|
||||
const auto &sub_abstract = sequence_abstract->elements()[i];
|
||||
MS_EXCEPTION_IF_NULL(sub_abstract);
|
||||
auto get_item = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleGetItem->name())), node,
|
||||
NewValueNode(MakeValue<int64_t>(SizeToLong(i)))});
|
||||
get_item->set_abstract(sub_abstract);
|
||||
const auto &sub_outputs = CreateTupleGetItemForTupleOutput(get_item, graph);
|
||||
outputs.insert(outputs.end(), sub_outputs.begin(), sub_outputs.end());
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
// Flatten the tuple input of condition gather.
|
||||
CNodePtr FlattenConditionGatherNodeInput(const CNodePtr &kernel, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto mng = graph->manager();
|
||||
if (mng == nullptr) {
|
||||
auto manager = MakeManager({graph});
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
mng = graph->manager();
|
||||
}
|
||||
|
||||
AnfNodePtrList new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConditionGather->name()))};
|
||||
size_t output_num = SIZE_MAX;
|
||||
// Collect inputs.
|
||||
for (size_t i = 1; i < kernel->inputs().size(); ++i) {
|
||||
const auto &input = kernel->inputs()[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
AnfNodePtrList outputs = CreateTupleGetItemForTupleOutput(input, graph);
|
||||
// All input branch should have same output num.
|
||||
if (output_num != SIZE_MAX && output_num != outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output size:" << output_num << " and " << outputs.size()
|
||||
<< " for kernel:" << kernel->fullname_with_scope();
|
||||
}
|
||||
output_num = outputs.size();
|
||||
new_inputs.insert(new_inputs.end(), outputs.begin(), outputs.end());
|
||||
}
|
||||
|
||||
// Create new condition gather node.
|
||||
auto new_kernel = graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_kernel);
|
||||
ValuePtr abstract_construct_index = nullptr;
|
||||
AbstractBasePtrList new_abstract_list = CollectAbstract(kernel->abstract(), &abstract_construct_index);
|
||||
MS_EXCEPTION_IF_NULL(abstract_construct_index);
|
||||
MS_LOG(INFO) << "Abstract construct index:" << abstract_construct_index->ToString()
|
||||
<< " for rebuild the abstract of kernel:" << new_kernel->DebugString();
|
||||
if (new_abstract_list.size() != output_num) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract list size:" << new_abstract_list.size() << " and output size:" << output_num
|
||||
<< " for kernel:" << kernel->DebugString() << " abstract:" << kernel->abstract()->ToString();
|
||||
}
|
||||
new_kernel->set_abstract(std::make_shared<abstract::AbstractTuple>(new_abstract_list));
|
||||
SelectKernelInfo(graph, new_kernel);
|
||||
if (output_num == SIZE_MAX) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output size:" << output_num << " for kernel:" << kernel->fullname_with_scope();
|
||||
}
|
||||
new_kernel->AddAttr(kAttrBranchOutputNum, MakeValue<size_t>(output_num));
|
||||
if (kernel->HasAttr(kAttrBranchGraphName)) {
|
||||
new_kernel->AddAttr(kAttrBranchGraphName, kernel->GetAttr(kAttrBranchGraphName));
|
||||
}
|
||||
|
||||
// Rebuild the output construct for condition gather node.
|
||||
std::deque<CNodePtr> get_item_list;
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto get_item = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimTupleGetItem->name())),
|
||||
new_kernel, NewValueNode(MakeValue<int64_t>(i))});
|
||||
get_item_list.emplace_back(get_item);
|
||||
get_item->set_abstract(new_abstract_list[i]);
|
||||
}
|
||||
auto make_tuple = ConstructMakeTupleRecursion(abstract_construct_index, &get_item_list, graph);
|
||||
(void)mng->Replace(kernel, make_tuple);
|
||||
return new_kernel;
|
||||
}
|
||||
|
||||
// Flatten the tuple input of condition node.
|
||||
void FlattenConditionNodeInput(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (auto &kernel : graph->execution_order()) {
|
||||
if (!IsPrimitiveCNode(kernel, prim::kPrimConditionGather)) {
|
||||
continue;
|
||||
}
|
||||
const auto &new_kernel = FlattenConditionGatherNodeInput(kernel, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_kernel);
|
||||
const auto &iter = graph->condition_gather_to_switch().find(kernel);
|
||||
if (iter == graph->condition_gather_to_switch().end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get condition switch node for gather:" << kernel->DebugString();
|
||||
}
|
||||
const auto &inline_iter = graph->inline_sub_graph_kernels().find(kernel);
|
||||
if (inline_iter != graph->inline_sub_graph_kernels().end()) {
|
||||
graph->AddInlineSubgraphKernel(new_kernel, inline_iter->second);
|
||||
MS_LOG(INFO) << "Add new condition gather node:" << new_kernel->fullname_with_scope()
|
||||
<< " subgraph name:" << inline_iter->second << " to graph:" << graph->ToString();
|
||||
}
|
||||
graph->AddConditionGatherSwitchPair(new_kernel, iter->second);
|
||||
graph->RemoveConditionGatherSwitchPair(kernel);
|
||||
MS_LOG(INFO) << "Add new condition gather node:" << new_kernel->fullname_with_scope()
|
||||
<< " to replace node:" << kernel->fullname_with_scope() << " branch name:"
|
||||
<< (kernel->HasAttr(kAttrBranchGraphName) ? new_kernel->GetAttr(kAttrBranchGraphName)->ToString()
|
||||
: " null");
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->CanDump(kIntroductory);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_after_flatten_gather_input_graph_" + std::to_string(graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, graph, true, kWholeStack);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string GetBranchName(const KernelGraphPtr &graph, const CNodePtr &kernel) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::string current_branch = graph->ToString();
|
||||
const auto &iter = graph->inline_sub_graph_kernels().find(kernel);
|
||||
if (iter != graph->inline_sub_graph_kernels().end()) {
|
||||
current_branch = iter->second;
|
||||
}
|
||||
return current_branch;
|
||||
}
|
||||
|
||||
// Put the kernels belonging to the same inline subgraph together in the execution order.
|
||||
void FixExecutionOrderForInlineControlFlowGraph(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->condition_gather_to_switch().empty()) {
|
||||
return;
|
||||
}
|
||||
auto execution_order = graph->execution_order();
|
||||
for (const auto &condition_node_pair : graph->condition_gather_to_switch()) {
|
||||
std::vector<CNodePtr> new_order;
|
||||
std::vector<CNodePtr> new_order_after_switch;
|
||||
MS_EXCEPTION_IF_NULL(condition_node_pair.first);
|
||||
MS_EXCEPTION_IF_NULL(condition_node_pair.second);
|
||||
std::string current_branch = GetBranchName(graph, condition_node_pair.second->cast<CNodePtr>());
|
||||
bool is_get_switch = false;
|
||||
for (auto iter = execution_order.begin(); iter != execution_order.end(); ++iter) {
|
||||
if (*iter == condition_node_pair.second) {
|
||||
is_get_switch = true;
|
||||
continue;
|
||||
}
|
||||
if (*iter == condition_node_pair.first) {
|
||||
if (!is_get_switch) {
|
||||
MS_LOG(EXCEPTION) << "Condition gather:" << condition_node_pair.first->fullname_with_scope()
|
||||
<< " is in front of condition switch: "
|
||||
<< condition_node_pair.second->fullname_with_scope();
|
||||
}
|
||||
new_order.emplace_back(condition_node_pair.second->cast<CNodePtr>());
|
||||
new_order.insert(new_order.end(), new_order_after_switch.begin(), new_order_after_switch.end());
|
||||
new_order.insert(new_order.end(), iter, execution_order.end());
|
||||
break;
|
||||
}
|
||||
if (!is_get_switch || current_branch == GetBranchName(graph, *iter)) {
|
||||
new_order.emplace_back(*iter);
|
||||
} else {
|
||||
new_order_after_switch.emplace_back(*iter);
|
||||
}
|
||||
}
|
||||
if (execution_order.size() != new_order.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to reorder execution kernel for graph:" << graph->ToString();
|
||||
}
|
||||
execution_order = new_order;
|
||||
}
|
||||
graph->set_execution_order(execution_order);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GeKernelExecutor::Initialize() {
|
||||
|
@ -276,7 +700,10 @@ void GeKernelExecutor::OptimizeGraph(const FuncGraphPtr &graph) const {
|
|||
memo.clear();
|
||||
GEGraphOptimization::GetInstance().OptimizeACLGraphAfterKernelSelect(kernel_graph, &memo);
|
||||
memo.clear();
|
||||
InlineSubGraph(kernel_graph);
|
||||
InlineCallGraph(kernel_graph);
|
||||
memo.clear();
|
||||
InlineSwitchGraph(kernel_graph, &memo);
|
||||
FlattenConditionNodeInput(kernel_graph);
|
||||
OptimizeExecutionOrder(NOT_NULL(graph));
|
||||
profiler::CollectHostInfo("Ascend", "Graph Optimization", "GeOptimizeGraph", 1, 0, 1);
|
||||
}
|
||||
|
@ -375,6 +802,7 @@ void GeKernelExecutor::OptimizeExecutionOrder(const FuncGraphPtr &graph) const {
|
|||
kernel_graph->DisableRuntimeCache();
|
||||
kernel_graph->set_execution_order(execution_order);
|
||||
MS_LOG(DEBUG) << "Status record: end optimize execution order. graph id: " << kernel_graph->graph_id();
|
||||
FixExecutionOrderForInlineControlFlowGraph(kernel_graph);
|
||||
}
|
||||
|
||||
void GeKernelExecutor::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
||||
|
|
|
@ -41,11 +41,7 @@ file(GLOB_RECURSE SRC_IN_910B RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"opapi/*.cc"
|
||||
"pyboost/*.cc"
|
||||
"pyboost/auto_generate/*.cc"
|
||||
"rts/send.cc"
|
||||
"rts/recv.cc"
|
||||
"rts/rt_kernel.cc"
|
||||
"rts/rt_kernel_build.cc"
|
||||
"rts/rt_kernel_info.cc"
|
||||
"rts/*.cc"
|
||||
)
|
||||
list(REMOVE_ITEM SRC_IN_910B ${AICPU_OPS_SRC})
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ bool IsNotHcclCommunicationOp(const std::string &op_name) {
|
|||
}
|
||||
|
||||
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
static const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
|
||||
static const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
|
||||
kNumberTypeFloat32, kNumberTypeInt16, kNumberTypeBFloat16};
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/rts/condition_gather.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
ConditionGatherKernel::~ConditionGatherKernel() {}
|
||||
|
||||
bool ConditionGatherKernel::Init(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node);
|
||||
std::vector<KernelTensor *> output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node);
|
||||
Resize(input_kernel_tensors, output_kernel_tensors);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConditionGatherKernel::Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
|
||||
const std::vector<KernelTensor *> &, void *) {
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ConditionGatherKernel : public RtKernel {
|
||||
public:
|
||||
ConditionGatherKernel() = default;
|
||||
~ConditionGatherKernel() override;
|
||||
bool Init(const AnfNodePtr &anf_node) override;
|
||||
bool Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
|
||||
const std::vector<KernelTensor *> &, void *) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in RtKernel."; }
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL(conditiongather, ConditionGatherKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_GATHER_H
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/rts/condition_switch.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
ConditionSwitchKernel::~ConditionSwitchKernel() {}
|
||||
|
||||
bool ConditionSwitchKernel::Init(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::vector<KernelTensor *> input_kernel_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(anf_node);
|
||||
std::vector<KernelTensor *> output_kernel_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(anf_node);
|
||||
Resize(input_kernel_tensors, output_kernel_tensors);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConditionSwitchKernel::Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
|
||||
const std::vector<KernelTensor *> &, void *) {
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/ascend/kernel/rts/rt_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ConditionSwitchKernel : public RtKernel {
|
||||
public:
|
||||
ConditionSwitchKernel() = default;
|
||||
~ConditionSwitchKernel() override;
|
||||
bool Init(const AnfNodePtr &anf_node) override;
|
||||
bool Launch(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &,
|
||||
const std::vector<KernelTensor *> &, void *) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in RtKernel."; }
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL(conditionswitch, ConditionSwitchKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_CONDITION_SWITCH_H
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/optimizer/ge/process_partial_inline.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "include/backend/optimizer/helper.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "mindspore/core/ops/framework_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ProcessPartialInline::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimPartial, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ProcessPartialInline::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (!context->IsKByKExecutorMode()) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->HasPrimalAttr(kAttrNotCut)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto partial_graph = cnode->input(kIndex1);
|
||||
auto sub_kernel_graph = session::AnfRuntimeAlgorithm::GetValueNodeKernelGraph(partial_graph);
|
||||
std::vector<AnfNodePtr> partial_inline_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimPartialInline->name()))};
|
||||
for (size_t i = kIndex1; i < common::AnfAlgo::GetInputNum(cnode); i++) {
|
||||
partial_inline_inputs.emplace_back(common::AnfAlgo::GetInputNode(cnode, i));
|
||||
}
|
||||
auto partial_inline = graph->NewCNode(partial_inline_inputs);
|
||||
MS_EXCEPTION_IF_NULL(partial_inline);
|
||||
partial_inline->set_abstract(cnode->abstract());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKernelGraph, MakeValue(sub_kernel_graph), partial_inline);
|
||||
return partial_inline;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_
|
||||
|
||||
#include "include/backend/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ProcessPartialInline : public PatternProcessPass {
|
||||
public:
|
||||
explicit ProcessPartialInline(bool multi_graph = true) : PatternProcessPass("process partial inline", multi_graph) {}
|
||||
~ProcessPartialInline() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_PROCESS_PARTIAL_INLINE_H_
|
|
@ -41,6 +41,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ge/unfold_nested_output.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/resize_bilinear_add_attr.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/process_call_inline.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/process_partial_inline.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/deal_ref_output.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/hcom/insert_load_for_allgather.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/set_fracz_group_attr.h"
|
||||
|
@ -125,6 +126,7 @@ void GEBackendOptimizeACL(const KernelGraphPtr &kernel_graph) {
|
|||
opt_acl_pm->AddPass(std::make_shared<InsertTensorMoveForCommunication>());
|
||||
opt_acl_pm->AddPass(std::make_shared<opt::TransDependValueToInt32>());
|
||||
opt_acl_pm->AddPass(std::make_shared<opt::ProcessCallInline>());
|
||||
opt_acl_pm->AddPass(std::make_shared<opt::ProcessPartialInline>());
|
||||
opt_acl_pm->AddPass(std::make_shared<opt::ExpanderFallback>());
|
||||
optimizer->AddPassManager(opt_acl_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -524,20 +524,16 @@ void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *de
|
|||
kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
|
||||
MS_LOG(DEBUG) << "Kernel tensor created without set stream id, but set after device address created.";
|
||||
auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
|
||||
if (user_data != nullptr) {
|
||||
device_address->SetNodeIndex(kernel, i);
|
||||
}
|
||||
device_address->SetNodeIndex(kernel, i);
|
||||
if (is_from_persistent_mem) {
|
||||
device_address->set_from_persistent_mem(true);
|
||||
}
|
||||
if (find(outputs.begin(), outputs.end(), kernel) != outputs.end()) {
|
||||
device_address->SetNodeIndex(kernel, i);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
||||
<< " addr:" << device_address << " type:" << device_address->type_id()
|
||||
<< ", kernel tensor addr:" << kernel_tensor.get()
|
||||
<< ", kernel tensor: " << kernel_tensor->ToString() << " addr size:" << address_size
|
||||
<< " real size:" << device_address->GetSize();
|
||||
<< " real size:" << device_address->GetSize()
|
||||
<< " origin ref count:" << device_address->original_ref_count();
|
||||
device_address->set_stream_id(AnfAlgo::GetStreamId(kernel));
|
||||
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
||||
}
|
||||
|
@ -735,6 +731,8 @@ void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr
|
|||
MS_EXCEPTION_IF_NULL(group_node_device_address);
|
||||
// Update the reference count of device address.
|
||||
device_address->IncreaseOriginalRefCount();
|
||||
MS_LOG(DEBUG) << "After increase ref count for device address:" << device_address
|
||||
<< " ref count:" << device_address->original_ref_count();
|
||||
device_address->ResetRefCount();
|
||||
group_node_device_address->set_pointer_ref_count(device_address->pointer_ref_count());
|
||||
}
|
||||
|
@ -798,6 +796,8 @@ void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur
|
|||
cur_node_output_addr->DecreaseOriginalRefCount();
|
||||
cur_node_output_addr->ResetRefCount();
|
||||
origin_node_output_addr->IncreaseOriginalRefCount();
|
||||
MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
|
||||
<< " ref count:" << origin_node_output_addr->original_ref_count();
|
||||
origin_node_output_addr->ResetRefCount();
|
||||
cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
|
||||
cur_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);
|
||||
|
|
|
@ -26,11 +26,14 @@ void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<
|
|||
// The unused data may be invalid ptr.
|
||||
if (!ActorDispatcher::enable_async_launch_kernel() && !input_data->data_->IsPtrValid() &&
|
||||
!TEST_FLAG(input_data->data_->flag(), device::kDeviceAddressFlagNotUsed)) {
|
||||
MS_LOG(EXCEPTION) << "The input_data does not have a valid ptr of actor:" << GetAID().Name()
|
||||
<< " with index:" << input_data->index_ << ", flag:" << input_data->data_->flag()
|
||||
<< " device address:" << input_data->data_ << " ref count:" << input_data->data_->ref_count()
|
||||
<< " dynamic ref count:" << input_data->data_->dynamic_ref_count()
|
||||
<< " origin ref count:" << input_data->data_->original_ref_count();
|
||||
std::string error_info = "The input_data does not have a valid ptr of actor:" + GetAID().Name() +
|
||||
" with index:" + std::to_string(input_data->index_) +
|
||||
", flag:" + std::to_string(input_data->data_->flag()) +
|
||||
" device address:" + std::to_string((int64_t)(input_data->data_)) +
|
||||
" ref count:" + std::to_string(input_data->data_->ref_count()) +
|
||||
" dynamic ref count:" + std::to_string(input_data->data_->dynamic_ref_count()) +
|
||||
" origin ref count:" + std::to_string(input_data->data_->original_ref_count());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
|
|
|
@ -161,6 +161,15 @@ bool IsRpcActor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsInnerControlFlowActor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == "ConditionSwitch" ||
|
||||
common::AnfAlgo::GetCNodeName(node) == "ConditionGather")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>()) {
|
||||
|
|
|
@ -116,7 +116,10 @@ enum class KernelTransformType {
|
|||
// Memory actor type.
|
||||
kMemoryAllocActor,
|
||||
kMemoryFreeActor,
|
||||
kMemorySwapActor
|
||||
kMemorySwapActor,
|
||||
// Inner control flow actor type.
|
||||
kConditionGatherActor,
|
||||
kConditionSwitchActor
|
||||
};
|
||||
|
||||
#define SET_OPCONTEXT_FAIL_RET_WITH_ERROR(op_context, message) \
|
||||
|
@ -302,6 +305,7 @@ bool IsSkippedKernelActor(const AnfNodePtr &node);
|
|||
|
||||
bool IsRpcActor(const AnfNodePtr &node);
|
||||
|
||||
bool IsInnerControlFlowActor(const AnfNodePtr &node);
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph.
|
||||
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||
|
|
|
@ -46,6 +46,8 @@
|
|||
#include "runtime/graph_scheduler/actor/control_flow/entrance_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/memory/memory_swap_actor.h"
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
ConditionGatherActor::ConditionGatherActor(const std::string &name, const CNodePtr &kernel,
|
||||
const DeviceContext *device_context, const AID &memory_manager_aid,
|
||||
const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy,
|
||||
const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes,
|
||||
const KernelTransformType &type)
|
||||
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
|
||||
|
||||
void ConditionGatherActor::RunBranchName(const std::string &branch_name, OpContext<DeviceTensor> *const context) {
|
||||
MS_LOG(DEBUG) << "Condition gather actor:" << GetAID() << " branch name:" << branch_name;
|
||||
current_branch_name_ = branch_name;
|
||||
if (branch_name_to_input_data_num_.find(current_branch_name_) == branch_name_to_input_data_num_.end()) {
|
||||
input_datas_num_ = 0;
|
||||
} else {
|
||||
input_datas_num_ = branch_name_to_input_data_num_[current_branch_name_];
|
||||
}
|
||||
if (branch_name_to_input_control_num_.find(current_branch_name_) == branch_name_to_input_control_num_.end()) {
|
||||
input_controls_num_ = 0;
|
||||
} else {
|
||||
input_controls_num_ = branch_name_to_input_control_num_[current_branch_name_];
|
||||
}
|
||||
if (input_datas_num_ == 0 && input_controls_num_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "No input data and input control, branch id:" << current_branch_name_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Input data num:" << input_datas_num_ << " control num:" << input_controls_num_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
void ConditionGatherActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
input_device_tensors_.resize(branch_output_num_);
|
||||
InitOutputData();
|
||||
}
|
||||
|
||||
void ConditionGatherActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = std::find(branch_names_.begin(), branch_names_.end(), current_branch_name_);
|
||||
if (iter == branch_names_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid current branch name:" << current_branch_name_ << " total:" << branch_names_
|
||||
<< " for actor:" << GetAID();
|
||||
}
|
||||
size_t start_index = branch_output_num_ * (iter - branch_names_.begin());
|
||||
|
||||
memory_free_list_.clear();
|
||||
// Fetch input device tensor from input data.
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
if (IntToSize(input_data->index_) < start_index ||
|
||||
IntToSize(input_data->index_) - start_index >= input_device_tensors_.size()) {
|
||||
std::string error_info =
|
||||
"Invalid input index:" + std::to_string(input_data->index_) + " start:" + std::to_string(start_index) +
|
||||
" total:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
input_device_tensors_[IntToSize(input_data->index_) - start_index] = input_data->data_;
|
||||
memory_free_list_.emplace_back(input_data->data_);
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input device tensor from device tensor store.
|
||||
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
if (device_tensor_store_key.first < start_index ||
|
||||
device_tensor_store_key.first - start_index >= input_device_tensors_.size()) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
|
||||
device_contexts_[0]->GetDeviceType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
input_device_tensors_[device_tensor_store_key.first - start_index] = device_tensor.get();
|
||||
}
|
||||
|
||||
if (output_data_.size() != output_data_arrows_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output data size:" << output_data_.size()
|
||||
<< " and output data arrow size:" << output_data_arrows_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(output_data_arrows_[i]);
|
||||
MS_EXCEPTION_IF_NULL(output_data_[i].first);
|
||||
const auto &from_index = output_data_arrows_[i]->from_output_index_;
|
||||
if (IntToSize(from_index) >= input_device_tensors_.size() || from_index < 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index << " to actor:" << output_data_arrows_[i]->to_op_id_
|
||||
<< " to index:" << output_data_arrows_[i]->to_input_index_ << " for actor:" << GetAID();
|
||||
}
|
||||
if (input_device_tensors_[from_index] == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get input device tensor index:" + std::to_string(from_index) + " failed.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
output_data_[i].first->data_ = input_device_tensors_[from_index];
|
||||
}
|
||||
}
|
||||
|
||||
void ConditionGatherActor::Run(OpContext<DeviceTensor> *const context) {
|
||||
try {
|
||||
FetchInput(context);
|
||||
if (memory_free_list_.size() > 0) {
|
||||
SendMemoryFreeReq(context);
|
||||
}
|
||||
MS_LOG(DEBUG) << "My executor order log launch kernel:" << kernel_->fullname_with_scope();
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
} catch (const std::exception &e) {
|
||||
MsException::Instance().SetException();
|
||||
std::string error_info =
|
||||
"#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
#include "runtime/graph_scheduler/actor/kernel_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
// Condition gather actor is used to collect the output of different branch from condition switch actor.
|
||||
class ConditionGatherActor : public KernelActor {
|
||||
public:
|
||||
ConditionGatherActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes,
|
||||
const KernelTransformType &type = KernelTransformType::kConditionGatherActor);
|
||||
~ConditionGatherActor() override = default;
|
||||
// Receive the branch name from condition switch actor.
|
||||
void RunBranchName(const std::string &branch_name, OpContext<DeviceTensor> *const context);
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
void FetchInput(OpContext<DeviceTensor> *const context);
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class InlineControlFlowScheduler;
|
||||
// Output num of each branch.
|
||||
size_t branch_output_num_;
|
||||
// The order of each branch name.
|
||||
std::vector<std::string> branch_names_;
|
||||
// The current execute branch between switch and gather actor.
|
||||
std::string current_branch_name_;
|
||||
// Input data and control num for each branch.
|
||||
mindspore::HashMap<std::string, size_t> branch_name_to_id_;
|
||||
mindspore::HashMap<std::string, size_t> branch_name_to_input_data_num_;
|
||||
mindspore::HashMap<std::string, size_t> branch_name_to_input_control_num_;
|
||||
};
|
||||
|
||||
using ConditionGatherActorPtr = std::shared_ptr<ConditionGatherActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_GATHER_ACTOR_H_
|
|
@ -0,0 +1,191 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/actor/control_flow/condition_switch_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/condition_gather_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
ConditionSwitchActor::ConditionSwitchActor(const std::string &name, const CNodePtr &kernel,
|
||||
const DeviceContext *device_context, const AID &memory_manager_aid,
|
||||
const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy,
|
||||
const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes,
|
||||
const KernelTransformType &type)
|
||||
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
|
||||
|
||||
void ConditionSwitchActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
input_device_tensors_.resize(common::AnfAlgo::GetInputTensorNum(kernel_));
|
||||
|
||||
InitOutputData();
|
||||
output_data_by_output_index_.resize(AnfAlgo::GetOutputTensorNum(kernel_));
|
||||
if (output_data_.size() != output_data_arrows_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
|
||||
const auto &output_data = output_data_[i].first;
|
||||
const auto &data_arrow = output_data_arrows_[i];
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
const auto &from_index = data_arrow->from_output_index_;
|
||||
if (IntToSize(from_index) >= output_data_by_output_index_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from index:" << from_index
|
||||
<< " and output size:" << output_data_by_output_index_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
output_data_by_output_index_[from_index].emplace_back(output_data.get());
|
||||
}
|
||||
}
|
||||
|
||||
void ConditionSwitchActor::SendOutput(OpContext<DeviceTensor> *const context, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(gather_aid_);
|
||||
MS_LOG(DEBUG) << "condition actor run for index:" << index << " branch name:" << branch_names_[index]
|
||||
<< " for actor:" << GetAID();
|
||||
ActorDispatcher::Send(*gather_aid_, &ConditionGatherActor::RunBranchName, branch_names_[index], context);
|
||||
|
||||
if (output_data_arrows_.size() != output_data_nodes_.size() || output_data_nodes_.size() != output_data_.size() ||
|
||||
output_data_.size() != output_data_branch_indexes_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid data arrow size:" << output_data_arrows_.size()
|
||||
<< " node size:" << output_data_nodes_.size() << " data size:" << output_data_.size()
|
||||
<< " index size:" << output_data_branch_indexes_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
for (size_t i = 0; i < output_data_branch_indexes_.size(); ++i) {
|
||||
if (output_data_branch_indexes_[i] == index) {
|
||||
ActorDispatcher::Send(output_data_arrows_[i]->to_op_id_, &OpActor::RunOpData, output_data_[i].first.get(),
|
||||
context);
|
||||
}
|
||||
}
|
||||
|
||||
if (output_control_arrows_.size() != output_control_branch_indexes_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid control arrow size:" << output_control_arrows_.size()
|
||||
<< output_control_branch_indexes_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
for (size_t i = 0; i < output_control_branch_indexes_.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(output_control_arrows_[i]);
|
||||
if (output_control_branch_indexes_[i] == index) {
|
||||
ActorDispatcher::Send(output_control_arrows_[i]->to_op_id_, &OpActor::RunOpControl, const_cast<AID *>(&GetAID()),
|
||||
context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConditionSwitchActor::Run(OpContext<DeviceTensor> *const context) {
|
||||
try {
|
||||
FetchInput(context);
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensors_[0]->kernel_tensor());
|
||||
bool index = input_device_tensors_[0]->kernel_tensor()->GetValueWithCheck<bool>();
|
||||
MS_LOG(DEBUG) << "Index:" << index << " for actor:" << GetAID();
|
||||
EraseInput(context);
|
||||
CollectMemoryFreeList(index);
|
||||
if (memory_free_list_.size() > 0) {
|
||||
SendMemoryFreeReq(context);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Launch kernel:" << kernel_->fullname_with_scope();
|
||||
SendOutput(context, index);
|
||||
} catch (const std::exception &e) {
|
||||
MsException::Instance().SetException();
|
||||
std::string error_info =
|
||||
"#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
void ConditionSwitchActor::CollectMemoryFreeList(size_t index) {
|
||||
memory_free_list_.clear();
|
||||
memory_free_list_.insert(memory_free_list_.end(), input_device_tensors_.begin(), input_device_tensors_.end());
|
||||
for (size_t i = 0; i < branch_origin_ref_count_.size(); ++i) {
|
||||
if (i == index) {
|
||||
continue;
|
||||
}
|
||||
if (branch_origin_ref_count_[i].size() + 1 != input_device_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid origin ref count size:" << branch_origin_ref_count_[i]
|
||||
<< " and input size:" << input_device_tensors_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Free memory for branch:" << i << " for actor:" << GetAID();
|
||||
for (size_t j = 0; j < branch_origin_ref_count_[i].size(); ++j) {
|
||||
std::fill_n(back_inserter(memory_free_list_), branch_origin_ref_count_[i][j], input_device_tensors_[j + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConditionSwitchActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// Fetch input device tensor from input data.
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
|
||||
std::string error_info = "Invalid input index, need:" + std::to_string(input_data->index_) +
|
||||
" current:" + std::to_string(input_device_tensors_.size()) +
|
||||
" for actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
input_device_tensors_[IntToSize(input_data->index_)] = input_data->data_;
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch input device tensor from device tensor store.
|
||||
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
|
||||
device_contexts_[0]->GetDeviceType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
if (device_tensor_store_key.first >= input_device_tensors_.size()) {
|
||||
std::string error_info =
|
||||
"The input index is out of range, need:" + std::to_string(device_tensor_store_key.first) +
|
||||
" current:" + std::to_string(input_device_tensors_.size()) + " for actor:" + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
|
||||
}
|
||||
|
||||
if (output_data_by_output_index_.size() + 1 != input_device_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output size:" << output_data_by_output_index_.size()
|
||||
<< " and input device tensor size:" << input_device_tensors_.size() << " for actor:" << GetAID();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) {
|
||||
if (output_data_by_output_index_[i].empty()) {
|
||||
continue;
|
||||
}
|
||||
const auto &data = input_device_tensors_[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
for (auto &output_data : output_data_by_output_index_[i]) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
output_data->data_ = data;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
#include "runtime/graph_scheduler/actor/kernel_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
|
||||
// Condition switch actor is used to execute the branch according to the input condition in kernel graph.
|
||||
class ConditionSwitchActor : public KernelActor {
|
||||
public:
|
||||
ConditionSwitchActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes,
|
||||
const KernelTransformType &type = KernelTransformType::kConditionSwitchActor);
|
||||
~ConditionSwitchActor() override = default;
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
void FetchInput(OpContext<DeviceTensor> *const context);
|
||||
void SendOutput(OpContext<DeviceTensor> *const context, size_t index);
|
||||
|
||||
private:
|
||||
friend class InlineControlFlowScheduler;
|
||||
// Collect memory free list, as the ref counts of different branches are superimposed on the output,
|
||||
// so the excess reference counts of other branches need to be subtracted in advance.
|
||||
void CollectMemoryFreeList(size_t index);
|
||||
|
||||
// Graph name of each branch,
|
||||
std::vector<std::string> branch_names_;
|
||||
// Ref count of each branch.
|
||||
std::vector<std::vector<size_t>> branch_origin_ref_count_;
|
||||
// Branch of data arrow and control arrow.
|
||||
std::vector<size_t> output_data_branch_indexes_;
|
||||
std::vector<size_t> output_control_branch_indexes_;
|
||||
|
||||
// Cache output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
|
||||
|
||||
// Switch needs to send current branch name to the corresponding gather actor to check its inputs.
|
||||
AID *gather_aid_{nullptr};
|
||||
};
|
||||
|
||||
using ConditionSwitchActorPtr = std::shared_ptr<ConditionSwitchActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CONTROLFLOW_CONDITION_SWITCH_ACTOR_H_
|
|
@ -552,7 +552,7 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
|
|||
(void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageLink, 1, 0, 0);
|
||||
Link(actor_set.get(), graph_compiler_info);
|
||||
(void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageLink, 1, 0, 1);
|
||||
|
||||
inline_control_flow_scheduler_.Link(actor_set.get(), graph_compiler_info);
|
||||
DumpActor(actor_set.get(), graph_compiler_info);
|
||||
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
SchedulerHelper::CheckActorValid(actor_set.get());
|
||||
|
@ -1062,6 +1062,8 @@ void GraphScheduler::UpdateDeviceAddressByRefInternalParameter(const GraphCompil
|
|||
cur_node_output_addr->DecreaseOriginalRefCount();
|
||||
cur_node_output_addr->ResetRefCount();
|
||||
origin_node_output_addr->IncreaseOriginalRefCount();
|
||||
MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
|
||||
<< " ref count:" << origin_node_output_addr->original_ref_count();
|
||||
origin_node_output_addr->ResetRefCount();
|
||||
cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
|
||||
}
|
||||
|
@ -1285,6 +1287,9 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
|
|||
KernelActorPtr kernel_actor = nullptr;
|
||||
if (IsRpcActor(kernel)) {
|
||||
kernel_actor = GenerateRpcActor(kernel, real_device_context, strategy, ref_input_indexes, ref_output_indexes);
|
||||
} else if (IsInnerControlFlowActor(kernel)) {
|
||||
kernel_actor =
|
||||
GenerateInnerControlFlowActor(kernel, real_device_context, strategy, ref_input_indexes, ref_output_indexes);
|
||||
} else {
|
||||
kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, real_device_context,
|
||||
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
|
||||
|
@ -1518,6 +1523,27 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
KernelActorPtr GraphScheduler::GenerateInnerControlFlowActor(const CNodePtr &kernel,
|
||||
const DeviceContext *device_context,
|
||||
GraphExecutionStrategy strategy,
|
||||
const std::set<size_t> &ref_input_indexes,
|
||||
const std::set<size_t> &ref_output_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != "ConditionSwitch" &&
|
||||
common::AnfAlgo::GetCNodeName(kernel) != "ConditionGather") {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Kernel " << kernel->fullname_with_scope()
|
||||
<< " is not a inner control flow kernel.";
|
||||
}
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) == "ConditionSwitch") {
|
||||
return std::make_shared<ConditionSwitchActor>(kernel->fullname_with_scope(), kernel, device_context,
|
||||
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
|
||||
ref_input_indexes, ref_output_indexes);
|
||||
}
|
||||
return std::make_shared<ConditionGatherActor>(kernel->fullname_with_scope(), kernel, device_context,
|
||||
memory_manager_aid_, debug_aid_, recorder_aid_, strategy,
|
||||
ref_input_indexes, ref_output_indexes);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void GetAllUInputByCNode(const CNodePtr &cnode,
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
|
||||
|
@ -2385,7 +2411,7 @@ void GraphScheduler::LinkDataArrowForCustomActor(const ActorSet *actor_set,
|
|||
void GraphScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
|
||||
const GraphCompilerInfo &graph_compiler_info) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
|
||||
if (graph->is_graph_run_mode() || graph->is_any_type_input() || !graph->inline_sub_graph_kernels().empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "utils/hash_set.h"
|
||||
#include "runtime/graph_scheduler/control_node_scheduler.h"
|
||||
#include "runtime/graph_scheduler/any_type_graph_scheduler.h"
|
||||
#include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
|
||||
#include "runtime/graph_scheduler/mem_swap_scheduler.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_set.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
|
@ -131,7 +132,11 @@ class BACKEND_EXPORT GraphScheduler {
|
|||
KernelActorPtr GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes);
|
||||
|
||||
// Generate inner control flow actor in execution order.
|
||||
KernelActorPtr GenerateInnerControlFlowActor(const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
GraphExecutionStrategy strategy,
|
||||
const std::set<size_t> &ref_input_indexes,
|
||||
const std::set<size_t> &ref_output_indexes);
|
||||
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
|
||||
// previous graph and the head of next graph.
|
||||
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
@ -247,6 +252,8 @@ class BACKEND_EXPORT GraphScheduler {
|
|||
ControlNodeScheduler control_node_scheduler_;
|
||||
// If there is an any type input in graph, it will be used to transform it.
|
||||
AnyTypeGraphScheduler any_type_graph_scheduler_;
|
||||
// If there is inline control flow in kernel graph, it will be used to transform it.
|
||||
InlineControlFlowScheduler inline_control_flow_scheduler_;
|
||||
|
||||
// Build and link swap actor when memory offload is enabled.
|
||||
MemSwapScheduler swap_node_scheduler_;
|
||||
|
|
|
@ -0,0 +1,610 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
|
||||
#include <vector>
|
||||
#include "runtime/graph_scheduler/scheduler_helper.h"
|
||||
#include "ops/framework_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void InlineControlFlowScheduler::LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &inline_sub_graph_kernels = graph->inline_sub_graph_kernels();
|
||||
if (inline_sub_graph_kernels.empty()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Link control arrow for graph:" << graph->ToString();
|
||||
// Only link control arrow between kernels in the same graph.
|
||||
mindspore::HashMap<std::string, AbstractActor *> branch_last_actor;
|
||||
for (size_t i = 0; i < graph->execution_order().size(); ++i) {
|
||||
const auto &to_kernel = graph->execution_order()[i];
|
||||
if (IsRpcActor(to_kernel)) {
|
||||
MS_LOG(INFO) << "Rpc op is not available in the execution order, from kernel: "
|
||||
<< graph->execution_order()[i - 1]->fullname_with_scope()
|
||||
<< ", to kernel:" << graph->execution_order()[i]->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
const auto &iter = inline_sub_graph_kernels.find(to_kernel);
|
||||
std::string current_branch = graph->ToString();
|
||||
if (iter != inline_sub_graph_kernels.end()) {
|
||||
current_branch = iter->second;
|
||||
MS_LOG(DEBUG) << "Kernel:" << to_kernel->fullname_with_scope() << " branch:" << current_branch;
|
||||
}
|
||||
|
||||
const auto to_kernel_type = FetchKernelTransformType(to_kernel, graph, {}, GraphExecutionStrategy::kPipeline);
|
||||
auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_kernel, graph);
|
||||
const auto &actor_iter = branch_last_actor.find(current_branch);
|
||||
if (actor_iter == branch_last_actor.end()) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
|
||||
branch_last_actor[current_branch] = to_actor;
|
||||
MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Add control arrow between " << actor_iter->second->GetAID() << " and " << to_actor->GetAID();
|
||||
SchedulerHelper::AddControlArrow(actor_iter->second, to_actor);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(to_kernel, prim::kPrimConditionSwitch)) {
|
||||
// The control relation end after the condition switch node in graph.
|
||||
branch_last_actor.erase(current_branch);
|
||||
MS_LOG(DEBUG) << "For branch:" << current_branch << " end actor:" << to_actor->GetAID();
|
||||
} else {
|
||||
// The control relation start first kernel in graph.
|
||||
branch_last_actor[current_branch] = to_actor;
|
||||
MS_LOG(DEBUG) << "For branch:" << current_branch << " start actor:" << to_actor->GetAID();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the branch name by input data arrow.
|
||||
std::string InlineControlFlowScheduler::GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
|
||||
KernelActor *condition_gather_actor,
|
||||
DataArrow *data_arrow,
|
||||
const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(condition_switch_actor);
|
||||
MS_EXCEPTION_IF_NULL(condition_gather_actor);
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(condition_gather_actor->kernel());
|
||||
const auto &condition_pair_iter = kernel_graph->condition_gather_to_switch().find(condition_gather_actor->kernel());
|
||||
if (condition_pair_iter == kernel_graph->condition_gather_to_switch().end() ||
|
||||
condition_pair_iter->second != condition_switch_actor->kernel()) {
|
||||
MS_LOG(EXCEPTION) << "Condition switch actor:" << condition_switch_actor->GetAID()
|
||||
<< " and gather actor:" << condition_gather_actor << " is not match.";
|
||||
}
|
||||
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchOutputNum)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
// Get the output branch index in condition gather actor.
|
||||
const auto &output_value = condition_gather_actor->kernel()->GetAttr(kAttrBranchOutputNum);
|
||||
MS_EXCEPTION_IF_NULL(output_value);
|
||||
size_t branch_index = data_arrow->to_input_index_ / GetValue<size_t>(output_value);
|
||||
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchGraphName)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get branch graph name by actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
|
||||
// Get output branch name by branch index.
|
||||
const auto &branch_graph_names = condition_gather_actor->kernel()->GetAttr(kAttrBranchGraphName);
|
||||
MS_EXCEPTION_IF_NULL(branch_graph_names);
|
||||
MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
|
||||
<< " for actor:" << condition_gather_actor->GetAID();
|
||||
if (!branch_graph_names->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
|
||||
<< " for actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_name);
|
||||
if (branch_index >= tuple_name->size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid to index:" << data_arrow->to_input_index_
|
||||
<< " output num:" << GetValue<size_t>(output_value)
|
||||
<< " branch graph name:" << tuple_name->ToString()
|
||||
<< " from actor:" << condition_switch_actor->GetAID()
|
||||
<< " to actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(tuple_name->value()[branch_index]);
|
||||
return GetValue<std::string>(tuple_name->value()[branch_index]);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::FixRefCountByKernelGraphRefMap(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph) {
|
||||
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
|
||||
for (const auto &ref_pair : kernel_graph->GetRefMap()) {
|
||||
const auto &output_pair = ref_pair.first;
|
||||
const auto &origin_pair = ref_pair.second;
|
||||
MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
|
||||
<< " origin node:" << origin_pair.first->fullname_with_scope();
|
||||
const auto &recursive_origin_pair = kernel_graph->GetRefNodeRecursive(output_pair);
|
||||
// If the input node of ref node pair is a condition switch node , the ref count of corresponding switch node input
|
||||
// should add 1.
|
||||
if (recursive_origin_pair.first == condition_switch_actor->kernel() && output_pair.first != nullptr) {
|
||||
MS_LOG(DEBUG) << "Condition switch node is an input of ref node:" << output_pair.first->fullname_with_scope();
|
||||
if (inline_sub_graph_kernels.find(output_pair.first) == inline_sub_graph_kernels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline subgraph name by ref node:"
|
||||
<< output_pair.first->fullname_with_scope();
|
||||
}
|
||||
// Get the branch index for ref output.
|
||||
const auto ¤t_branch_name = inline_sub_graph_kernels.at(output_pair.first);
|
||||
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
|
||||
condition_switch_actor->branch_names_.end(), current_branch_name);
|
||||
if (iter == condition_switch_actor->branch_names_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
|
||||
<< " total branch name:" << condition_switch_actor->branch_names_
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
|
||||
|
||||
if (recursive_origin_pair.second >= output_num || branch_index >= condition_switch_actor->branch_names_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output index:" << recursive_origin_pair.second << " total:" << output_num
|
||||
<< " and branch index:" << branch_index
|
||||
<< " total:" << condition_switch_actor->branch_names_.size()
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
// The ref count of the corresponding branch add 1.
|
||||
condition_switch_actor->branch_origin_ref_count_[branch_index][recursive_origin_pair.second]++;
|
||||
MS_LOG(DEBUG) << "Add ref count for current branch:" << current_branch_name << " branch index:" << branch_index
|
||||
<< " output index:" << recursive_origin_pair.second
|
||||
<< " of actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::FixRefCountByConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(condition_switch_actor);
|
||||
// Collect all the output ref count of condition switch actor.
|
||||
std::vector<size_t> total_ref_count;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_switch_actor->kernel(), i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
total_ref_count.emplace_back(device_address->original_ref_count());
|
||||
MS_LOG(DEBUG) << "For actor:" << condition_switch_actor->GetAID() << " output device address:" << device_address
|
||||
<< " output index:" << i << " ref_count:" << total_ref_count.back();
|
||||
}
|
||||
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(condition_switch_actor->kernel());
|
||||
// Input num should same as the output num and the condition of switch node.
|
||||
if (input_num != output_num + 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num:" << input_num << " and output num:" << output_num
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
|
||||
// Add the ref count to the input of condition switch actor.
|
||||
for (size_t i = 1; i < input_num; ++i) {
|
||||
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_switch_actor->kernel(), i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_switch_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
|
||||
if (device_address->original_ref_count() == SIZE_MAX) {
|
||||
continue;
|
||||
}
|
||||
device_address->set_original_ref_count(device_address->original_ref_count() + total_ref_count[i - 1] - 1);
|
||||
device_address->ResetRefCount();
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_switch_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
|
||||
}
|
||||
FixRefCountByKernelGraphRefMap(condition_switch_actor, kernel_graph);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitOutputDataBranchInfoForConditionSwitchActor(
|
||||
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
|
||||
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
|
||||
condition_switch_actor->output_data_branch_indexes_.resize(condition_switch_actor->output_data_arrows().size());
|
||||
// Get the index for each output data arrow.
|
||||
for (size_t i = 0; i < condition_switch_actor->output_data_arrows().size(); ++i) {
|
||||
const auto &output_node = condition_switch_actor->output_data_nodes()[i];
|
||||
const auto &data_arrow = condition_switch_actor->output_data_arrows()[i];
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
const auto &to_actor = FetchActor(data_arrow->to_op_id_.Name());
|
||||
if (to_actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get actor:" << data_arrow->to_op_id_.Name()
|
||||
<< " from actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
if (to_actor->type() != KernelTransformType::kConditionGatherActor &&
|
||||
to_actor->type() != KernelTransformType::kKernelActor) {
|
||||
MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
|
||||
<< " from actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
|
||||
const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
|
||||
std::string current_branch_name;
|
||||
if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
|
||||
current_branch_name =
|
||||
GetBranchNameByConditionGatherActor(condition_switch_actor, to_kernel_actor, data_arrow.get(), kernel_graph);
|
||||
} else {
|
||||
if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
|
||||
<< to_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " in actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
|
||||
<< " in actor:" << condition_switch_actor->GetAID()
|
||||
<< " from index:" << data_arrow->from_output_index_ << " to actor:" << data_arrow->to_op_id_
|
||||
<< " to index:" << data_arrow->to_input_index_;
|
||||
current_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
|
||||
}
|
||||
// Get branch index for output data arrow.
|
||||
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
|
||||
condition_switch_actor->branch_names_.end(), current_branch_name);
|
||||
if (iter == condition_switch_actor->branch_names_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
|
||||
<< " total branch name:" << condition_switch_actor->branch_names_
|
||||
<< " from actor:" << condition_switch_actor->GetAID() << " to actor:" << to_actor->GetAID();
|
||||
}
|
||||
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
|
||||
if (IntToSize(data_arrow->from_output_index_) >= output_num ||
|
||||
branch_index >= condition_switch_actor->branch_names_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output index:" << data_arrow->from_output_index_ << " total:" << output_num
|
||||
<< " and branch index:" << branch_index
|
||||
<< " total:" << condition_switch_actor->branch_names_.size()
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
condition_switch_actor->branch_origin_ref_count_[branch_index][data_arrow->from_output_index_]++;
|
||||
condition_switch_actor->output_data_branch_indexes_[i] = branch_index;
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitOutputControlBranchInfoForConditionSwitchActor(
|
||||
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
|
||||
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
|
||||
condition_switch_actor->output_control_branch_indexes_.resize(condition_switch_actor->output_control_arrows().size());
|
||||
// Get the index for each output control arrow.
|
||||
for (size_t i = 0; i < condition_switch_actor->output_control_arrows().size(); ++i) {
|
||||
const auto &arrow = condition_switch_actor->output_control_arrows()[i];
|
||||
MS_EXCEPTION_IF_NULL(arrow);
|
||||
const auto &to_actor = FetchActor(arrow->to_op_id_.Name());
|
||||
if (to_actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get actor:" << arrow->to_op_id_.Name()
|
||||
<< " from actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
if (to_actor->type() == KernelTransformType::kConditionGatherActor) {
|
||||
condition_switch_actor->output_control_branch_indexes_[i] = SIZE_MAX;
|
||||
continue;
|
||||
}
|
||||
if (to_actor->type() != KernelTransformType::kKernelActor &&
|
||||
to_actor->type() != KernelTransformType::kConditionSwitchActor) {
|
||||
MS_LOG(EXCEPTION) << "Invalid to actor:" << to_actor->GetAID()
|
||||
<< " from actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
const auto &to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_actor->kernel());
|
||||
if (inline_sub_graph_kernels.find(to_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
|
||||
<< to_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " in actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Sub graph kernel:" << to_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " belong graph:" << inline_sub_graph_kernels.at(to_kernel_actor->kernel())
|
||||
<< " in actor:" << condition_switch_actor->GetAID() << " to actor:" << arrow->to_op_id_;
|
||||
const auto ¤t_branch_name = inline_sub_graph_kernels.at(to_kernel_actor->kernel());
|
||||
const auto &iter = std::find(condition_switch_actor->branch_names_.begin(),
|
||||
condition_switch_actor->branch_names_.end(), current_branch_name);
|
||||
if (iter == condition_switch_actor->branch_names_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch name:" << current_branch_name
|
||||
<< " total branch name:" << condition_switch_actor->branch_names_
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
size_t branch_index = iter - condition_switch_actor->branch_names_.begin();
|
||||
condition_switch_actor->output_control_branch_indexes_[i] = branch_index;
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitOutputBranchInfoForConditionSwitchActor(
|
||||
ConditionSwitchActor *const condition_switch_actor, const KernelGraphPtr &kernel_graph) {
|
||||
if (condition_switch_actor->output_data_nodes().size() != condition_switch_actor->output_data_arrows().size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid data node size:" << condition_switch_actor->output_data_nodes().size()
|
||||
<< " and arrow size:" << condition_switch_actor->output_data_arrows().size()
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
InitOutputDataBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
|
||||
InitOutputControlBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
|
||||
MS_LOG(DEBUG) << "Branch origin ref count:" << condition_switch_actor->branch_origin_ref_count_
|
||||
<< " output data branch index:" << condition_switch_actor->output_data_branch_indexes_
|
||||
<< " output control branch index:" << condition_switch_actor->output_control_branch_indexes_
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::HandleConditionSwitchActor(const KernelActorPtr &kernel_actor) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(kernel_actor.get());
|
||||
MS_EXCEPTION_IF_NULL(condition_switch_actor);
|
||||
MS_EXCEPTION_IF_NULL(condition_switch_actor->kernel());
|
||||
const auto &graph = condition_switch_actor->kernel()->func_graph();
|
||||
if (graph == nullptr || !graph->isa<KernelGraph>()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
const auto &kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
|
||||
<< " by actor:" << condition_switch_actor->GetAID();
|
||||
if (!condition_switch_actor->kernel()->HasAttr(kInlineSubGraphName)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
const auto &inline_sub_graph_names = condition_switch_actor->kernel()->GetAttr(kInlineSubGraphName);
|
||||
MS_EXCEPTION_IF_NULL(inline_sub_graph_names);
|
||||
MS_LOG(DEBUG) << "inline sub graph name:" << inline_sub_graph_names->ToString()
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
if (!inline_sub_graph_names->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input subgraph name:" << inline_sub_graph_names->ToString()
|
||||
<< " for actor:" << condition_switch_actor->GetAID();
|
||||
}
|
||||
const auto &tuple_name = inline_sub_graph_names->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_name);
|
||||
std::vector<std::string> branch_names;
|
||||
for_each(tuple_name->value().begin(), tuple_name->value().end(),
|
||||
[&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
|
||||
condition_switch_actor->branch_names_ = branch_names;
|
||||
// Fix ref count.
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_switch_actor->kernel());
|
||||
condition_switch_actor->branch_origin_ref_count_ =
|
||||
std::vector<std::vector<size_t>>(tuple_name->size(), vector<size_t>(output_num, 0));
|
||||
|
||||
FixRefCountByConditionSwitchActor(condition_switch_actor, kernel_graph);
|
||||
InitOutputBranchInfoForConditionSwitchActor(condition_switch_actor, kernel_graph);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::FixRefCountByKernelGraphRefMap(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph) {
|
||||
// If the input node of ref node pair is a condition gather node , the ref count of corresponding gather node input
|
||||
// should add 1.
|
||||
for (const auto &ref_pair : kernel_graph->GetRefMap()) {
|
||||
const auto &output_pair = ref_pair.first;
|
||||
const auto &origin_pair = ref_pair.second;
|
||||
MS_LOG(DEBUG) << "output node:" << output_pair.first->fullname_with_scope()
|
||||
<< " origin node:" << origin_pair.first->fullname_with_scope();
|
||||
const auto &recursive_origin_pair = kernel_graph->GetRefNodeRecursive(output_pair);
|
||||
if (recursive_origin_pair.first == condition_gather_actor->kernel() && output_pair.first != nullptr) {
|
||||
MS_LOG(DEBUG) << "Condition gather node output index:" << recursive_origin_pair.second
|
||||
<< " is an input of ref node:" << output_pair.first->fullname_with_scope()
|
||||
<< " to index:" << output_pair.second
|
||||
<< " need update ref count for actor:" << condition_gather_actor->GetAID();
|
||||
for (size_t i = recursive_origin_pair.second; i < common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
|
||||
i += condition_gather_actor->branch_output_num_) {
|
||||
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
|
||||
if (device_address->original_ref_count() == SIZE_MAX) {
|
||||
continue;
|
||||
}
|
||||
size_t pre_origin_ref_count = device_address->original_ref_count();
|
||||
device_address->set_original_ref_count(device_address->original_ref_count() + 1);
|
||||
device_address->ResetRefCount();
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " fix ref count from:" << pre_origin_ref_count
|
||||
<< " to:" << device_address->original_ref_count();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph) {
|
||||
std::vector<size_t> total_ref_count;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(condition_gather_actor->kernel());
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
const auto &device_address = AnfAlgo::GetMutableOutputAddr(condition_gather_actor->kernel(), i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
total_ref_count.emplace_back(device_address->original_ref_count());
|
||||
MS_LOG(DEBUG) << "For actor:" << condition_gather_actor->GetAID() << " output device address:" << device_address
|
||||
<< " output index:" << i << " ref_count:" << total_ref_count.back();
|
||||
}
|
||||
size_t input_num = common::AnfAlgo::GetInputNum(condition_gather_actor->kernel());
|
||||
if (input_num == 0 || input_num % condition_gather_actor->branch_output_num_ != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input num:" << input_num
|
||||
<< " branch output num:" << condition_gather_actor->branch_output_num_
|
||||
<< " for actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
const auto &device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(condition_gather_actor->kernel(), i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " ref_count:" << device_address->original_ref_count();
|
||||
if (device_address->original_ref_count() == SIZE_MAX) {
|
||||
continue;
|
||||
}
|
||||
size_t pre_origin_ref_count = device_address->original_ref_count();
|
||||
// The real ref count is the relative position of this branch output.
|
||||
device_address->set_original_ref_count(device_address->original_ref_count() +
|
||||
total_ref_count[i % condition_gather_actor->branch_output_num_] - 1);
|
||||
device_address->ResetRefCount();
|
||||
MS_LOG(DEBUG) << "For actor::" << condition_gather_actor->GetAID() << " input device address:" << device_address
|
||||
<< " input index:" << i << " fix ref count from:" << pre_origin_ref_count
|
||||
<< " to:" << device_address->original_ref_count();
|
||||
}
|
||||
FixRefCountByKernelGraphRefMap(condition_gather_actor, kernel_graph);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitInputDataBranchInfoForConditionGatherActor(
|
||||
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
|
||||
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
|
||||
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
|
||||
<< " by actor:" << condition_gather_actor->GetAID();
|
||||
for (const auto &pair : condition_gather_actor->input_data_arrow_aids_) {
|
||||
const auto &from_aid = pair.first;
|
||||
const auto &data_arrow = pair.second;
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
const auto &from_actor = FetchActor(from_aid.Name());
|
||||
if (from_actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
if (from_actor->type() != KernelTransformType::kKernelActor &&
|
||||
from_actor->type() != KernelTransformType::kConditionSwitchActor) {
|
||||
MS_LOG(EXCEPTION) << "Invalid to actor:" << from_actor->GetAID()
|
||||
<< " from actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_kernel_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
|
||||
std::string current_branch_name;
|
||||
if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
|
||||
current_branch_name =
|
||||
GetBranchNameByConditionGatherActor(from_kernel_actor, condition_gather_actor, data_arrow, kernel_graph);
|
||||
} else {
|
||||
if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
|
||||
<< from_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " in actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
|
||||
<< " in actor:" << condition_gather_actor->GetAID();
|
||||
current_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
|
||||
}
|
||||
const auto &iter = condition_gather_actor->branch_name_to_id_.find(current_branch_name);
|
||||
if (iter == condition_gather_actor->branch_name_to_id_.end()) {
|
||||
condition_gather_actor->branch_name_to_id_[current_branch_name] =
|
||||
condition_gather_actor->branch_name_to_id_.size();
|
||||
MS_LOG(DEBUG) << "Add branch index:" << condition_gather_actor->branch_name_to_id_[current_branch_name]
|
||||
<< " branch name:" << current_branch_name << " for actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
// Get the input data num of each branch.
|
||||
if (condition_gather_actor->branch_name_to_input_data_num_.find(current_branch_name) ==
|
||||
condition_gather_actor->branch_name_to_input_data_num_.end()) {
|
||||
condition_gather_actor->branch_name_to_input_data_num_[current_branch_name] = 1;
|
||||
} else {
|
||||
condition_gather_actor->branch_name_to_input_data_num_[current_branch_name]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitInputControlBranchInfoForConditionGatherActor(
|
||||
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
|
||||
const auto &inline_sub_graph_kernels = kernel_graph->inline_sub_graph_kernels();
|
||||
MS_LOG(DEBUG) << "Fetch kernel graph:" << kernel_graph->ToString()
|
||||
<< " by actor:" << condition_gather_actor->GetAID();
|
||||
|
||||
for (const auto &pair : condition_gather_actor->input_control_arrow_aids_) {
|
||||
const auto &from_aid = pair.first;
|
||||
const auto &from_actor = FetchActor(from_aid.Name());
|
||||
if (from_actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get from actor:" << from_aid << " to actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
if (from_actor->type() == KernelTransformType::kConditionSwitchActor) {
|
||||
continue;
|
||||
}
|
||||
if (from_actor->type() != KernelTransformType::kKernelActor &&
|
||||
from_actor->type() != KernelTransformType::kConditionGatherActor) {
|
||||
MS_LOG(EXCEPTION) << "Invalid from actor:" << from_actor->GetAID()
|
||||
<< " to actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_kernel_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_kernel_actor->kernel());
|
||||
if (inline_sub_graph_kernels.find(from_kernel_actor->kernel()) == inline_sub_graph_kernels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline sub graph name by user node:"
|
||||
<< from_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " in actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Sub graph kernel:" << from_kernel_actor->kernel()->fullname_with_scope()
|
||||
<< " belong graph:" << inline_sub_graph_kernels.at(from_kernel_actor->kernel())
|
||||
<< " in actor:" << condition_gather_actor->GetAID();
|
||||
const auto ¤t_branch_name = inline_sub_graph_kernels.at(from_kernel_actor->kernel());
|
||||
// Get input op control num of each branch.
|
||||
if (condition_gather_actor->branch_name_to_input_control_num_.find(current_branch_name) ==
|
||||
condition_gather_actor->branch_name_to_input_control_num_.end()) {
|
||||
condition_gather_actor->branch_name_to_input_control_num_[current_branch_name] = 1;
|
||||
} else {
|
||||
condition_gather_actor->branch_name_to_input_control_num_[current_branch_name]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::InitInputBranchInfoForConditionGatherActor(
|
||||
ConditionGatherActor *const condition_gather_actor, const KernelGraphPtr &kernel_graph) {
|
||||
InitInputDataBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
|
||||
InitInputControlBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::HandleConditionGatherActor(const KernelActorPtr &kernel_actor) {
|
||||
const auto &condition_gather_actor = dynamic_cast<ConditionGatherActor *>(kernel_actor.get());
|
||||
MS_EXCEPTION_IF_NULL(condition_gather_actor);
|
||||
MS_EXCEPTION_IF_NULL(condition_gather_actor->kernel());
|
||||
const auto &graph = condition_gather_actor->kernel()->func_graph();
|
||||
if (graph == nullptr || !graph->isa<KernelGraph>()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get kernel graph by actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &gather_to_switch_iter = kernel_graph->condition_gather_to_switch().find(condition_gather_actor->kernel());
|
||||
if (gather_to_switch_iter == kernel_graph->condition_gather_to_switch().end()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get switch node by gather node:"
|
||||
<< condition_gather_actor->kernel()->fullname_with_scope();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(gather_to_switch_iter->second);
|
||||
const auto &actor = FetchActor(gather_to_switch_iter->second->fullname_with_scope());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &condition_switch_actor = dynamic_cast<ConditionSwitchActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(condition_switch_actor);
|
||||
condition_switch_actor->gather_aid_ = const_cast<AID *>(&condition_gather_actor->GetAID());
|
||||
|
||||
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchOutputNum)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get branch output num by actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &output_value = condition_gather_actor->kernel()->GetAttr(kAttrBranchOutputNum);
|
||||
MS_EXCEPTION_IF_NULL(output_value);
|
||||
condition_gather_actor->branch_output_num_ = GetValue<size_t>(output_value);
|
||||
|
||||
if (!condition_gather_actor->kernel()->HasAttr(kAttrBranchGraphName)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get inline graph name by actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &branch_graph_names = condition_gather_actor->kernel()->GetAttr(kAttrBranchGraphName);
|
||||
MS_EXCEPTION_IF_NULL(branch_graph_names);
|
||||
MS_LOG(DEBUG) << "Branch graph name:" << branch_graph_names->ToString()
|
||||
<< " for actor:" << condition_gather_actor->GetAID();
|
||||
if (!branch_graph_names->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid branch group name:" << branch_graph_names->ToString()
|
||||
<< " for actor:" << condition_gather_actor->GetAID();
|
||||
}
|
||||
const auto &tuple_name = branch_graph_names->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_name);
|
||||
std::vector<std::string> branch_names;
|
||||
std::for_each(tuple_name->value().begin(), tuple_name->value().end(),
|
||||
[&branch_names](const auto &value) { branch_names.emplace_back(GetValue<std::string>(value)); });
|
||||
condition_gather_actor->branch_names_ = branch_names;
|
||||
// Fix ref count.
|
||||
FixRefCountByConditionGatherActor(condition_gather_actor, kernel_graph);
|
||||
InitInputBranchInfoForConditionGatherActor(condition_gather_actor, kernel_graph);
|
||||
}
|
||||
|
||||
void InlineControlFlowScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
LinkControlArrowByExecutionOrder(graph, graph_compiler_info);
|
||||
}
|
||||
}
|
||||
for (const auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
if (kernel_actor->type() == KernelTransformType::kConditionSwitchActor) {
|
||||
HandleConditionSwitchActor(kernel_actor);
|
||||
} else if (kernel_actor->type() == KernelTransformType::kConditionGatherActor) {
|
||||
HandleConditionGatherActor(kernel_actor);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
|
||||
|
||||
#include <string>
|
||||
#include "runtime/graph_scheduler/actor/actor_set.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
class InlineControlFlowScheduler {
|
||||
public:
|
||||
InlineControlFlowScheduler() = default;
|
||||
~InlineControlFlowScheduler() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(InlineControlFlowScheduler);
|
||||
|
||||
// Link control arrows and fix the member variables for condition actors.
|
||||
void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
private:
|
||||
void LinkControlArrowByExecutionOrder(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info);
|
||||
// Fix the member variables for condition actors.
|
||||
void HandleConditionSwitchActor(const KernelActorPtr &kernel_actor);
|
||||
void HandleConditionGatherActor(const KernelActorPtr &kernel_actor);
|
||||
|
||||
// Init the output branch info for condition actor.
|
||||
// For condition switch actor, the output arrow include all the output branch and should be distinguished.
|
||||
void InitOutputBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void InitOutputControlBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void InitOutputDataBranchInfoForConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void InitInputBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void InitInputDataBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void InitInputControlBranchInfoForConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
|
||||
// Fix ref count for condition actors.
|
||||
// In condition switch actor, the ref count of actor should be change to total num for both branch.
|
||||
// In condition gather actor, the ref count of gather input should add the ref count of gather output.
|
||||
// The ref count of ref node should be add to the input of condition actor.
|
||||
void FixRefCountByConditionSwitchActor(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void FixRefCountByKernelGraphRefMap(ConditionSwitchActor *const condition_switch_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void FixRefCountByConditionGatherActor(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
void FixRefCountByKernelGraphRefMap(ConditionGatherActor *const condition_gather_actor,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
|
||||
std::string GetBranchNameByConditionGatherActor(KernelActor *condition_switch_actor,
|
||||
KernelActor *condition_gather_actor, DataArrow *data_arrow,
|
||||
const KernelGraphPtr &kernel_graph);
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_INLINE_CONTROL_FLOW_SCHEDULER_H_
|
|
@ -717,7 +717,9 @@ void SchedulerHelper::AddMemorySign(AbstractActor *const from_actor, AbstractAct
|
|||
KernelGraphPtr SchedulerHelper::FetchKernelGraphByActor(AbstractActor *const actor) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
AnfNode *from_kernel = nullptr;
|
||||
if (actor->type() == KernelTransformType::kKernelActor) {
|
||||
if (actor->type() == KernelTransformType::kKernelActor ||
|
||||
actor->type() == KernelTransformType::kConditionGatherActor ||
|
||||
actor->type() == KernelTransformType::kConditionSwitchActor) {
|
||||
auto kernel_actor = dynamic_cast<KernelActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
from_kernel = kernel_actor->kernel().get();
|
||||
|
|
|
@ -119,7 +119,15 @@ bool IsMultiLayerTuple(const abstract::AbstractBasePtr &abstract) {
|
|||
});
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node) {
|
||||
namespace {
|
||||
bool IsMultiOutput(const AnfNodePtr &node) {
|
||||
return node != nullptr && node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractSequence>() &&
|
||||
node->abstract()->cast<abstract::AbstractSequencePtr>()->size() > 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Output node: " << node->fullname_with_scope();
|
||||
std::vector<KernelWithIndex> ret;
|
||||
|
@ -129,17 +137,26 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->size(); i++) {
|
||||
auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i));
|
||||
auto make_tuple_output = GetAllOutputWithIndexInner(make_tuple->input(i), return_types);
|
||||
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (std::any_of(return_types.begin(), return_types.end(), [&node](const PrimitivePtr &prim_type) -> bool {
|
||||
return common::AnfAlgo::CheckPrimitiveType(node, prim_type);
|
||||
})) {
|
||||
if (IsMultiOutput(node)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid get all output with index node:" << node->DebugString()
|
||||
<< " abstract:" << node->abstract()->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Need node flatten output of node:" << node->DebugString();
|
||||
return {KernelWithIndex(node, 0)};
|
||||
}
|
||||
// The depend node need get the real node.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
auto depend_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend));
|
||||
auto real_output = GetAllOutputWithIndexInner(depend_node->input(kRealInputIndexInDepend), return_types);
|
||||
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
@ -196,7 +213,7 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
|
||||
// The MakeTuple/MakeSparse node need recurse.
|
||||
if (IsOneOfPrimitiveCNode(output_with_index.first, expand_prims)) {
|
||||
auto output_vector = GetAllOutputWithIndexInner(output_with_index.first);
|
||||
auto output_vector = GetAllOutputWithIndexInner(output_with_index.first, return_types);
|
||||
if (output_vector.size() <= output_with_index.second) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "Invalid index:" << output_with_index.second
|
||||
<< " for outputs of node:" << output_with_index.first->DebugString();
|
||||
|
@ -438,8 +455,9 @@ std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithOutMonadAndParameter(const
|
|||
return real_output;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
||||
auto ret = GetAllOutputWithIndexInner(node);
|
||||
std::vector<KernelWithIndex> AnfAlgo::GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types) {
|
||||
auto ret = GetAllOutputWithIndexInner(node, return_types);
|
||||
std::map<AnfNodePtr, size_t> value_node_index;
|
||||
|
||||
// Unify the output of the front and back end to the ValueTuple
|
||||
|
|
|
@ -169,12 +169,17 @@ GVAR_DEF(PrimitivePtr, kPrimCall, std::make_shared<Primitive>(kCallOpName));
|
|||
GVAR_DEF(PrimitivePtr, kPrimRaise,
|
||||
std::make_shared<Primitive>(kRaiseOpName, mindspore::HashMap<std::string, ValuePtr>(
|
||||
{{std::string(GRAPH_FLAG_SIDE_EFFECT_IO), MakeValue(true)}})));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("call_inline"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSwitchLayer, std::make_shared<Primitive>("switch_layer"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStringUpper, std::make_shared<Primitive>(kStringUpperOpName));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStringLower, std::make_shared<Primitive>(kStringLowerOpName));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFormat, std::make_shared<Primitive>(kFormatOpName));
|
||||
|
||||
// Backend Inline
|
||||
GVAR_DEF(PrimitivePtr, kPrimCallInline, std::make_shared<Primitive>("CallInline"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPartialInline, std::make_shared<Primitive>("PartialInline"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConditionSwitch, std::make_shared<Primitive>("ConditionSwitch"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConditionGather, std::make_shared<Primitive>("ConditionGather"));
|
||||
|
||||
// Pack
|
||||
GVAR_DEF(PrimitivePtr, kPrimPackFunc, std::make_shared<Primitive>(kPackFuncOpName));
|
||||
} // namespace prim
|
||||
|
|
|
@ -602,7 +602,8 @@ bool AnfUtils::NeedJumpMonadOutput(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName};
|
||||
std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName, prim::kPrimConditionSwitch->name(),
|
||||
prim::kPrimConditionGather->name()};
|
||||
if (std::find(jump_monad_output_nodes.begin(), jump_monad_output_nodes.end(), GetCNodeName(cnode)) !=
|
||||
jump_monad_output_nodes.end()) {
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,672 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
from mindspore import context, Tensor, jit, ops, mutable
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path='./log/')
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_if():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Inline switch node into kernel graph.
|
||||
Expectation: Not throw exception.
|
||||
"""
|
||||
param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
@jit
|
||||
def foo(x, y, param_a, param_b):
|
||||
if param_a > param_b:
|
||||
param_b += 1
|
||||
return x + param_b, y + param_b
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
ret1 = foo(x, x, param_a, param_b)
|
||||
ret2 = foo(x, x, param_a, param_b)
|
||||
assert ret1 == (Tensor(7, mstype.int32), Tensor(7, mstype.int32))
|
||||
assert ret2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_parameter():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
param_a = Parameter(Tensor(5))
|
||||
param_b = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
if x < 3:
|
||||
return param_a
|
||||
return param_b
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_valuenode():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
if x < 3:
|
||||
return 1
|
||||
return 2
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_input():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y, z):
|
||||
if x < 3:
|
||||
return y
|
||||
return z
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor(2), Tensor(3))
|
||||
assert ret1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_value_node_output_in_single_branch():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Inline switch node into kernel graph.
|
||||
Expectation: Not throw exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def BranchReturnTensor(x, y):
|
||||
x = x + Tensor(2, mstype.int32)
|
||||
y = x + y
|
||||
if x < 5:
|
||||
return y, Tensor(2, mstype.int32)
|
||||
return x, y
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
ret1 = BranchReturnTensor(x, x)
|
||||
ret2 = BranchReturnTensor(x, x)
|
||||
ret3 = BranchReturnTensor(x, x)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_diff_ref_count_in_branch():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Inline switch node into kernel graph.
|
||||
Expectation: Not throw exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def BranchDiffRefCount(x, y):
|
||||
x = x + Tensor(2, mstype.int32)
|
||||
y = x + y
|
||||
if x < 5:
|
||||
x = x + 3
|
||||
y = x + y
|
||||
else:
|
||||
x = x + 3
|
||||
x = x + 4
|
||||
x = x + 5
|
||||
y = x + y
|
||||
y = x + y
|
||||
y = x + y
|
||||
return x, y
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
ret1 = BranchDiffRefCount(x, x)
|
||||
x = Tensor(4, mstype.int32)
|
||||
ret2 = BranchDiffRefCount(x, x)
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_kernel_backoff():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Inline switch node into kernel graph.
|
||||
Expectation: Not throw exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def BranchKernelBackOff(x, y, shape):
|
||||
x = x + Tensor(2, mstype.int32)
|
||||
if y < 5:
|
||||
z = ops.reshape(x, shape)
|
||||
else:
|
||||
z = x
|
||||
return z + 1
|
||||
|
||||
x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
|
||||
y = Tensor(2, mstype.int32)
|
||||
ret1 = BranchKernelBackOff(x, y, mutable((2, 3)))
|
||||
ret2 = BranchKernelBackOff(x, y, mutable((2, 3)))
|
||||
ret3 = BranchKernelBackOff(x, y, mutable((2, 3)))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_update_parameter():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
param_a = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a):
|
||||
x = x + param_a
|
||||
if x < 3:
|
||||
param_a = param_a + 2
|
||||
else:
|
||||
param_a = param_a + x
|
||||
return param_a
|
||||
|
||||
ret1 = foo(Tensor(1), param_a)
|
||||
ret2 = foo(Tensor(1), param_a)
|
||||
ret3 = foo(Tensor(1), param_a)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_update_and_return_parameter():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
param_a = Parameter(Tensor(5))
|
||||
param_b = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
x = x + param_a
|
||||
if x < 3:
|
||||
param_a = param_a + 2
|
||||
param_b = param_b - param_a
|
||||
return Tensor(2), param_b
|
||||
param_a = param_a + x
|
||||
param_b = param_b + param_a
|
||||
return param_a, param_b
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
ret2 = foo(Tensor(1), param_a, param_b)
|
||||
ret3 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_switch_input_in_branch():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
param_a = Parameter(Tensor(5))
|
||||
param_b = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
x = x + param_a
|
||||
if x < 3:
|
||||
param_a = param_a + 2
|
||||
param_b = param_b - param_a
|
||||
return x, param_b
|
||||
param_a = param_a + x
|
||||
param_b = param_b + param_a
|
||||
return param_a, param_b
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
ret2 = foo(Tensor(1), param_a, param_b)
|
||||
ret3 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_switch_input():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
param_a = Parameter(Tensor(5))
|
||||
param_b = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
x = x + param_a
|
||||
if x < 3:
|
||||
param_a = param_a + 2
|
||||
param_b = param_b - param_a
|
||||
else:
|
||||
param_a = param_a + x
|
||||
param_b = param_b + param_a
|
||||
return x, param_b, 3
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
ret2 = foo(Tensor(1), param_a, param_b)
|
||||
ret3 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tuple_args_to_dynamic_tuple_para():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
y_shape = ops.shape(y)
|
||||
if x < 3:
|
||||
y_shape = y_shape * 2
|
||||
else:
|
||||
y_shape = y_shape * 3
|
||||
return y_shape[0]
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
|
||||
ret2 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
|
||||
ret3 = foo(Tensor(1), Tensor([[6, 6, 6], [6, 6, 6]]))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tuple_input_to_switch():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y, dst_shape):
|
||||
y, _ = ops.unique(y)
|
||||
y = ops.reshape(y, dst_shape)
|
||||
y_shape = ops.shape(y)
|
||||
if x < 3:
|
||||
y_shape = y_shape * 2
|
||||
else:
|
||||
y_shape = y_shape * 3
|
||||
return y_shape
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
|
||||
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
|
||||
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_tuple_input_to_switch():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, dyn_tuple):
|
||||
if x < 3:
|
||||
dyn_tuple = dyn_tuple * 2
|
||||
else:
|
||||
dyn_tuple = dyn_tuple * 3
|
||||
return dyn_tuple
|
||||
|
||||
ret1 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
|
||||
ret2 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
|
||||
ret3 = foo(Tensor(1), mutable((2, 3), dynamic_len=True))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_condition():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, cond):
|
||||
if cond:
|
||||
x = x * 2
|
||||
return x, cond
|
||||
x = x * 3
|
||||
return x, cond
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor(True))
|
||||
ret2 = foo(Tensor(1), Tensor(True))
|
||||
ret3 = foo(Tensor(1), Tensor(True))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_include_other_output():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
y = y + 2
|
||||
y = y * 3
|
||||
y = y / 4
|
||||
y = y - 5
|
||||
y = y * y
|
||||
if x < 5:
|
||||
x = x * 2
|
||||
else:
|
||||
x = x + 2
|
||||
return x, y
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor(2))
|
||||
ret2 = foo(Tensor(1), Tensor(2))
|
||||
ret3 = foo(Tensor(1), Tensor(2))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_output_include_refnode():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y, dst_shape):
|
||||
y, _ = ops.unique(y)
|
||||
y = ops.reshape(y, dst_shape)
|
||||
if x < 3:
|
||||
y = ops.expand_dims(y, 1)
|
||||
y = ops.flatten(y)
|
||||
return y
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]), mutable((2, 3)))
|
||||
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
|
||||
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]), mutable((2, 3)))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_include_dynamic_shape():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
y, _ = ops.unique(y)
|
||||
if x < 3:
|
||||
y = y * 2
|
||||
else:
|
||||
z1 = y / 6
|
||||
z2 = y * 2
|
||||
z3 = y - Tensor([[6, 12, 18], [24, 30, 36]])
|
||||
z4 = y + Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
y = z1 + z2 + z3 + z4
|
||||
return y
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
|
||||
ret2 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [12, 18, 30], [18, 24, 36]]))
|
||||
ret3 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36]]))
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_control_arrow_from_switch_to_gather():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
param_a = Parameter(Tensor(5))
|
||||
param_b = Parameter(Tensor(5))
|
||||
|
||||
@jit
|
||||
def foo(x, param_a, param_b):
|
||||
x = x + param_a
|
||||
if x < 3:
|
||||
param_a = param_a + 2
|
||||
param_b = param_b - param_a
|
||||
return Tensor(2), param_b
|
||||
x = x + param_a
|
||||
return param_a, param_b
|
||||
|
||||
ret1 = foo(Tensor(1), param_a, param_b)
|
||||
ret2 = foo(Tensor(1), param_a, param_b)
|
||||
ret3 = foo(Tensor(1), param_a, param_b)
|
||||
assert ret1
|
||||
assert ret2
|
||||
assert ret3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_only_u_input():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
x = x + 1
|
||||
if x < 3:
|
||||
ops.print("this is true")
|
||||
else:
|
||||
y = ops.reshape(y, (4, 1))
|
||||
ops.print("this is false")
|
||||
return ops.shape(y)
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
|
||||
assert ret1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_u_input_and_input():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
x = x + 1
|
||||
if x < 3:
|
||||
ops.print("this is true")
|
||||
else:
|
||||
y = ops.reshape(y, (4, 1))
|
||||
ops.print("this is false")
|
||||
return ops.shape(y)
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[1, 2], [3, 4]]))
|
||||
assert ret1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_output_real_tuple():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y):
|
||||
if x < 3:
|
||||
y, _ = ops.unique(y)
|
||||
y = ops.expand_dims(y, 1)
|
||||
y = ops.flatten(y)
|
||||
z = ops.shape(y)
|
||||
else:
|
||||
z = ops.shape(y)
|
||||
return z
|
||||
|
||||
ret1 = foo(Tensor(1), Tensor([[6, 12, 18], [24, 30, 36], [6, 18, 36]]))
|
||||
ret2 = foo(Tensor(5), Tensor([[6, 12, 18], [24, 30, 36]]))
|
||||
assert ret1
|
||||
assert ret2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_branch_output_dynamic_tuple():
|
||||
"""
|
||||
Feature: Contrtol flow inline.
|
||||
Description: Control flow if.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x, y, shape):
|
||||
if y < 5:
|
||||
z = ops.reshape(x, shape)
|
||||
out = ops.shape(z)
|
||||
else:
|
||||
out = ops.shape(x)
|
||||
return out
|
||||
|
||||
x = Tensor([2, 2, 2, 2, 2, 2], mstype.int32)
|
||||
y = Tensor(2, mstype.int32)
|
||||
ret1 = foo(x, y, mutable((2, 3), dynamic_len=True))
|
||||
assert ret1
|
Loading…
Reference in New Issue