PyNative Ascend Task Sink

This commit is contained in:
caifubi 2022-08-11 20:15:31 +08:00
parent 855e06c7ca
commit 2d6fbcd9b3
30 changed files with 643 additions and 108 deletions

View File

@ -550,6 +550,20 @@ bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t outp
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
bool skip_nop_node) { bool skip_nop_node) {
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (kernel_with_index.first->isa<ValueNode>()) {
auto value_node = kernel_with_index.first->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->is_forward_output()) {
return dynamic_cast<const DeviceAddress *>(tensor->device_address().get());
}
}
}
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node); return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
} }

View File

@ -454,13 +454,6 @@ std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompile
return input_tensor_lists; return input_tensor_lists;
} }
bool IsAutoParallel() {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
}
} // namespace } // namespace
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
@ -605,6 +598,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
real_execution_mode_ = ms_execution_mode_; real_execution_mode_ = ms_execution_mode_;
func_graph->set_flag(kFlagPyNativeRunInGraph, real_execution_mode_ == kPynativeMode);
// Compile root graph. // Compile root graph.
graph_id_to_device_context_.clear(); graph_id_to_device_context_.clear();
@ -615,9 +609,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}); device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context);
bool all_support = device_context->PartitionGraph(func_graph); bool all_support = device_context->PartitionGraph(func_graph);
auto run_mode = device_context->GetRunMode(func_graph);
if (all_support) { if (all_support) {
if (run_mode == device::RunMode::kGraphMode) { auto run_mode = device_context->GetRunMode(func_graph);
if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context); auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
graph_id_to_device_context_[graph_id] = device_context; graph_id_to_device_context_[graph_id] = device_context;
} else { } else {
@ -1004,8 +998,9 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
graph->set_flag(kFlagPyNativeRunInGraph, true); graph->set_flag(kFlagPyNativeRunInGraph, true);
graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph)); graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph));
// The size of control_nodes is at least 1 since there is return node in the graph. // KernelByKernel: The size of control_nodes is at least 1 since there is return node in the graph.
if (control_nodes_.size() == 1 && graphs.size() == 1) { // GraphMode: No control nodes.
if (control_nodes_.size() <= 1 && graphs.size() == 1) {
MS_LOG(INFO) << "Replace parameter format"; MS_LOG(INFO) << "Replace parameter format";
// The input tensors of heterogeneous graphs or control flow graphs are null. // The input tensors of heterogeneous graphs or control flow graphs are null.
// Need to get tensor after ParseControlNodes. // Need to get tensor after ParseControlNodes.
@ -1121,7 +1116,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
// Save grad node to Bucket // Save grad node to Bucket
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
!kernel->is_parallel() && IsAutoParallel()) { !kernel->is_parallel() && pynative::GraphAdapter::IsAutoParallel()) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors); graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
} }
} }

View File

@ -951,13 +951,6 @@ void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink); context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink); backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
}; };
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
// PYNATIVE: no need set any context.
if (mode == kPynativeMode) {
MS_LOG(INFO) << "Run graph mode with pynative.";
set_ctx(false, false, false);
return;
}
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
// GPU/CPU no need set any context. // GPU/CPU no need set any context.

View File

@ -802,7 +802,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
if (top_cell()->dynamic_shape()) { if (top_cell()->dynamic_shape()) {
bprop_graph->set_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE, true); bprop_graph->set_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE, true);
} }
if (top_cell()->is_dynamic_structure()) { if (top_cell()->is_real_dynamic_structure()) {
bprop_graph->set_flag(kFlagIsDynamicStructure, true); bprop_graph->set_flag(kFlagIsDynamicStructure, true);
} }
// Do opt for final bprop graph // Do opt for final bprop graph
@ -918,7 +918,7 @@ void GradExecutor::CheckNeedCompileGraph() {
EraseTopCellFromTopCellList(pre_top_cell); EraseTopCellFromTopCellList(pre_top_cell);
pre_top_cell->Clear(); pre_top_cell->Clear();
already_run_top_cell_[already_top_cell_id] = new_top_cell; already_run_top_cell_[already_top_cell_id] = new_top_cell;
top_cell()->set_dynamic_structure(true); top_cell()->set_is_real_dynamic_structure(true);
} else { } else {
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again"; MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id()); pre_top_cell->set_input_args_id(new_top_cell->input_args_id());

View File

@ -96,6 +96,10 @@ class TopCellInfo {
size_t grad_order() const { return grad_order_; } size_t grad_order() const { return grad_order_; }
bool is_dynamic_structure() const { return is_dynamic_structure_; } bool is_dynamic_structure() const { return is_dynamic_structure_; }
void set_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; } void set_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
bool is_real_dynamic_structure() const { return is_real_dynamic_structure_; }
void set_is_real_dynamic_structure(bool is_real_dynamic_structure) {
is_real_dynamic_structure_ = is_real_dynamic_structure;
}
bool dynamic_shape() const { return dynamic_shape_; } bool dynamic_shape() const { return dynamic_shape_; }
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; } void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
bool hook_changed() const { return hook_changed_; } bool hook_changed() const { return hook_changed_; }
@ -184,6 +188,8 @@ class TopCellInfo {
private: private:
bool is_topest_{false}; bool is_topest_{false};
bool is_dynamic_structure_{false}; bool is_dynamic_structure_{false};
// Set this flag to ture when all_op_info of top_cell is changed.
bool is_real_dynamic_structure_{false};
bool dynamic_shape_{false}; bool dynamic_shape_{false};
bool vm_compiled_{false}; bool vm_compiled_{false};
bool hook_changed_{false}; bool hook_changed_{false};

View File

@ -258,6 +258,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
} else { } else {
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
} }
rt_model_zero_copy_.Release(graph_id);
} }
void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const { void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const {
@ -575,6 +577,11 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph &graph) {
MS_LOG(EXCEPTION) << "Distribute Task Failed, \nerror msg: " << e.what(); MS_LOG(EXCEPTION) << "Distribute Task Failed, \nerror msg: " << e.what();
} }
if (!rt_model_zero_copy_.GenerateZeroCopyTasks(graph)) {
MS_LOG(ERROR) << "Generate ZeroCopyTask failed, graph id " << graph.graph_id();
return false;
}
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
if (ProfilingManager::GetInstance().IsProfilingInitialized()) { if (ProfilingManager::GetInstance().IsProfilingInitialized()) {
auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first); auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first);
@ -1065,6 +1072,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph &graph) {
return false; return false;
} }
if (!rt_model_zero_copy_.UpdateTaskArgs(graph, compute_stream())) {
MS_LOG(ERROR) << "Update rtModel task args failed, graph id " << graph.graph_id();
return false;
}
try { try {
ModelRunner::Instance().RunModel(graph.graph_id()); ModelRunner::Instance().RunModel(graph.graph_id());
} catch (const std::exception &) { } catch (const std::exception &) {

View File

@ -27,6 +27,7 @@
#include "runtime/device/kernel_runtime.h" #include "runtime/device/kernel_runtime.h"
#include "runtime/context.h" #include "runtime/context.h"
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h" #include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
#include "plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "backend/common/session/session_basic.h" #include "backend/common/session/session_basic.h"
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
@ -114,6 +115,7 @@ class AscendKernelRuntime : public KernelRuntime {
static bool DeleteDumpDir(const std::string &path); static bool DeleteDumpDir(const std::string &path);
static int DeleteDumpFile(std::string path); static int DeleteDumpFile(std::string path);
static std::string GetRealPath(const std::string &path); static std::string GetRealPath(const std::string &path);
void CreateDefaultStream(uint32_t device_id);
rtContext_t rt_context_{nullptr}; rtContext_t rt_context_{nullptr};
bool initialized_{false}; bool initialized_{false};
@ -128,7 +130,7 @@ class AscendKernelRuntime : public KernelRuntime {
std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_; std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_;
std::map<uint32_t, void *> stream_id_map_; std::map<uint32_t, void *> stream_id_map_;
std::set<uint32_t> initialized_device_set_{}; std::set<uint32_t> initialized_device_set_{};
void CreateDefaultStream(uint32_t device_id); tasksink::RtModelZeroCopy rt_model_zero_copy_;
}; };
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);

View File

@ -481,7 +481,7 @@ void AscendStreamAssign::InsertEventForNonTaskSink(const NotNull<KernelGraphPtr>
} }
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
if (graph_ptr->has_flag(kFlagPyNativeRunInGraph)) { if (graph_ptr->has_flag(kFlagPyNativeRunInGraph) && !graph_ptr->is_graph_run_mode()) {
AssignStreamForPynativeGraph(graph_ptr.get()); AssignStreamForPynativeGraph(graph_ptr.get());
return; return;
} }

View File

@ -60,6 +60,15 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList(); return model_iter->second->GetTaskIdList();
} }
const std::vector<std::shared_ptr<Task>> &ModelRunner::GetTaskList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
}
MS_EXCEPTION_IF_NULL(model_iter->second);
return model_iter->second->GetTaskList();
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
decltype(runtime_models_)::const_iterator model_iter = runtime_models_.find(model_id); decltype(runtime_models_)::const_iterator model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.cend()) { if (model_iter == runtime_models_.cend()) {

View File

@ -23,6 +23,7 @@
#include <tuple> #include <tuple>
#include <string> #include <string>
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h" #include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
#include "plugin/device/ascend/hal/device/ge_runtime/task/task.h"
namespace mindspore::ge::model_runner { namespace mindspore::ge::model_runner {
class RuntimeModel; class RuntimeModel;
@ -38,6 +39,8 @@ class ModelRunner {
void LoadModelComplete(uint32_t model_id); void LoadModelComplete(uint32_t model_id);
const std::vector<std::shared_ptr<Task>> &GetTaskList(uint32_t model_id) const;
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;

View File

@ -293,5 +293,7 @@ void RuntimeModel::RtEventDestory() noexcept {
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
const std::vector<std::shared_ptr<Task>> &RuntimeModel::GetTaskList() const { return task_list_; }
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace mindspore::ge::model_runner } // namespace mindspore::ge::model_runner

View File

@ -36,6 +36,7 @@ class RuntimeModel {
void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model); void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
void DistributeTask(); void DistributeTask();
void LoadComplete(); void LoadComplete();
const std::vector<std::shared_ptr<Task>> &GetTaskList() const;
const std::vector<uint32_t> &GetTaskIdList() const; const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const; const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }

View File

@ -30,7 +30,8 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai
stream_(nullptr), stream_(nullptr),
args_(nullptr), args_(nullptr),
ext_info_(nullptr), ext_info_(nullptr),
input_output_addr_(nullptr) { input_output_addr_(nullptr),
args_size_(0) {
MS_EXCEPTION_IF_NULL(task_info_); MS_EXCEPTION_IF_NULL(task_info_);
auto stream_list = model_context.stream_list(); auto stream_list = model_context.stream_list();
@ -60,20 +61,20 @@ void AicpuTask::Distribute() {
(void)io_addrs.insert(io_addrs.cend(), task_info_->output_data_addrs().cbegin(), (void)io_addrs.insert(io_addrs.cend(), task_info_->output_data_addrs().cbegin(),
task_info_->output_data_addrs().cend()); task_info_->output_data_addrs().cend());
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); io_addrs_size_ = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; uint32_t node_def_len_offset = io_addr_offset + io_addrs_size_;
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + args_size_ = sizeof(aicpu::AicpuParamHead) + io_addrs_size_ + static_cast<uint32_t>(task_info_->node_def().size()) +
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); sizeof(uint32_t);
// Malloc device memory for args // Malloc device memory for args
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); rtError_t rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret; MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret;
} }
SetAicpuParamHead(args_size, io_addrs_num); SetAicpuParamHead(args_size_, io_addrs_num);
SetInputOutputAddrs(io_addrs, io_addr_offset); SetInputOutputAddrs(io_addrs, io_addr_offset);
SetNodeDef(node_def_len_offset, node_def_addr_offset); SetNodeDef(node_def_len_offset, node_def_addr_offset);
@ -82,12 +83,12 @@ void AicpuTask::Distribute() {
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
auto cpu_flag = task_info_->cust_aicpu() ? RT_KERNEL_CUSTOM_AICPU : dump_flag; auto cpu_flag = task_info_->cust_aicpu() ? RT_KERNEL_CUSTOM_AICPU : dump_flag;
MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size_ << ", io_addrs_num =" << io_addrs_num
<< ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name() << ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name()
<< ", dump_flag = " << dump_flag; << ", dump_flag = " << dump_flag;
rtArgsEx_t argsInfo = {}; rtArgsEx_t argsInfo = {};
argsInfo.args = args_; argsInfo.args = args_;
argsInfo.argsSize = args_size; argsInfo.argsSize = args_size_;
rt_ret = rtCpuKernelLaunchWithFlag(static_cast<const void *>(task_info_->so_name().data()), rt_ret = rtCpuKernelLaunchWithFlag(static_cast<const void *>(task_info_->so_name().data()),
static_cast<const void *>(task_info_->kernel_name().data()), 1, &argsInfo, nullptr, static_cast<const void *>(task_info_->kernel_name().data()), 1, &argsInfo, nullptr,
stream_, cpu_flag); stream_, cpu_flag);

View File

@ -33,6 +33,8 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
void *Args() const override { return input_output_addr_; } void *Args() const override { return input_output_addr_; }
size_t ArgsSize() const override { return io_addrs_size_; }
std::string task_name() const override { return task_info_->op_name(); } std::string task_name() const override { return task_info_->op_name(); }
private: private:
@ -46,6 +48,8 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
void *args_; void *args_;
void *ext_info_; void *ext_info_;
void *input_output_addr_; void *input_output_addr_;
size_t io_addrs_size_;
size_t args_size_;
}; };
} // namespace mindspore::ge::model_runner } // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_ #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_

View File

@ -33,6 +33,8 @@ class Task {
virtual void *Args() const { return nullptr; } virtual void *Args() const { return nullptr; }
virtual size_t ArgsSize() const { return 0; }
virtual std::string task_name() const { return ""; } virtual std::string task_name() const { return ""; }
void set_model_handle(rtModel_t model_handle) { model_handle_ = model_handle; } void set_model_handle(rtModel_t model_handle) { model_handle_ = model_handle; }

View File

@ -27,7 +27,8 @@ TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTas
task_info_(task_info), task_info_(task_info),
stream_(nullptr), stream_(nullptr),
stub_func_(nullptr), stub_func_(nullptr),
args_(nullptr) { args_(nullptr),
args_size_(0) {
MS_EXCEPTION_IF_NULL(task_info); MS_EXCEPTION_IF_NULL(task_info);
auto stream_list = model_context.stream_list(); auto stream_list = model_context.stream_list();
@ -74,14 +75,14 @@ void TbeTask::Distribute() {
task_info_->output_data_addrs().cend()); task_info_->output_data_addrs().cend());
tensor_device_addrs.insert(tensor_device_addrs.cend(), task_info_->workspace_addrs().cbegin(), tensor_device_addrs.insert(tensor_device_addrs.cend(), task_info_->workspace_addrs().cbegin(),
task_info_->workspace_addrs().cend()); task_info_->workspace_addrs().cend());
auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *)); args_size_ = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size; MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size_;
} }
rt_ret = aclrtMemcpy(args_, args_size, static_cast<void *>(tensor_device_addrs.data()), args_size, rt_ret = aclrtMemcpy(args_, args_size_, static_cast<void *>(tensor_device_addrs.data()), args_size_,
ACL_MEMCPY_HOST_TO_DEVICE); ACL_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret; MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret;
@ -91,10 +92,10 @@ void TbeTask::Distribute() {
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rtArgsEx_t args_info = {}; rtArgsEx_t args_info = {};
args_info.args = args_; args_info.args = args_;
args_info.argsSize = args_size; args_info.argsSize = args_size_;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), &args_info, nullptr, stream_, dump_flag); rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), &args_info, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size; MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size_;
} }
MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag; MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
} }

View File

@ -32,6 +32,8 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
void *Args() const override { return args_; } void *Args() const override { return args_; }
size_t ArgsSize() const override { return args_size_; }
std::string task_name() const override { return task_info_->op_name(); } std::string task_name() const override { return task_info_->op_name(); }
private: private:
@ -39,6 +41,7 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
void *stream_; void *stream_;
void *stub_func_; void *stub_func_;
void *args_; void *args_;
size_t args_size_;
}; };
} // namespace mindspore::ge::model_runner } // namespace mindspore::ge::model_runner
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_ #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_

View File

@ -0,0 +1,283 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.h"
#include <vector>
#include <map>
#include <algorithm>
#include "runtime/rt.h"
#include "external/acl/acl_rt.h"
#include "ir/tensor.h"
#include "include/common/utils/anfalgo.h"
#include "runtime/device/kernel_info.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "plugin/device/ascend/hal/device/ge_runtime/model_runner.h"
#include "plugin/device/ascend/hal/device/tasksink/task_generator.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace tasksink {
using TaskPtr = std::shared_ptr<ge::model_runner::Task>;
namespace {
bool IsForwardOutputValueNode(const AnfNodePtr &input) {
MS_EXCEPTION_IF_NULL(input);
if (input->isa<ValueNode>()) {
auto value_node = input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor->is_forward_output()) {
return true;
}
}
}
return false;
}
bool CheckTaskValid(const CNodePtr &node, const std::vector<void *> &args_datas) {
MS_EXCEPTION_IF_NULL(node);
bool task_valid = true;
// Check input/output/workspace
auto input_addrs = TaskGenerator::GetTaskInput(node);
auto output_addrs = TaskGenerator::GetTaskOutput(node);
auto workspace_addrs = TaskGenerator::GetTaskWorkspace(node);
std::vector<AddressPtr> node_addresses;
std::move(input_addrs.begin(), input_addrs.end(), std::back_inserter(node_addresses));
std::move(output_addrs.begin(), output_addrs.end(), std::back_inserter(node_addresses));
std::move(workspace_addrs.begin(), workspace_addrs.end(), std::back_inserter(node_addresses));
if (node_addresses.size() != args_datas.size()) {
MS_LOG(ERROR) << "Check failed, Node " << node->UniqueName() << " total addr size " << node_addresses.size()
<< " is not equal to " << args_datas.size();
return false;
}
for (size_t i = 0; i < node_addresses.size(); ++i) {
auto node_address = node_addresses[i];
MS_EXCEPTION_IF_NULL(node_address);
if (node_address->addr != args_datas[i]) {
MS_LOG(ERROR) << "Node " << node->UniqueName() << " addr " << node_address->addr << " not equal to addr of task "
<< args_datas[i];
task_valid = false;
}
}
return task_valid;
}
bool NeedSkipZeroCopy(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (common::AnfAlgo::IsNonTaskOp(node)) {
MS_LOG(INFO) << "Skip generate ZeroCopyTask for NonTaskOp " << node->fullname_with_scope();
return true;
}
auto kernel_type = AnfAlgo::GetKernelType(node);
if (kernel_type != KernelType::TBE_KERNEL && kernel_type != KernelType::AICPU_KERNEL) {
MS_LOG(INFO) << "Skip generate ZeroCopyTask for " << node->fullname_with_scope();
return true;
}
return false;
}
} // namespace
void *ParameterZeroCopyTask::GetAddressPtr() {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Not a parameter node " << node->DebugString();
}
auto kernel_info = dynamic_cast<KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto parameter_address = kernel_info->GetOutputAddr(0);
MS_EXCEPTION_IF_NULL(parameter_address);
return parameter_address->GetMutablePtr();
}
void *ValueNodeZeroCopyTask::GetAddressPtr() {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "Not a ValueNode " << node->DebugString();
}
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto value_node_address = tensor->device_address();
MS_EXCEPTION_IF_NULL(value_node_address);
return value_node_address->GetMutablePtr();
}
bool ZeroCopyTask::UpdateArgs(void *stream) {
device_ptr_ = GetAddressPtr();
if (device_ptr_ == nullptr) {
MS_LOG(ERROR) << "Device address ptr is null, task " << task_name_;
return false;
}
if (device_ptr_ == previous_ptr_) {
MS_LOG(DEBUG) << "No need to update task of " << task_name_;
return true;
}
auto ret = aclrtMemcpyAsync(static_cast<uint8_t *>(args_base_) + args_offset_, sizeof(void *), &device_ptr_,
sizeof(void *), ACL_MEMCPY_HOST_TO_DEVICE, stream);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Update task " << task_name_ << " aclrtMemcpy failed, ret " << ret;
return false;
}
previous_ptr_ = device_ptr_;
MS_LOG(INFO) << "Update task " << task_name_ << " args_offset " << args_offset_ << " device_ptr " << device_ptr_
<< " success";
return true;
}
bool RtModelZeroCopy::GenerateZeroCopyTasks(const session::KernelGraph &graph) {
if (!graph.has_flag(kFlagPyNativeRunInGraph)) {
MS_LOG(INFO) << "RtModelZeroCopy is not enabled";
return true;
}
std::vector<ZeroCopyTaskPtr> zero_copy_tasks;
auto task_lists = ge::model_runner::ModelRunner::Instance().GetTaskList(graph.graph_id());
std::map<std::string, TaskPtr> op_name_to_task;
std::transform(task_lists.begin(), task_lists.end(), std::inserter(op_name_to_task, op_name_to_task.end()),
[](const TaskPtr &task) { return std::make_pair(task->task_name(), task); });
auto nodes = graph.execution_order();
for (const auto &node : nodes) {
if (NeedSkipZeroCopy(node)) {
continue;
}
MS_EXCEPTION_IF_NULL(node);
auto op_name = node->UniqueName();
auto iter = op_name_to_task.find(op_name);
if (iter == op_name_to_task.end()) {
MS_LOG(EXCEPTION) << "Cannot found task of op " << op_name;
}
auto task = iter->second;
MS_EXCEPTION_IF_NULL(task);
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_num; ++i) {
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(node, i);
auto input = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph, true).first;
MS_EXCEPTION_IF_NULL(input);
if (input->isa<Parameter>()) {
zero_copy_tasks.emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
input, task->Args(), i * sizeof(void *), task->task_name()));
MS_LOG(INFO) << "Generate ZeroCopyTask for Node " << node->fullname_with_scope() << " Parameter "
<< input->DebugString();
} else if (IsForwardOutputValueNode(input)) {
zero_copy_tasks.emplace_back(std::make_shared<tasksink::ValueNodeZeroCopyTask>(
input, task->Args(), i * sizeof(void *), task->task_name()));
MS_LOG(INFO) << "Generate ZeroCopyTask for Node " << node->fullname_with_scope() << " ValueNode "
<< input->DebugString();
}
}
}
auto iter = graph_zero_copy_tasks_.try_emplace(graph.graph_id(), zero_copy_tasks);
if (!iter.second) {
MS_LOG(ERROR) << "Generate ZeroCopyTask failed, Duplicate graph id " << graph.graph_id();
return false;
}
return true;
}
bool RtModelZeroCopy::UpdateTaskArgs(const session::KernelGraph &graph, void *stream) const {
if (!graph.has_flag(kFlagPyNativeRunInGraph)) {
MS_LOG(INFO) << "RtModelZeroCopy is not enabled, no need to update task args.";
return true;
}
auto iter = graph_zero_copy_tasks_.find(graph.graph_id());
if (iter == graph_zero_copy_tasks_.end()) {
MS_LOG(ERROR) << "No zero copy tasks found. graph id " << graph.graph_id();
return false;
}
auto zero_copy_tasks = iter->second;
if (std::any_of(zero_copy_tasks.begin(), zero_copy_tasks.end(),
[stream](const ZeroCopyTaskPtr &task) { return !task->UpdateArgs(stream); })) {
MS_LOG(ERROR) << "Update task args failed";
return false;
}
MS_LOG(INFO) << "Check rtMode valid " << ((rtStreamSynchronize(stream) == RT_ERROR_NONE) && CheckRtModelValid(graph));
return true;
}
bool RtModelZeroCopy::CheckRtModelValid(const session::KernelGraph &graph) {
auto graph_id = graph.graph_id();
auto tasks = ge::model_runner::ModelRunner::Instance().GetTaskList(graph_id);
std::map<std::string, TaskPtr> op_name_to_task;
std::transform(tasks.begin(), tasks.end(), std::inserter(op_name_to_task, op_name_to_task.end()),
[](const TaskPtr &task) { return std::make_pair(task->task_name(), task); });
auto nodes = graph.execution_order();
bool task_valid = true;
for (const auto &node : nodes) {
if (NeedSkipZeroCopy(node)) {
continue;
}
MS_EXCEPTION_IF_NULL(node);
auto unique_name = node->UniqueName();
auto iter = op_name_to_task.find(unique_name);
if (iter == op_name_to_task.end()) {
MS_LOG(ERROR) << "Cannot found task of op " << unique_name;
task_valid = false;
continue;
}
auto task = iter->second;
MS_EXCEPTION_IF_NULL(task);
auto task_args = task->Args();
auto task_size = task->ArgsSize();
if (task_size == 0) {
// For example InitDataSet (AiCpu kernel).
MS_LOG(INFO) << "task name " << task->task_name() << " task_size is 0";
continue;
}
std::vector<void *> args_datas(task_size / sizeof(void *), nullptr);
if (aclrtMemcpy(args_datas.data(), task_size, task_args, task_size, ACL_MEMCPY_DEVICE_TO_HOST) != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtMemcpy failed, task " << task->task_name() << " task_size " << task_size;
return false;
}
if (!CheckTaskValid(node, args_datas)) {
task_valid = false;
}
}
return task_valid;
}
void RtModelZeroCopy::Release(uint32_t graph_id) { (void)graph_zero_copy_tasks_.erase(graph_id); }
} // namespace tasksink
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,102 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_
#define MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_
#include <memory>
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "runtime/device/device_address.h"
#include "backend/common/session/kernel_graph.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace tasksink {
class ZeroCopyTask {
public:
ZeroCopyTask(AnfNodeWeakPtr anf_node, void *args_base, size_t args_offset, std::string task_name)
: anf_node_(std::move(anf_node)),
args_base_(args_base),
args_offset_(args_offset),
task_name_(std::move(task_name)) {}
~ZeroCopyTask() = default;
// Update the address in task args
bool UpdateArgs(void *stream);
virtual void *GetAddressPtr() = 0;
protected:
// Parameter or ValueNode
AnfNodeWeakPtr anf_node_;
private:
void *args_base_;
size_t args_offset_;
std::string task_name_;
void *device_ptr_{nullptr};
void *previous_ptr_{nullptr};
};
using ZeroCopyTaskPtr = std::shared_ptr<ZeroCopyTask>;
class ParameterZeroCopyTask : public ZeroCopyTask {
public:
ParameterZeroCopyTask(const AnfNodeWeakPtr &anf_node, void *args_base, size_t args_offset,
const std::string &task_name)
: ZeroCopyTask(anf_node, args_base, args_offset, task_name) {}
void *GetAddressPtr() override;
};
class ValueNodeZeroCopyTask : public ZeroCopyTask {
public:
ValueNodeZeroCopyTask(const AnfNodeWeakPtr &anf_node, void *args_base, size_t args_offset,
const std::string &task_name)
: ZeroCopyTask(anf_node, args_base, args_offset, task_name) {}
void *GetAddressPtr() override;
};
// Update the device address in task without copying data.
// Usually we assume that the address in the task is constant.
// If the address of graph input changed, we need to copy data of graph input tensor to the address of the task.
// In fact, when the operator is executed, the real address is obtained from the task args.
// Task args is a secondary pointer, which stores the input address or output address of the operator.
// If we can update the input address in task args, we can avoid data copying.
class RtModelZeroCopy {
public:
RtModelZeroCopy() = default;
~RtModelZeroCopy() = default;
// Generate ZeroCopyTasks after the tasks of rtModel is Distributed. (Need to get task args address)
bool GenerateZeroCopyTasks(const session::KernelGraph &graph);
// Copy device address ptr to task args if the ptr changed.
bool UpdateTaskArgs(const session::KernelGraph &graph, void *stream) const;
// Check rtModel after update task args. The process of checking is consistent with the process of generating tasks.
static bool CheckRtModelValid(const session::KernelGraph &graph);
// Release resource after the graph is destructed.
void Release(uint32_t graph_id);
private:
mindspore::HashMap<uint32_t, std::vector<ZeroCopyTaskPtr>> graph_zero_copy_tasks_;
};
} // namespace tasksink
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_

View File

@ -162,13 +162,82 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
} }
} }
AddressPtrList TaskGenerator::GetTaskInput(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
AddressPtrList kernel_inputs;
auto op_name = common::AnfAlgo::GetCNodeName(node);
if (op_name == kAtomicAddrCleanOpName) {
LaunchAddrCleanKernel(node, &kernel_inputs);
return kernel_inputs;
}
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_num; ++i) {
if (common::AnfAlgo::IsNoneInput(node, i)) {
continue;
}
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(node, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, input_index_in_graph);
AddressPtr input = std::make_shared<Address>();
MS_EXCEPTION_IF_NULL(input);
input->addr = device_address->ptr_;
input->size = device_address->size_;
auto prenode_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph);
MS_EXCEPTION_IF_NULL(prenode_with_index.first);
if (AnfUtils::IsRealCNodeKernel(prenode_with_index.first)) {
if (common::AnfAlgo::IsNonTaskOp(prenode_with_index.first->cast<CNodePtr>())) {
// use memory offset to implement NonTask Type Split op
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
// offset is split's output index * split's output size
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
MS_EXCEPTION_IF_NULL(split_input0_device_address);
input->addr =
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
MS_LOG(INFO) << "Change " << node->fullname_with_scope() << "'s input " << i << " address to "
<< split_input0_device_address->ptr_ << " + " << prenode_with_index.second * input->size;
}
}
kernel_inputs.push_back(input);
}
return kernel_inputs;
}
AddressPtrList TaskGenerator::GetTaskOutput(const CNodePtr &node) {
AddressPtrList kernel_outputs;
// No kernel output if output of the cnode is monad, such as LabelSwitch.
if (!HasAbstractMonad(node)) {
size_t output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
auto it = AnfAlgo::GetOutputAddr(node, i, false);
AddressPtr output = std::make_shared<Address>();
output->addr = it->ptr_;
output->size = it->size_;
kernel_outputs.push_back(output);
}
}
return kernel_outputs;
}
AddressPtrList TaskGenerator::GetTaskWorkspace(const CNodePtr &node) {
AddressPtrList kernel_workspaces;
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(node, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
workspace->addr = device_address->ptr_;
workspace->size = device_address->size_;
kernel_workspaces.push_back(workspace);
}
return kernel_workspaces;
}
bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id,
std::vector<TaskInfoPtr> *task_info_list) { std::vector<TaskInfoPtr> *task_info_list) {
MS_EXCEPTION_IF_NULL(task_info_list); MS_EXCEPTION_IF_NULL(task_info_list);
MS_EXCEPTION_IF_NULL(anf_node_ptr); MS_EXCEPTION_IF_NULL(anf_node_ptr);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
kernel_mod->set_unique_name(anf_node_ptr->UniqueName()); kernel_mod->set_unique_name(anf_node_ptr->UniqueName());
@ -185,59 +254,16 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
return true; return true;
} }
if (op_name != kAtomicAddrCleanOpName) { AddressPtrList kernel_inputs;
size_t input_num = common::AnfAlgo::GetInputTensorNum(anf_node_ptr); AddressPtrList kernel_workspaces;
for (size_t i = 0; i < input_num; ++i) { AddressPtrList kernel_outputs;
if (common::AnfAlgo::IsNoneInput(anf_node_ptr, i)) {
continue;
}
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(anf_node_ptr, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, input_index_in_graph);
AddressPtr input = std::make_shared<Address>();
MS_EXCEPTION_IF_NULL(input);
input->addr = device_address->ptr_;
input->size = device_address->size_;
auto prenode_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node_ptr, input_index_in_graph); if (op_name == kAtomicAddrCleanOpName) {
MS_EXCEPTION_IF_NULL(prenode_with_index.first);
if (AnfUtils::IsRealCNodeKernel(prenode_with_index.first)) {
if (common::AnfAlgo::IsNonTaskOp(prenode_with_index.first->cast<CNodePtr>())) {
// use memory offset to implement NonTask Type Split op
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
// offset is split's output index * split's output size
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
MS_EXCEPTION_IF_NULL(split_input0_device_address);
input->addr =
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to "
<< split_input0_device_address->ptr_ << " + " << prenode_with_index.second * input->size;
}
}
kernel_inputs.push_back(input);
}
// No kernel output if output of the cnode is monad, such as LabelSwitch.
if (!HasAbstractMonad(anf_node_ptr)) {
size_t output_num = common::AnfAlgo::GetOutputTensorNum(anf_node_ptr);
for (size_t i = 0; i < output_num; ++i) {
auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i, false);
AddressPtr output = std::make_shared<Address>();
output->addr = it->ptr_;
output->size = it->size_;
kernel_outputs.push_back(output);
}
}
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
workspace->addr = device_address->ptr_;
workspace->size = device_address->size_;
kernel_workspaces.push_back(workspace);
}
} else {
LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs); LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs);
} else {
kernel_inputs = GetTaskInput(anf_node_ptr);
kernel_workspaces = GetTaskWorkspace(anf_node_ptr);
kernel_outputs = GetTaskOutput(anf_node_ptr);
} }
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod); auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
@ -251,6 +277,7 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
MS_LOG(ERROR) << "Empty task_info_ptrs."; MS_LOG(ERROR) << "Empty task_info_ptrs.";
return false; return false;
} }
MS_LOG(INFO) << "Node " << anf_node_ptr->fullname_with_scope() << " get task " << task_info_ptrs.front()->op_name();
debug_info->op_name_ = anf_node_ptr->fullname_with_scope(); debug_info->op_name_ = anf_node_ptr->fullname_with_scope();
debug_info->task_num_ = task_info_ptrs.size(); debug_info->task_num_ = task_info_ptrs.size();
debug_info->stream_id_ = task_info_ptrs[0]->stream_id(); debug_info->stream_id_ = task_info_ptrs[0]->stream_id();

View File

@ -63,6 +63,10 @@ class TaskGenerator {
std::vector<TaskDebugInfoPtr> GetTaskDebugInfo() const { return task_debug_info_list_; } std::vector<TaskDebugInfoPtr> GetTaskDebugInfo() const { return task_debug_info_list_; }
static void DumpTaskInfo(const string &real_filename, const std::vector<TaskDebugInfoPtr> &task_debug_info_list); static void DumpTaskInfo(const string &real_filename, const std::vector<TaskDebugInfoPtr> &task_debug_info_list);
static AddressPtrList GetTaskInput(const CNodePtr &node);
static AddressPtrList GetTaskOutput(const CNodePtr &node);
static AddressPtrList GetTaskWorkspace(const CNodePtr &node);
private: private:
std::vector<TaskDebugInfoPtr> task_debug_info_list_; std::vector<TaskDebugInfoPtr> task_debug_info_list_;
static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs);

View File

@ -97,7 +97,7 @@ bool AscendDeviceContext::PartitionGraph(const FuncGraphPtr &func_graph) const {
RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const { RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && IsGraphMode() && !IsDynamicShapeGraph(func_graph)) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && !IsDynamicShapeGraph(func_graph)) {
return RunMode::kGraphMode; return RunMode::kGraphMode;
} else { } else {
return RunMode::kKernelMode; return RunMode::kKernelMode;

View File

@ -33,6 +33,7 @@ namespace device {
namespace ascend { namespace ascend {
using KernelGraph = mindspore::session::KernelGraph; using KernelGraph = mindspore::session::KernelGraph;
namespace {
CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) { CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
size_t node_sizes = kernel_nodes.size(); size_t node_sizes = kernel_nodes.size();
if (index >= node_sizes - 1) { if (index >= node_sizes - 1) {
@ -162,6 +163,25 @@ void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
UnfoldRecursiveExecOrder(kernel_graph); UnfoldRecursiveExecOrder(kernel_graph);
} }
void EnableGraphInputZeroCopy(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
// Zero copy is only enabled for PyNative.
if (!graph->has_flag(kFlagPyNativeRunInGraph) || !graph->is_graph_run_mode()) {
return;
}
const auto &input_nodes = graph->input_nodes();
for (const auto &input : input_nodes) {
MS_EXCEPTION_IF_NULL(input);
if (AnfAlgo::OutputAddrExist(input, 0)) {
auto input_address = AnfAlgo::GetMutableOutputAddr(input, 0);
MS_EXCEPTION_IF_NULL(input_address);
input_address->set_is_ptr_persisted(false);
MS_LOG(INFO) << "Enable zero copy for input " << input->DebugString();
}
}
}
} // namespace
void AscendGraphExecutor::Initialize() { void AscendGraphExecutor::Initialize() {
res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get()); res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get());
MS_EXCEPTION_IF_NULL(res_manager_); MS_EXCEPTION_IF_NULL(res_manager_);
@ -217,6 +237,7 @@ void AscendGraphExecutor::PreprocessBeforeRun(const KernelGraphPtr &graph) const
AllocateGraphMemory(NOT_NULL(graph)); AllocateGraphMemory(NOT_NULL(graph));
LoadModel(NOT_NULL(graph)); LoadModel(NOT_NULL(graph));
AssignOutputNopNodeDeviceAddress(graph, device_context_); AssignOutputNopNodeDeviceAddress(graph, device_context_);
EnableGraphInputZeroCopy(graph);
} }
void AscendGraphExecutor::UpdateExecOrder(const KernelGraphPtr &graph) const { void AscendGraphExecutor::UpdateExecOrder(const KernelGraphPtr &graph) const {

View File

@ -220,12 +220,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); auto hccl_node = anf_node_.lock();
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); MS_EXCEPTION_IF_NULL(hccl_node);
if (!mutable_workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode) || auto func_graph = hccl_node->func_graph();
mode == kPynativeMode) { auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_run_mode = kernel_graph->is_graph_run_mode();
// Not task sink mode.
if (!mutable_workspace_size_list_.empty() || hccl_data_type_list_.empty() || !graph_run_mode) {
return mutable_workspace_size_list_; return mutable_workspace_size_list_;
} }
// Task sink mode.
mutable_workspace_size_list_.emplace_back( mutable_workspace_size_list_.emplace_back(
hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0])); hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
return mutable_workspace_size_list_; return mutable_workspace_size_list_;

View File

@ -149,7 +149,8 @@ bool EliminateGraphOutputTransdata::Run(const FuncGraphPtr &func_graph) {
std::vector<PrimitivePtr> return_type = {prim::kPrimMakeTuple}; std::vector<PrimitivePtr> return_type = {prim::kPrimMakeTuple};
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(func_graph->output(), 0, false, return_type); auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(func_graph->output(), 0, false, return_type);
if (!common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { if (!common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Graph output is not a MakeTuple"; MS_LOG(INFO) << "Graph output is not a MakeTuple";
return true;
} }
auto transdata_ref_count = GetTransdataRefCount(func_graph); auto transdata_ref_count = GetTransdataRefCount(func_graph);

View File

@ -695,8 +695,12 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
opt::BackendCommonOptimization(root_graph); opt::BackendCommonOptimization(root_graph);
root_graph->SetInputNodes(); root_graph->SetInputNodes();
auto graph_id = CompileGraphImpl(root_graph, device_context); GraphId graph_id = 0;
if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
graph_id = CompileGraphImpl(root_graph, device_context);
} else {
graph_id = root_graph->graph_id();
}
// Set summary nodes for all graphs. // Set summary nodes for all graphs.
session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs); session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs);
@ -704,6 +708,7 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
// for ascend mindRT. // for ascend mindRT.
session_->DumpGraphs(all_graphs); session_->DumpGraphs(all_graphs);
if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
// Cache the backend graph output nodes to front nodes with output index. // Cache the backend graph output nodes to front nodes with output index.
auto output = func_graph->output(); auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
@ -711,6 +716,14 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
MS_EXCEPTION_IF_NULL(backend_node); MS_EXCEPTION_IF_NULL(backend_node);
root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output}); root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
AnfAlgo::UpdateGraphValidRefPair(root_graph); AnfAlgo::UpdateGraphValidRefPair(root_graph);
} else {
for (auto &node : root_graph->execution_order()) {
if (common::AnfAlgo::IsControlOpExecInBackend(node)) {
root_graph->set_flag(kFlagsIsCutGraph, true);
}
}
root_graph->set_front_outputs({func_graph->output()});
}
MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id; MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
return graph_id; return graph_id;

View File

@ -22,6 +22,7 @@
#include "ir/tensor.h" #include "ir/tensor.h"
#include "include/common/utils/convert_utils.h" #include "include/common/utils/convert_utils.h"
#include "include/common/utils/anfalgo.h" #include "include/common/utils/anfalgo.h"
#include "include/common/utils/parallel_context.h"
#include "backend/common/session/anf_runtime_algorithm.h" #include "backend/common/session/anf_runtime_algorithm.h"
#include "runtime/graph_scheduler/device_tensor_store.h" #include "runtime/graph_scheduler/device_tensor_store.h"
#include "runtime/device/ms_device_shape_transfer.h" #include "runtime/device/ms_device_shape_transfer.h"
@ -330,4 +331,27 @@ void GraphAdapter::ReplaceGraphParameterProperties(const KernelGraphPtr &graph,
} }
} }
} }
bool GraphAdapter::IsAutoParallel() {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
}
bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
bool pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
if (!pynative_mode) {
return true;
}
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
auto is_cut_graph = std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
return common::AnfAlgo::IsControlOpExecInBackend(node);
});
return !IsAutoParallel() && !is_cut_graph && !func_graph->has_flag(kFlagIsDynamicStructure);
}
} // namespace mindspore::pynative } // namespace mindspore::pynative

View File

@ -33,6 +33,8 @@ class GraphAdapter {
static void RemoveUnusedValueNodes(const KernelGraphPtr &graph); static void RemoveUnusedValueNodes(const KernelGraphPtr &graph);
static void HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &tensors, static void HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &tensors,
const std::vector<device::DeviceContext *> &device_contexts); const std::vector<device::DeviceContext *> &device_contexts);
static bool PyNativeEnableTaskSink(const FuncGraphPtr &func_graph);
static bool IsAutoParallel();
}; };
} // namespace mindspore::pynative } // namespace mindspore::pynative
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_ #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_

View File

@ -160,6 +160,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_device_address.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_device_address.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_pool.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_pool.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"

View File

@ -27,6 +27,10 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
uint32_t graph_id) { uint32_t graph_id) {
return true; return true;
} }
AddressPtrList TaskGenerator::GetTaskInput(const CNodePtr &node) { return {}; }
AddressPtrList TaskGenerator::GetTaskOutput(const CNodePtr &node) { return {}; }
AddressPtrList TaskGenerator::GetTaskWorkspace(const CNodePtr &node) { return {}; }
} // namespace tasksink } // namespace tasksink
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
void DataDumper::LoadDumpInfo() {} void DataDumper::LoadDumpInfo() {}