forked from mindspore-Ecosystem/mindspore
link child graphs
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
d9c74e0acd
commit
af5019b94f
|
@ -29,6 +29,7 @@
|
|||
#include "hccl/hcom.h"
|
||||
#include "common/trans.h"
|
||||
#include "runtime/context.h"
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#include "device/ascend/ascend_memory_pool.h"
|
||||
#include "framework/ge_runtime/model_runner.h"
|
||||
|
@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
|
||||
AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance();
|
||||
AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance();
|
||||
// the streams' flag not HEAD_STREAM
|
||||
std::vector<uint32_t> wait_active_stream_list;
|
||||
assign_instance.GetWaitStreams(&wait_active_stream_list);
|
||||
auto force_copy_stream_list = assign_instance.hcom_streams();
|
||||
stream_assign_instance.GetWaitStreams(&wait_active_stream_list);
|
||||
auto force_copy_stream_list = stream_assign_instance.hcom_streams();
|
||||
|
||||
MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum()
|
||||
<< ", total event num:" << assign_instance.total_event_num()
|
||||
MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum()
|
||||
<< ", total event num:" << stream_assign_instance.total_event_num()
|
||||
<< ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph))
|
||||
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
|
||||
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
|
||||
|
||||
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
|
||||
std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
|
||||
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
|
||||
0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0);
|
||||
0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
|
||||
stream_assign_instance.total_event_num(), 0);
|
||||
|
||||
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
|
||||
if (!ret.second) {
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
|
@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
|
|||
uint32_t goto_label_id = GetValue<uint32_t>(value);
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
|
||||
node->set_inputs({node->input(0)});
|
||||
}
|
||||
|
||||
static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
|
||||
|
@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
|
|||
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
|
||||
node->set_inputs({node->input(0), node->input(1)});
|
||||
}
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) {
|
||||
auto cnode_list = graph->execution_order();
|
||||
// 1 assign label id to label_set
|
||||
uint32_t cur_label_id = 0;
|
||||
for (auto &node : cnode_list) {
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) {
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node);
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id;
|
||||
++cur_label_id;
|
||||
}
|
||||
static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id,
|
||||
NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) {
|
||||
if (memo->find(graph.get()) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
// 2 update label_switch / label_goto
|
||||
for (auto &node : cnode_list) {
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) {
|
||||
UpdateLabelGoto(NOT_NULL(node));
|
||||
|
||||
MS_LOG(INFO) << "Assign label for " << graph->ToString();
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
for (auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) {
|
||||
UpdateLabelSwitch(NOT_NULL(node));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string node_name = AnfAlgo::GetCNodeName(node);
|
||||
if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(*label_id), node);
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id;
|
||||
++(*label_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cg : graph->child_graph_order()) {
|
||||
AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo);
|
||||
}
|
||||
}
|
||||
|
||||
static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGraph>> graph,
|
||||
NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) {
|
||||
if (memo->find(graph.get()) != memo->end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
|
||||
auto nodes = TopoSort(graph->get_return());
|
||||
for (auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string node_name = AnfAlgo::GetCNodeName(node);
|
||||
if (node_name == kLabelGotoOpName) {
|
||||
UpdateLabelGoto(NOT_NULL(cnode));
|
||||
cnode->set_abstract(nullptr);
|
||||
}
|
||||
|
||||
if (node_name == kLabelSwitchOpName) {
|
||||
UpdateLabelSwitch(NOT_NULL(cnode));
|
||||
}
|
||||
}
|
||||
for (auto &cg : graph->child_graph_order()) {
|
||||
AssignLabelForGotoSwitch(NOT_NULL(cg), memo);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
|
||||
MS_LOG(INFO) << "Assign label start.";
|
||||
std::set<std::shared_ptr<session::KernelGraph>> memo;
|
||||
uint32_t label_id = 0;
|
||||
AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(label_num_mutex_);
|
||||
label_num_[graph.get().get()] = label_id;
|
||||
}
|
||||
AssignLabelForGotoSwitch(graph, NOT_NULL(&memo));
|
||||
MS_LOG(INFO) << "Assign label end.";
|
||||
}
|
||||
|
||||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) {
|
||||
std::lock_guard<std::mutex> lock(label_num_mutex_);
|
||||
auto iter = label_num_.find(graph.get());
|
||||
if (iter == label_num_.end()) {
|
||||
MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label.";
|
||||
return 1;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
|
||||
return GetLabelNum(NOT_NULL(graph.get().get()));
|
||||
}
|
||||
|
||||
} // namespace ascend
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/contract.h"
|
||||
|
||||
|
@ -35,11 +36,16 @@ class AscendLabelAssign {
|
|||
AscendLabelAssign(const AscendLabelAssign &) = delete;
|
||||
AscendLabelAssign &operator=(const AscendLabelAssign &) = delete;
|
||||
|
||||
void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph);
|
||||
void AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph);
|
||||
uint32_t GetLabelNum(NotNull<const session::KernelGraph *> graph);
|
||||
uint32_t GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph);
|
||||
|
||||
private:
|
||||
AscendLabelAssign() = default;
|
||||
~AscendLabelAssign() = default;
|
||||
|
||||
std::map<const session::KernelGraph *, uint32_t> label_num_;
|
||||
std::mutex label_num_mutex_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "kernel/rts/label_switch.h"
|
||||
#include <asm-generic/param.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "runtime/stream.h"
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
|
|||
MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id;
|
||||
std::vector<TaskInfoPtr> task_info_list;
|
||||
cond_ = inputs[0]->addr;
|
||||
// std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr =
|
||||
// std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, &label_list_, cond_);
|
||||
// need updata ge task info define
|
||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_);
|
||||
// todo: need update ge task info define
|
||||
auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, 0);
|
||||
// auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_);
|
||||
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
||||
task_info_list.emplace_back(task_info_ptr);
|
||||
return task_info_list;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
|
||||
|
||||
vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT};
|
||||
vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool};
|
||||
if (input_format.size() != input_type.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size "
|
||||
<< input_type.size();
|
||||
}
|
||||
for (size_t i = 0; i < input_format.size(); ++i) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetInputsFormat({input_format[i]});
|
||||
builder.SetInputsDeviceType({input_type[i]});
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
label_switch_build_info.emplace_back(builder.Build());
|
||||
}
|
||||
return label_switch_build_info;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel {
|
|||
void *cond_;
|
||||
};
|
||||
|
||||
class LabelSwitchDesc : public RtKerDesc {
|
||||
public:
|
||||
LabelSwitchDesc() = default;
|
||||
~LabelSwitchDesc() override = default;
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() override;
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc);
|
||||
MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() {
|
|||
return _this;
|
||||
}
|
||||
|
||||
static bool IsDefaultKernelInfo(const std::string &name) {
|
||||
static const std::set<std::string> white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName,
|
||||
kLabelGotoOpName};
|
||||
return white_list.find(name) != white_list.end();
|
||||
}
|
||||
|
||||
void GetRtKelInfo(const CNodePtr &kernel_node,
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
|
@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
|
|||
}
|
||||
// if can't find kernel info in kernel info database, use the default kernel info
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "StreamSwitch" || node_name == "StreamActive") {
|
||||
if (IsDefaultKernelInfo(node_name)) {
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set input infos
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
|
|
@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
||||
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>();
|
||||
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
||||
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
|
||||
MS_EXCEPTION_IF_NULL(msbc_ptr);
|
||||
compile::VmEvalFuncPtr run =
|
||||
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
||||
MS_LOG(INFO) << "Execute args size" << args.size();
|
||||
auto outs = bc_ptr->RunGraph(graph_id, args);
|
||||
MS_LOG(DEBUG) << "out size" << outs.size();
|
||||
std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
||||
MS_LOG(INFO) << "Execute args size " << args.size();
|
||||
auto outs = msbc_ptr->RunGraph(graph_id, args);
|
||||
MS_LOG(DEBUG) << "out size " << outs.size();
|
||||
return outs[0];
|
||||
});
|
||||
res->results()[kOutput] = run;
|
||||
|
|
|
@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
)
|
||||
|
||||
if (ENABLE_GPU)
|
||||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"gpu_session.cc"
|
||||
)
|
||||
list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST})
|
||||
endif ()
|
||||
|
||||
if (ENABLE_CPU)
|
||||
file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"cpu_session.cc"
|
||||
)
|
||||
list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST})
|
||||
endif ()
|
||||
|
||||
if (ENABLE_D)
|
||||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"ascend_session.cc"
|
||||
"ascend_control_parser.cc"
|
||||
)
|
||||
list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST})
|
||||
endif ()
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "session/ascend_control_parser.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
||||
static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) {
|
||||
VectorRef call_args;
|
||||
for (auto iter = iter_begin; iter != iter_end; ++iter) {
|
||||
if (utils::isa<ValueNode>(*iter)) {
|
||||
call_args.push_back(GetValueNode(*iter));
|
||||
} else {
|
||||
call_args.push_back(*iter);
|
||||
}
|
||||
}
|
||||
return call_args;
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo));
|
||||
}
|
||||
|
||||
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label, const VectorRef &args,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString();
|
||||
// 0. recursive condition
|
||||
if (memo->find(kg) != memo->end()) {
|
||||
MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString();
|
||||
return NOT_NULL(kg->get_start_label());
|
||||
}
|
||||
|
||||
// 2. args replace placeholder
|
||||
LinkParentGraph(kg, last_node, last_label, args);
|
||||
// 3. topological sort
|
||||
std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return()));
|
||||
if (nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
|
||||
}
|
||||
// 4. insert first_label
|
||||
auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
for (auto node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||
InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
kg->set_start_label(start_label);
|
||||
// 5. traverse
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
auto &cnode = nodes[i];
|
||||
if (cnode->size() < kCNodePrim + 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
}
|
||||
AnfNodePtr fn = cnode->input(kCNodePrim);
|
||||
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
|
||||
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr arg = cnode->input(kCNodeCallArg);
|
||||
if (IsValueNode<KernelGraph>(arg)) {
|
||||
RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo);
|
||||
} else if (!arg->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString();
|
||||
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) {
|
||||
auto arg_cnode = arg->cast<CNodePtr>();
|
||||
cnode->set_inputs(cnode->inputs());
|
||||
RecurseSwitch(kg, NOT_NULL(cnode), memo);
|
||||
} else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) {
|
||||
auto arg_cnode = arg->cast<CNodePtr>();
|
||||
cnode->set_inputs(cnode->inputs());
|
||||
RecurseSwitchLayer(kg, NOT_NULL(cnode), memo);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
|
||||
return NOT_NULL(start_label);
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) {
|
||||
std::vector<CNodePtr> out;
|
||||
for (auto &node : in) {
|
||||
if (node->isa<CNode>()) {
|
||||
out.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
|
||||
auto return_node = kg->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
inputs.push_back(return_node->input(1));
|
||||
inputs.push_back(attch_node.get());
|
||||
auto depend_node = kg->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
||||
NotNull<AnfNodePtr> second_node) {
|
||||
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
|
||||
<< ", the second node is " << second_node->DebugString();
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
|
||||
first_node, second_node};
|
||||
auto control_depend = kg->NewCNode(inputs);
|
||||
InsertDependToGraph(kg, NOT_NULL(control_depend));
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
||||
const CNodePtr &last_label, const VectorRef &args) {
|
||||
if (from_graph_call_node != nullptr) {
|
||||
SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args);
|
||||
}
|
||||
|
||||
auto origin_return = kg->get_return();
|
||||
std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs();
|
||||
// if entry graph, replace return with make_tuple
|
||||
if (from_graph_call_node == nullptr || last_label == nullptr) {
|
||||
MS_LOG(INFO) << kg->ToString() << " is entry graph.";
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)};
|
||||
make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end());
|
||||
auto make_tuple = kg->NewCNode(make_tuple_inputs);
|
||||
origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple});
|
||||
} else {
|
||||
// else replace return with label_goto
|
||||
auto label_goto =
|
||||
kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
|
||||
InsertDependToGraph(kg, NOT_NULL(label_goto));
|
||||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process call func " << cur_node->DebugString();
|
||||
|
||||
// 1 get kernel graph
|
||||
auto origin_inputs = cur_node->inputs();
|
||||
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
|
||||
auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end());
|
||||
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
|
||||
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
|
||||
return;
|
||||
}
|
||||
// 2 return label
|
||||
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
// 3 add depend relationship
|
||||
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
|
||||
if (next_node != nullptr && next_node != kg->get_return()) {
|
||||
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
|
||||
}
|
||||
auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]);
|
||||
// 4 modify call op to goto op
|
||||
cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]);
|
||||
// 5 recurse sub graph
|
||||
CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo);
|
||||
new_inputs.push_back(sub_label);
|
||||
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
|
||||
cur_node->set_inputs(new_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
|
||||
|
||||
if (cur_node->size() < kCNodeSwitchLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
|
||||
}
|
||||
// 1 return label
|
||||
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)});
|
||||
// 2 recurse sub graph
|
||||
auto origin_switch_inputs = cur_node->inputs();
|
||||
std::vector<AnfNodePtr> new_switch_inputs = {
|
||||
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
|
||||
// 2.1 branch kernel graph and args
|
||||
CNodePtr partial;
|
||||
KernelGraphPtr branch_fg;
|
||||
VectorRef call_args;
|
||||
std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
// 2.2 add depend relationship
|
||||
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
|
||||
// 2.3 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
}
|
||||
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
||||
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
|
||||
|
||||
if (cur_node->size() < kCNodeSwitchLayerLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
|
||||
}
|
||||
|
||||
auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
|
||||
MS_EXCEPTION_IF_NULL(branch_tuple);
|
||||
if (!branch_tuple->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
|
||||
}
|
||||
auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
|
||||
// 1 return label
|
||||
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))});
|
||||
// 2 recurse sub graph
|
||||
auto origin_switch_inputs = cur_node->inputs();
|
||||
std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
for (size_t i = 0; i < branch_partial.size(); ++i) {
|
||||
// 2.1 branch kernel graph and args
|
||||
CNodePtr partial;
|
||||
KernelGraphPtr branch_fg;
|
||||
VectorRef call_args;
|
||||
std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
|
||||
// 2.2 add depend relationship
|
||||
InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label));
|
||||
// 2.3 recurse sub graph
|
||||
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo);
|
||||
new_switch_inputs.push_back(branch_label);
|
||||
}
|
||||
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
||||
if (!node.get()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
|
||||
}
|
||||
// 2.1 branch kernel graph and args
|
||||
auto partial_cnode = utils::cast<CNodePtr>(node.get());
|
||||
if (partial_cnode->size() < kCNodePartialLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
|
||||
}
|
||||
auto partial_inputs = partial_cnode->inputs();
|
||||
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
|
||||
auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end());
|
||||
|
||||
return {partial_cnode, branch_kg, call_args};
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
|
||||
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
|
||||
return;
|
||||
}
|
||||
if (from.get() == to.get()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
|
||||
<< to->DebugString();
|
||||
// config inputs of assign node
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
|
||||
// generate a new cnode
|
||||
auto assign_node = kg->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(assign_node);
|
||||
assign_node->set_abstract(to->abstract());
|
||||
// append the assign at the end of from graph
|
||||
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
||||
}
|
||||
|
||||
size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node,
|
||||
size_t input_index) {
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return input_index + output_num;
|
||||
}
|
||||
|
||||
auto &graph_inputs = kg->inputs();
|
||||
if (input_index >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_index];
|
||||
if (node.get()->isa<Parameter>()) {
|
||||
MS_EXCEPTION_IF_NULL(backend_parameter);
|
||||
MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString()
|
||||
<< "] will be replaced.";
|
||||
kg->ReplaceNode(backend_parameter, node);
|
||||
return input_index;
|
||||
}
|
||||
InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter));
|
||||
return input_index + 1;
|
||||
}
|
||||
|
||||
void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node,
|
||||
const VectorRef &args) {}
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|
||||
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include "session/kernel_graph.h"
|
||||
#include "utils/base_ref.h"
|
||||
#include "utils/contract.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
||||
class AscendControlParser {
|
||||
public:
|
||||
static void LinkGraph(NotNull<KernelGraphPtr> kg);
|
||||
|
||||
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
|
||||
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
||||
NotNull<AnfNodePtr> second_node);
|
||||
|
||||
private:
|
||||
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label, const VectorRef &args,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
|
||||
static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in);
|
||||
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
||||
const CNodePtr &last_label, const VectorRef &args);
|
||||
static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node,
|
||||
const VectorRef &args);
|
||||
static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node);
|
||||
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index);
|
||||
|
||||
static constexpr size_t kCNodePrim = 0;
|
||||
static constexpr size_t kCNodeCallArg = 1;
|
||||
static constexpr size_t kCNodeSwitchCond = 1;
|
||||
static constexpr size_t kCNodeSwitchTrue = 2;
|
||||
static constexpr size_t kCNodeSwitchFalse = 3;
|
||||
static constexpr size_t kCNodeSwitchLength = 4;
|
||||
static constexpr size_t kCNodePartialLength = 2;
|
||||
static constexpr size_t kCNodePartialFunc = 1;
|
||||
static constexpr size_t kCNodeSwitchLayerCond = 1;
|
||||
static constexpr size_t kCNodeSwitchLayerBranch = 2;
|
||||
static constexpr size_t kCNodeSwitchLayerLength = 3;
|
||||
};
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
|
|
@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
|
|||
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||
std::vector<CNodePtr> cnodes = {};
|
||||
size_t i = 0;
|
||||
for (const auto anf : anf_nodes) {
|
||||
for (auto anf : anf_nodes) {
|
||||
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (anf->isa<CNode>()) {
|
||||
cnodes.push_back(anf->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
return std::move(cnodes);
|
||||
return cnodes;
|
||||
}
|
||||
|
||||
std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
|
||||
|
@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
|
|||
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
|
||||
}
|
||||
}
|
||||
return std::move(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void UpdateRealInput(KernelGraph *graph) {
|
||||
|
@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) {
|
|||
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
|
||||
partial_cnode->set_inputs(
|
||||
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
|
||||
return std::move(ret);
|
||||
return ret;
|
||||
};
|
||||
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
||||
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
||||
|
@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
// split switch
|
||||
SplitGraph(graph);
|
||||
// insert goto labels and label_sets
|
||||
LinkChildGraphs(graph.get());
|
||||
LinkChildGraphs(NOT_NULL(graph));
|
||||
// resource initialize
|
||||
InitRuntimeResource();
|
||||
// ir fusion
|
||||
IRFusion(graph);
|
||||
// kernel select
|
||||
SelectKernelGraphKernel(*graph);
|
||||
// convert model of predict module
|
||||
ConvertPredictModel(graph);
|
||||
// hardware optimize
|
||||
HardwareOptimizeGraphs(graph);
|
||||
// assign label
|
||||
AssignLabel(NOT_NULL(graph));
|
||||
if (!graph->executable()) {
|
||||
return graph->graph_id();
|
||||
}
|
||||
for (auto iter : graphs_) {
|
||||
if (iter.second == graph) {
|
||||
MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id();
|
||||
final_graph_id_ = graph->graph_id();
|
||||
}
|
||||
MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString();
|
||||
CompileChildGraph(iter.second);
|
||||
}
|
||||
// adjust kernel
|
||||
AdjustKernel(graph);
|
||||
// root graph valiate,include genearte execute order and so on
|
||||
RootGraphExecutorValidate(graph.get());
|
||||
// assign stream
|
||||
AssignStream(graph);
|
||||
// assign label
|
||||
AssignLabel(NOT_NULL(graph));
|
||||
// build kernel if node is cnode
|
||||
BuildKernel(graph);
|
||||
// alloc mem
|
||||
MemoryAlloc(graph.get());
|
||||
// task generate
|
||||
|
@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const {
|
||||
void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
|
@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived
|
|||
}
|
||||
|
||||
void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) {
|
||||
MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString();
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
|
||||
auto return_node = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
inputs.push_back(return_node->input(1));
|
||||
inputs.push_back(attch_node);
|
||||
auto depend_node = graph->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node));
|
||||
}
|
||||
|
||||
void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node,
|
||||
const AnfNodePtr &second_node) {
|
||||
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
|
||||
<< ", the second node is " << second_node->DebugString();
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))};
|
||||
inputs.push_back(first_node);
|
||||
inputs.push_back(second_node);
|
||||
auto control_depend = graph->NewCNode(inputs);
|
||||
InsertDependToGraph(graph_id, control_depend);
|
||||
AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node),
|
||||
NOT_NULL(second_node));
|
||||
}
|
||||
|
||||
size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) {
|
||||
|
@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
|||
SplitGraph(child_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "session/kernel_graph.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "session/session_factory.h"
|
||||
#include "session/ascend_control_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -74,7 +75,7 @@ class AscendSession : public SessionBasic {
|
|||
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const;
|
||||
void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
|
||||
|
@ -96,7 +97,8 @@ class AscendSession : public SessionBasic {
|
|||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||
|
||||
void SplitGraph(const KernelGraphPtr &graph);
|
||||
void LinkChildGraphs(KernelGraph *graph) {}
|
||||
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||
|
||||
void IRFusion(const KernelGraphPtr &graph) {}
|
||||
void SelectKernelGraphKernel(const KernelGraph &graph) {}
|
||||
void ConvertPredictModel(const KernelGraphPtr graph) {}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "utils/contract.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph {
|
|||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
|
||||
// checkout whether current graph is leaf graph
|
||||
bool IsLeafGraph() const;
|
||||
|
||||
// set input_tensors pointer of control parameter
|
||||
void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) {
|
||||
input_ctrl_tensors_ = input_tensors_ptr;
|
||||
|
@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph {
|
|||
// used to dump ir
|
||||
std::string ToString() const override;
|
||||
|
||||
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
|
||||
CNodePtr get_start_label() { return start_label_; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
|
@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph {
|
|||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
|
||||
// child graph execute order in root graph
|
||||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
||||
|
||||
// input_tensors of control parameter
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
|
||||
|
||||
// parameter graph
|
||||
std::shared_ptr<KernelGraph> parent_graph_;
|
||||
// record real parameters,inputs_ is the formal parameters
|
||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_;
|
||||
|
||||
CNodePtr start_label_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
|
@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/transform/*.cc"
|
||||
"../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc"
|
||||
"../../../mindspore/ccsrc/session/ascend_session.cc"
|
||||
"../../../mindspore/ccsrc/session/ascend_control_parser.cc"
|
||||
"../../../mindspore/ccsrc/session/kernel_graph.cc"
|
||||
"../../../mindspore/ccsrc/session/session_basic.cc"
|
||||
"../../../mindspore/ccsrc/session/session_factory.cc"
|
||||
|
|
|
@ -22,7 +22,9 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {}
|
||||
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {}
|
||||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; }
|
||||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; }
|
||||
|
||||
void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; }
|
||||
|
||||
|
@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
|
|||
} // namespace ascend
|
||||
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
return true;
|
||||
}
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return true; }
|
||||
bool KernelAdjust::NeedInsertSwitch() { return true; }
|
||||
void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; }
|
||||
} // namespace device
|
||||
|
|
Loading…
Reference in New Issue