PyNative Ascend Task Sink
This commit is contained in:
parent
855e06c7ca
commit
2d6fbcd9b3
|
@ -550,6 +550,20 @@ bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t outp
|
||||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||||
bool skip_nop_node) {
|
bool skip_nop_node) {
|
||||||
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
|
||||||
|
if (kernel_with_index.first->isa<ValueNode>()) {
|
||||||
|
auto value_node = kernel_with_index.first->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value = value_node->value();
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if (tensor->is_forward_output()) {
|
||||||
|
return dynamic_cast<const DeviceAddress *>(tensor->device_address().get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -454,13 +454,6 @@ std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompile
|
||||||
|
|
||||||
return input_tensor_lists;
|
return input_tensor_lists;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsAutoParallel() {
|
|
||||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
|
||||||
auto parallel_mode = parallel_context->parallel_mode();
|
|
||||||
return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
||||||
|
@ -605,6 +598,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||||
real_execution_mode_ = ms_execution_mode_;
|
real_execution_mode_ = ms_execution_mode_;
|
||||||
|
func_graph->set_flag(kFlagPyNativeRunInGraph, real_execution_mode_ == kPynativeMode);
|
||||||
|
|
||||||
// Compile root graph.
|
// Compile root graph.
|
||||||
graph_id_to_device_context_.clear();
|
graph_id_to_device_context_.clear();
|
||||||
|
@ -615,9 +609,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||||
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
bool all_support = device_context->PartitionGraph(func_graph);
|
bool all_support = device_context->PartitionGraph(func_graph);
|
||||||
auto run_mode = device_context->GetRunMode(func_graph);
|
|
||||||
if (all_support) {
|
if (all_support) {
|
||||||
if (run_mode == device::RunMode::kGraphMode) {
|
auto run_mode = device_context->GetRunMode(func_graph);
|
||||||
|
if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
|
||||||
auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
|
auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
|
||||||
graph_id_to_device_context_[graph_id] = device_context;
|
graph_id_to_device_context_[graph_id] = device_context;
|
||||||
} else {
|
} else {
|
||||||
|
@ -1004,8 +998,9 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
|
||||||
graph->set_flag(kFlagPyNativeRunInGraph, true);
|
graph->set_flag(kFlagPyNativeRunInGraph, true);
|
||||||
graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph));
|
graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph));
|
||||||
|
|
||||||
// The size of control_nodes is at least 1 since there is return node in the graph.
|
// KernelByKernel: The size of control_nodes is at least 1 since there is return node in the graph.
|
||||||
if (control_nodes_.size() == 1 && graphs.size() == 1) {
|
// GraphMode: No control nodes.
|
||||||
|
if (control_nodes_.size() <= 1 && graphs.size() == 1) {
|
||||||
MS_LOG(INFO) << "Replace parameter format";
|
MS_LOG(INFO) << "Replace parameter format";
|
||||||
// The input tensors of heterogeneous graphs or control flow graphs are null.
|
// The input tensors of heterogeneous graphs or control flow graphs are null.
|
||||||
// Need to get tensor after ParseControlNodes.
|
// Need to get tensor after ParseControlNodes.
|
||||||
|
@ -1121,7 +1116,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
||||||
|
|
||||||
// Save grad node to Bucket
|
// Save grad node to Bucket
|
||||||
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
||||||
!kernel->is_parallel() && IsAutoParallel()) {
|
!kernel->is_parallel() && pynative::GraphAdapter::IsAutoParallel()) {
|
||||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -951,13 +951,6 @@ void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
|
||||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
|
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
|
||||||
backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
|
backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
|
||||||
};
|
};
|
||||||
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
|
||||||
// PYNATIVE: no need set any context.
|
|
||||||
if (mode == kPynativeMode) {
|
|
||||||
MS_LOG(INFO) << "Run graph mode with pynative.";
|
|
||||||
set_ctx(false, false, false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||||
// GPU/CPU no need set any context.
|
// GPU/CPU no need set any context.
|
||||||
|
|
|
@ -802,7 +802,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
|
||||||
if (top_cell()->dynamic_shape()) {
|
if (top_cell()->dynamic_shape()) {
|
||||||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE, true);
|
bprop_graph->set_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE, true);
|
||||||
}
|
}
|
||||||
if (top_cell()->is_dynamic_structure()) {
|
if (top_cell()->is_real_dynamic_structure()) {
|
||||||
bprop_graph->set_flag(kFlagIsDynamicStructure, true);
|
bprop_graph->set_flag(kFlagIsDynamicStructure, true);
|
||||||
}
|
}
|
||||||
// Do opt for final bprop graph
|
// Do opt for final bprop graph
|
||||||
|
@ -918,7 +918,7 @@ void GradExecutor::CheckNeedCompileGraph() {
|
||||||
EraseTopCellFromTopCellList(pre_top_cell);
|
EraseTopCellFromTopCellList(pre_top_cell);
|
||||||
pre_top_cell->Clear();
|
pre_top_cell->Clear();
|
||||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||||
top_cell()->set_dynamic_structure(true);
|
top_cell()->set_is_real_dynamic_structure(true);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
|
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
|
||||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||||
|
|
|
@ -96,6 +96,10 @@ class TopCellInfo {
|
||||||
size_t grad_order() const { return grad_order_; }
|
size_t grad_order() const { return grad_order_; }
|
||||||
bool is_dynamic_structure() const { return is_dynamic_structure_; }
|
bool is_dynamic_structure() const { return is_dynamic_structure_; }
|
||||||
void set_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
|
void set_dynamic_structure(bool is_dynamic_structure) { is_dynamic_structure_ = is_dynamic_structure; }
|
||||||
|
bool is_real_dynamic_structure() const { return is_real_dynamic_structure_; }
|
||||||
|
void set_is_real_dynamic_structure(bool is_real_dynamic_structure) {
|
||||||
|
is_real_dynamic_structure_ = is_real_dynamic_structure;
|
||||||
|
}
|
||||||
bool dynamic_shape() const { return dynamic_shape_; }
|
bool dynamic_shape() const { return dynamic_shape_; }
|
||||||
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
|
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
|
||||||
bool hook_changed() const { return hook_changed_; }
|
bool hook_changed() const { return hook_changed_; }
|
||||||
|
@ -184,6 +188,8 @@ class TopCellInfo {
|
||||||
private:
|
private:
|
||||||
bool is_topest_{false};
|
bool is_topest_{false};
|
||||||
bool is_dynamic_structure_{false};
|
bool is_dynamic_structure_{false};
|
||||||
|
// Set this flag to ture when all_op_info of top_cell is changed.
|
||||||
|
bool is_real_dynamic_structure_{false};
|
||||||
bool dynamic_shape_{false};
|
bool dynamic_shape_{false};
|
||||||
bool vm_compiled_{false};
|
bool vm_compiled_{false};
|
||||||
bool hook_changed_{false};
|
bool hook_changed_{false};
|
||||||
|
|
|
@ -258,6 +258,8 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
|
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rt_model_zero_copy_.Release(graph_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const {
|
void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const {
|
||||||
|
@ -575,6 +577,11 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph &graph) {
|
||||||
MS_LOG(EXCEPTION) << "Distribute Task Failed, \nerror msg: " << e.what();
|
MS_LOG(EXCEPTION) << "Distribute Task Failed, \nerror msg: " << e.what();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!rt_model_zero_copy_.GenerateZeroCopyTasks(graph)) {
|
||||||
|
MS_LOG(ERROR) << "Generate ZeroCopyTask failed, graph id " << graph.graph_id();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
if (ProfilingManager::GetInstance().IsProfilingInitialized()) {
|
if (ProfilingManager::GetInstance().IsProfilingInitialized()) {
|
||||||
auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first);
|
auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first);
|
||||||
|
@ -1065,6 +1072,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph &graph) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!rt_model_zero_copy_.UpdateTaskArgs(graph, compute_stream())) {
|
||||||
|
MS_LOG(ERROR) << "Update rtModel task args failed, graph id " << graph.graph_id();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ModelRunner::Instance().RunModel(graph.graph_id());
|
ModelRunner::Instance().RunModel(graph.graph_id());
|
||||||
} catch (const std::exception &) {
|
} catch (const std::exception &) {
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "runtime/device/kernel_runtime.h"
|
#include "runtime/device/kernel_runtime.h"
|
||||||
#include "runtime/context.h"
|
#include "runtime/context.h"
|
||||||
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
|
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.h"
|
||||||
#include "runtime/device/kernel_runtime_manager.h"
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
#include "backend/common/session/session_basic.h"
|
#include "backend/common/session/session_basic.h"
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
|
@ -114,6 +115,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
static bool DeleteDumpDir(const std::string &path);
|
static bool DeleteDumpDir(const std::string &path);
|
||||||
static int DeleteDumpFile(std::string path);
|
static int DeleteDumpFile(std::string path);
|
||||||
static std::string GetRealPath(const std::string &path);
|
static std::string GetRealPath(const std::string &path);
|
||||||
|
void CreateDefaultStream(uint32_t device_id);
|
||||||
|
|
||||||
rtContext_t rt_context_{nullptr};
|
rtContext_t rt_context_{nullptr};
|
||||||
bool initialized_{false};
|
bool initialized_{false};
|
||||||
|
@ -128,7 +130,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_;
|
std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_;
|
||||||
std::map<uint32_t, void *> stream_id_map_;
|
std::map<uint32_t, void *> stream_id_map_;
|
||||||
std::set<uint32_t> initialized_device_set_{};
|
std::set<uint32_t> initialized_device_set_{};
|
||||||
void CreateDefaultStream(uint32_t device_id);
|
tasksink::RtModelZeroCopy rt_model_zero_copy_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);
|
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);
|
||||||
|
|
|
@ -481,7 +481,7 @@ void AscendStreamAssign::InsertEventForNonTaskSink(const NotNull<KernelGraphPtr>
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
if (graph_ptr->has_flag(kFlagPyNativeRunInGraph)) {
|
if (graph_ptr->has_flag(kFlagPyNativeRunInGraph) && !graph_ptr->is_graph_run_mode()) {
|
||||||
AssignStreamForPynativeGraph(graph_ptr.get());
|
AssignStreamForPynativeGraph(graph_ptr.get());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,15 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
|
||||||
return model_iter->second->GetTaskIdList();
|
return model_iter->second->GetTaskIdList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Task>> &ModelRunner::GetTaskList(uint32_t model_id) const {
|
||||||
|
auto model_iter = runtime_models_.find(model_id);
|
||||||
|
if (model_iter == runtime_models_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Model id " << model_id << " not found.";
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(model_iter->second);
|
||||||
|
return model_iter->second->GetTaskList();
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
|
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
|
||||||
decltype(runtime_models_)::const_iterator model_iter = runtime_models_.find(model_id);
|
decltype(runtime_models_)::const_iterator model_iter = runtime_models_.find(model_id);
|
||||||
if (model_iter == runtime_models_.cend()) {
|
if (model_iter == runtime_models_.cend()) {
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
|
#include "plugin/device/ascend/hal/device/ge_runtime/davinci_model.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/ge_runtime/task/task.h"
|
||||||
|
|
||||||
namespace mindspore::ge::model_runner {
|
namespace mindspore::ge::model_runner {
|
||||||
class RuntimeModel;
|
class RuntimeModel;
|
||||||
|
@ -38,6 +39,8 @@ class ModelRunner {
|
||||||
|
|
||||||
void LoadModelComplete(uint32_t model_id);
|
void LoadModelComplete(uint32_t model_id);
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Task>> &GetTaskList(uint32_t model_id) const;
|
||||||
|
|
||||||
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
|
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
|
||||||
|
|
||||||
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
|
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
|
||||||
|
|
|
@ -293,5 +293,7 @@ void RuntimeModel::RtEventDestory() noexcept {
|
||||||
|
|
||||||
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
|
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<Task>> &RuntimeModel::GetTaskList() const { return task_list_; }
|
||||||
|
|
||||||
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
|
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
|
||||||
} // namespace mindspore::ge::model_runner
|
} // namespace mindspore::ge::model_runner
|
||||||
|
|
|
@ -36,6 +36,7 @@ class RuntimeModel {
|
||||||
void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
|
void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model);
|
||||||
void DistributeTask();
|
void DistributeTask();
|
||||||
void LoadComplete();
|
void LoadComplete();
|
||||||
|
const std::vector<std::shared_ptr<Task>> &GetTaskList() const;
|
||||||
const std::vector<uint32_t> &GetTaskIdList() const;
|
const std::vector<uint32_t> &GetTaskIdList() const;
|
||||||
const std::vector<uint32_t> &GetStreamIdList() const;
|
const std::vector<uint32_t> &GetStreamIdList() const;
|
||||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
|
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
|
||||||
|
|
|
@ -30,7 +30,8 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai
|
||||||
stream_(nullptr),
|
stream_(nullptr),
|
||||||
args_(nullptr),
|
args_(nullptr),
|
||||||
ext_info_(nullptr),
|
ext_info_(nullptr),
|
||||||
input_output_addr_(nullptr) {
|
input_output_addr_(nullptr),
|
||||||
|
args_size_(0) {
|
||||||
MS_EXCEPTION_IF_NULL(task_info_);
|
MS_EXCEPTION_IF_NULL(task_info_);
|
||||||
|
|
||||||
auto stream_list = model_context.stream_list();
|
auto stream_list = model_context.stream_list();
|
||||||
|
@ -60,20 +61,20 @@ void AicpuTask::Distribute() {
|
||||||
(void)io_addrs.insert(io_addrs.cend(), task_info_->output_data_addrs().cbegin(),
|
(void)io_addrs.insert(io_addrs.cend(), task_info_->output_data_addrs().cbegin(),
|
||||||
task_info_->output_data_addrs().cend());
|
task_info_->output_data_addrs().cend());
|
||||||
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
|
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
|
||||||
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
|
io_addrs_size_ = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
|
||||||
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
|
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
|
||||||
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
|
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size_;
|
||||||
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
|
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
|
||||||
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
|
args_size_ = sizeof(aicpu::AicpuParamHead) + io_addrs_size_ + static_cast<uint32_t>(task_info_->node_def().size()) +
|
||||||
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);
|
sizeof(uint32_t);
|
||||||
|
|
||||||
// Malloc device memory for args
|
// Malloc device memory for args
|
||||||
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
|
rtError_t rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret;
|
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
SetAicpuParamHead(args_size, io_addrs_num);
|
SetAicpuParamHead(args_size_, io_addrs_num);
|
||||||
SetInputOutputAddrs(io_addrs, io_addr_offset);
|
SetInputOutputAddrs(io_addrs, io_addr_offset);
|
||||||
SetNodeDef(node_def_len_offset, node_def_addr_offset);
|
SetNodeDef(node_def_len_offset, node_def_addr_offset);
|
||||||
|
|
||||||
|
@ -82,12 +83,12 @@ void AicpuTask::Distribute() {
|
||||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
||||||
auto cpu_flag = task_info_->cust_aicpu() ? RT_KERNEL_CUSTOM_AICPU : dump_flag;
|
auto cpu_flag = task_info_->cust_aicpu() ? RT_KERNEL_CUSTOM_AICPU : dump_flag;
|
||||||
|
|
||||||
MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size << ", io_addrs_num =" << io_addrs_num
|
MS_LOG(INFO) << "Distribute AicpuTask start, args_size = " << args_size_ << ", io_addrs_num =" << io_addrs_num
|
||||||
<< ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name()
|
<< ", so_name = " << task_info_->so_name() << ", kernel_name = " << task_info_->kernel_name()
|
||||||
<< ", dump_flag = " << dump_flag;
|
<< ", dump_flag = " << dump_flag;
|
||||||
rtArgsEx_t argsInfo = {};
|
rtArgsEx_t argsInfo = {};
|
||||||
argsInfo.args = args_;
|
argsInfo.args = args_;
|
||||||
argsInfo.argsSize = args_size;
|
argsInfo.argsSize = args_size_;
|
||||||
rt_ret = rtCpuKernelLaunchWithFlag(static_cast<const void *>(task_info_->so_name().data()),
|
rt_ret = rtCpuKernelLaunchWithFlag(static_cast<const void *>(task_info_->so_name().data()),
|
||||||
static_cast<const void *>(task_info_->kernel_name().data()), 1, &argsInfo, nullptr,
|
static_cast<const void *>(task_info_->kernel_name().data()), 1, &argsInfo, nullptr,
|
||||||
stream_, cpu_flag);
|
stream_, cpu_flag);
|
||||||
|
|
|
@ -33,6 +33,8 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
|
||||||
|
|
||||||
void *Args() const override { return input_output_addr_; }
|
void *Args() const override { return input_output_addr_; }
|
||||||
|
|
||||||
|
size_t ArgsSize() const override { return io_addrs_size_; }
|
||||||
|
|
||||||
std::string task_name() const override { return task_info_->op_name(); }
|
std::string task_name() const override { return task_info_->op_name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -46,6 +48,8 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
|
||||||
void *args_;
|
void *args_;
|
||||||
void *ext_info_;
|
void *ext_info_;
|
||||||
void *input_output_addr_;
|
void *input_output_addr_;
|
||||||
|
size_t io_addrs_size_;
|
||||||
|
size_t args_size_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::ge::model_runner
|
} // namespace mindspore::ge::model_runner
|
||||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
|
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_AICPU_TASK_H_
|
||||||
|
|
|
@ -33,6 +33,8 @@ class Task {
|
||||||
|
|
||||||
virtual void *Args() const { return nullptr; }
|
virtual void *Args() const { return nullptr; }
|
||||||
|
|
||||||
|
virtual size_t ArgsSize() const { return 0; }
|
||||||
|
|
||||||
virtual std::string task_name() const { return ""; }
|
virtual std::string task_name() const { return ""; }
|
||||||
|
|
||||||
void set_model_handle(rtModel_t model_handle) { model_handle_ = model_handle; }
|
void set_model_handle(rtModel_t model_handle) { model_handle_ = model_handle; }
|
||||||
|
|
|
@ -27,7 +27,8 @@ TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTas
|
||||||
task_info_(task_info),
|
task_info_(task_info),
|
||||||
stream_(nullptr),
|
stream_(nullptr),
|
||||||
stub_func_(nullptr),
|
stub_func_(nullptr),
|
||||||
args_(nullptr) {
|
args_(nullptr),
|
||||||
|
args_size_(0) {
|
||||||
MS_EXCEPTION_IF_NULL(task_info);
|
MS_EXCEPTION_IF_NULL(task_info);
|
||||||
|
|
||||||
auto stream_list = model_context.stream_list();
|
auto stream_list = model_context.stream_list();
|
||||||
|
@ -74,14 +75,14 @@ void TbeTask::Distribute() {
|
||||||
task_info_->output_data_addrs().cend());
|
task_info_->output_data_addrs().cend());
|
||||||
tensor_device_addrs.insert(tensor_device_addrs.cend(), task_info_->workspace_addrs().cbegin(),
|
tensor_device_addrs.insert(tensor_device_addrs.cend(), task_info_->workspace_addrs().cbegin(),
|
||||||
task_info_->workspace_addrs().cend());
|
task_info_->workspace_addrs().cend());
|
||||||
auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
|
args_size_ = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
|
||||||
|
|
||||||
rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
|
rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size;
|
MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
rt_ret = aclrtMemcpy(args_, args_size, static_cast<void *>(tensor_device_addrs.data()), args_size,
|
rt_ret = aclrtMemcpy(args_, args_size_, static_cast<void *>(tensor_device_addrs.data()), args_size_,
|
||||||
ACL_MEMCPY_HOST_TO_DEVICE);
|
ACL_MEMCPY_HOST_TO_DEVICE);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret;
|
MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret;
|
||||||
|
@ -91,10 +92,10 @@ void TbeTask::Distribute() {
|
||||||
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
|
||||||
rtArgsEx_t args_info = {};
|
rtArgsEx_t args_info = {};
|
||||||
args_info.args = args_;
|
args_info.args = args_;
|
||||||
args_info.argsSize = args_size;
|
args_info.argsSize = args_size_;
|
||||||
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), &args_info, nullptr, stream_, dump_flag);
|
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), &args_info, nullptr, stream_, dump_flag);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size;
|
MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size_;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
|
MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,8 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
|
||||||
|
|
||||||
void *Args() const override { return args_; }
|
void *Args() const override { return args_; }
|
||||||
|
|
||||||
|
size_t ArgsSize() const override { return args_size_; }
|
||||||
|
|
||||||
std::string task_name() const override { return task_info_->op_name(); }
|
std::string task_name() const override { return task_info_->op_name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -39,6 +41,7 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
|
||||||
void *stream_;
|
void *stream_;
|
||||||
void *stub_func_;
|
void *stub_func_;
|
||||||
void *args_;
|
void *args_;
|
||||||
|
size_t args_size_;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::ge::model_runner
|
} // namespace mindspore::ge::model_runner
|
||||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
|
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_TASK_TBE_TASK_H_
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "runtime/rt.h"
|
||||||
|
#include "external/acl/acl_rt.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "runtime/device/kernel_info.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/ge_runtime/model_runner.h"
|
||||||
|
#include "plugin/device/ascend/hal/device/tasksink/task_generator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
namespace tasksink {
|
||||||
|
using TaskPtr = std::shared_ptr<ge::model_runner::Task>;
|
||||||
|
namespace {
|
||||||
|
bool IsForwardOutputValueNode(const AnfNodePtr &input) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
if (input->isa<ValueNode>()) {
|
||||||
|
auto value_node = input->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value = value_node->value();
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
if (tensor->is_forward_output()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckTaskValid(const CNodePtr &node, const std::vector<void *> &args_datas) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
bool task_valid = true;
|
||||||
|
// Check input/output/workspace
|
||||||
|
auto input_addrs = TaskGenerator::GetTaskInput(node);
|
||||||
|
auto output_addrs = TaskGenerator::GetTaskOutput(node);
|
||||||
|
auto workspace_addrs = TaskGenerator::GetTaskWorkspace(node);
|
||||||
|
|
||||||
|
std::vector<AddressPtr> node_addresses;
|
||||||
|
std::move(input_addrs.begin(), input_addrs.end(), std::back_inserter(node_addresses));
|
||||||
|
std::move(output_addrs.begin(), output_addrs.end(), std::back_inserter(node_addresses));
|
||||||
|
std::move(workspace_addrs.begin(), workspace_addrs.end(), std::back_inserter(node_addresses));
|
||||||
|
|
||||||
|
if (node_addresses.size() != args_datas.size()) {
|
||||||
|
MS_LOG(ERROR) << "Check failed, Node " << node->UniqueName() << " total addr size " << node_addresses.size()
|
||||||
|
<< " is not equal to " << args_datas.size();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < node_addresses.size(); ++i) {
|
||||||
|
auto node_address = node_addresses[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(node_address);
|
||||||
|
if (node_address->addr != args_datas[i]) {
|
||||||
|
MS_LOG(ERROR) << "Node " << node->UniqueName() << " addr " << node_address->addr << " not equal to addr of task "
|
||||||
|
<< args_datas[i];
|
||||||
|
task_valid = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return task_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NeedSkipZeroCopy(const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (common::AnfAlgo::IsNonTaskOp(node)) {
|
||||||
|
MS_LOG(INFO) << "Skip generate ZeroCopyTask for NonTaskOp " << node->fullname_with_scope();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto kernel_type = AnfAlgo::GetKernelType(node);
|
||||||
|
if (kernel_type != KernelType::TBE_KERNEL && kernel_type != KernelType::AICPU_KERNEL) {
|
||||||
|
MS_LOG(INFO) << "Skip generate ZeroCopyTask for " << node->fullname_with_scope();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void *ParameterZeroCopyTask::GetAddressPtr() {
|
||||||
|
auto node = anf_node_.lock();
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<Parameter>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Not a parameter node " << node->DebugString();
|
||||||
|
}
|
||||||
|
auto kernel_info = dynamic_cast<KernelInfo *>(node->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
auto parameter_address = kernel_info->GetOutputAddr(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(parameter_address);
|
||||||
|
return parameter_address->GetMutablePtr();
|
||||||
|
}
|
||||||
|
|
||||||
|
void *ValueNodeZeroCopyTask::GetAddressPtr() {
|
||||||
|
auto node = anf_node_.lock();
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<ValueNode>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Not a ValueNode " << node->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value = value_node->value();
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
auto value_node_address = tensor->device_address();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node_address);
|
||||||
|
return value_node_address->GetMutablePtr();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ZeroCopyTask::UpdateArgs(void *stream) {
|
||||||
|
device_ptr_ = GetAddressPtr();
|
||||||
|
if (device_ptr_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Device address ptr is null, task " << task_name_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (device_ptr_ == previous_ptr_) {
|
||||||
|
MS_LOG(DEBUG) << "No need to update task of " << task_name_;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = aclrtMemcpyAsync(static_cast<uint8_t *>(args_base_) + args_offset_, sizeof(void *), &device_ptr_,
|
||||||
|
sizeof(void *), ACL_MEMCPY_HOST_TO_DEVICE, stream);
|
||||||
|
if (ret != ACL_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "Update task " << task_name_ << " aclrtMemcpy failed, ret " << ret;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
previous_ptr_ = device_ptr_;
|
||||||
|
MS_LOG(INFO) << "Update task " << task_name_ << " args_offset " << args_offset_ << " device_ptr " << device_ptr_
|
||||||
|
<< " success";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RtModelZeroCopy::GenerateZeroCopyTasks(const session::KernelGraph &graph) {
|
||||||
|
if (!graph.has_flag(kFlagPyNativeRunInGraph)) {
|
||||||
|
MS_LOG(INFO) << "RtModelZeroCopy is not enabled";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ZeroCopyTaskPtr> zero_copy_tasks;
|
||||||
|
auto task_lists = ge::model_runner::ModelRunner::Instance().GetTaskList(graph.graph_id());
|
||||||
|
std::map<std::string, TaskPtr> op_name_to_task;
|
||||||
|
std::transform(task_lists.begin(), task_lists.end(), std::inserter(op_name_to_task, op_name_to_task.end()),
|
||||||
|
[](const TaskPtr &task) { return std::make_pair(task->task_name(), task); });
|
||||||
|
|
||||||
|
auto nodes = graph.execution_order();
|
||||||
|
for (const auto &node : nodes) {
|
||||||
|
if (NeedSkipZeroCopy(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto op_name = node->UniqueName();
|
||||||
|
auto iter = op_name_to_task.find(op_name);
|
||||||
|
if (iter == op_name_to_task.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Cannot found task of op " << op_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto task = iter->second;
|
||||||
|
MS_EXCEPTION_IF_NULL(task);
|
||||||
|
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(node, i);
|
||||||
|
auto input = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph, true).first;
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
if (input->isa<Parameter>()) {
|
||||||
|
zero_copy_tasks.emplace_back(std::make_shared<tasksink::ParameterZeroCopyTask>(
|
||||||
|
input, task->Args(), i * sizeof(void *), task->task_name()));
|
||||||
|
MS_LOG(INFO) << "Generate ZeroCopyTask for Node " << node->fullname_with_scope() << " Parameter "
|
||||||
|
<< input->DebugString();
|
||||||
|
} else if (IsForwardOutputValueNode(input)) {
|
||||||
|
zero_copy_tasks.emplace_back(std::make_shared<tasksink::ValueNodeZeroCopyTask>(
|
||||||
|
input, task->Args(), i * sizeof(void *), task->task_name()));
|
||||||
|
MS_LOG(INFO) << "Generate ZeroCopyTask for Node " << node->fullname_with_scope() << " ValueNode "
|
||||||
|
<< input->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = graph_zero_copy_tasks_.try_emplace(graph.graph_id(), zero_copy_tasks);
|
||||||
|
if (!iter.second) {
|
||||||
|
MS_LOG(ERROR) << "Generate ZeroCopyTask failed, Duplicate graph id " << graph.graph_id();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RtModelZeroCopy::UpdateTaskArgs(const session::KernelGraph &graph, void *stream) const {
|
||||||
|
if (!graph.has_flag(kFlagPyNativeRunInGraph)) {
|
||||||
|
MS_LOG(INFO) << "RtModelZeroCopy is not enabled, no need to update task args.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = graph_zero_copy_tasks_.find(graph.graph_id());
|
||||||
|
if (iter == graph_zero_copy_tasks_.end()) {
|
||||||
|
MS_LOG(ERROR) << "No zero copy tasks found. graph id " << graph.graph_id();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto zero_copy_tasks = iter->second;
|
||||||
|
if (std::any_of(zero_copy_tasks.begin(), zero_copy_tasks.end(),
|
||||||
|
[stream](const ZeroCopyTaskPtr &task) { return !task->UpdateArgs(stream); })) {
|
||||||
|
MS_LOG(ERROR) << "Update task args failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Check rtMode valid " << ((rtStreamSynchronize(stream) == RT_ERROR_NONE) && CheckRtModelValid(graph));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RtModelZeroCopy::CheckRtModelValid(const session::KernelGraph &graph) {
|
||||||
|
auto graph_id = graph.graph_id();
|
||||||
|
auto tasks = ge::model_runner::ModelRunner::Instance().GetTaskList(graph_id);
|
||||||
|
std::map<std::string, TaskPtr> op_name_to_task;
|
||||||
|
std::transform(tasks.begin(), tasks.end(), std::inserter(op_name_to_task, op_name_to_task.end()),
|
||||||
|
[](const TaskPtr &task) { return std::make_pair(task->task_name(), task); });
|
||||||
|
|
||||||
|
auto nodes = graph.execution_order();
|
||||||
|
bool task_valid = true;
|
||||||
|
for (const auto &node : nodes) {
|
||||||
|
if (NeedSkipZeroCopy(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto unique_name = node->UniqueName();
|
||||||
|
auto iter = op_name_to_task.find(unique_name);
|
||||||
|
if (iter == op_name_to_task.end()) {
|
||||||
|
MS_LOG(ERROR) << "Cannot found task of op " << unique_name;
|
||||||
|
task_valid = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto task = iter->second;
|
||||||
|
MS_EXCEPTION_IF_NULL(task);
|
||||||
|
auto task_args = task->Args();
|
||||||
|
auto task_size = task->ArgsSize();
|
||||||
|
if (task_size == 0) {
|
||||||
|
// For example InitDataSet (AiCpu kernel).
|
||||||
|
MS_LOG(INFO) << "task name " << task->task_name() << " task_size is 0";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<void *> args_datas(task_size / sizeof(void *), nullptr);
|
||||||
|
if (aclrtMemcpy(args_datas.data(), task_size, task_args, task_size, ACL_MEMCPY_DEVICE_TO_HOST) != ACL_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "aclrtMemcpy failed, task " << task->task_name() << " task_size " << task_size;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CheckTaskValid(node, args_datas)) {
|
||||||
|
task_valid = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return task_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RtModelZeroCopy::Release(uint32_t graph_id) { (void)graph_zero_copy_tasks_.erase(graph_id); }
|
||||||
|
} // namespace tasksink
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,102 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "runtime/device/device_address.h"
|
||||||
|
#include "backend/common/session/kernel_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
namespace tasksink {
|
||||||
|
class ZeroCopyTask {
|
||||||
|
public:
|
||||||
|
ZeroCopyTask(AnfNodeWeakPtr anf_node, void *args_base, size_t args_offset, std::string task_name)
|
||||||
|
: anf_node_(std::move(anf_node)),
|
||||||
|
args_base_(args_base),
|
||||||
|
args_offset_(args_offset),
|
||||||
|
task_name_(std::move(task_name)) {}
|
||||||
|
~ZeroCopyTask() = default;
|
||||||
|
|
||||||
|
// Update the address in task args
|
||||||
|
bool UpdateArgs(void *stream);
|
||||||
|
|
||||||
|
virtual void *GetAddressPtr() = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Parameter or ValueNode
|
||||||
|
AnfNodeWeakPtr anf_node_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void *args_base_;
|
||||||
|
size_t args_offset_;
|
||||||
|
std::string task_name_;
|
||||||
|
void *device_ptr_{nullptr};
|
||||||
|
void *previous_ptr_{nullptr};
|
||||||
|
};
|
||||||
|
using ZeroCopyTaskPtr = std::shared_ptr<ZeroCopyTask>;
|
||||||
|
|
||||||
|
class ParameterZeroCopyTask : public ZeroCopyTask {
|
||||||
|
public:
|
||||||
|
ParameterZeroCopyTask(const AnfNodeWeakPtr &anf_node, void *args_base, size_t args_offset,
|
||||||
|
const std::string &task_name)
|
||||||
|
: ZeroCopyTask(anf_node, args_base, args_offset, task_name) {}
|
||||||
|
void *GetAddressPtr() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ValueNodeZeroCopyTask : public ZeroCopyTask {
|
||||||
|
public:
|
||||||
|
ValueNodeZeroCopyTask(const AnfNodeWeakPtr &anf_node, void *args_base, size_t args_offset,
|
||||||
|
const std::string &task_name)
|
||||||
|
: ZeroCopyTask(anf_node, args_base, args_offset, task_name) {}
|
||||||
|
void *GetAddressPtr() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update the device address in task without copying data.
|
||||||
|
// Usually we assume that the address in the task is constant.
|
||||||
|
// If the address of graph input changed, we need to copy data of graph input tensor to the address of the task.
|
||||||
|
// In fact, when the operator is executed, the real address is obtained from the task args.
|
||||||
|
// Task args is a secondary pointer, which stores the input address or output address of the operator.
|
||||||
|
// If we can update the input address in task args, we can avoid data copying.
|
||||||
|
class RtModelZeroCopy {
|
||||||
|
public:
|
||||||
|
RtModelZeroCopy() = default;
|
||||||
|
~RtModelZeroCopy() = default;
|
||||||
|
|
||||||
|
// Generate ZeroCopyTasks after the tasks of rtModel is Distributed. (Need to get task args address)
|
||||||
|
bool GenerateZeroCopyTasks(const session::KernelGraph &graph);
|
||||||
|
// Copy device address ptr to task args if the ptr changed.
|
||||||
|
bool UpdateTaskArgs(const session::KernelGraph &graph, void *stream) const;
|
||||||
|
// Check rtModel after update task args. The process of checking is consistent with the process of generating tasks.
|
||||||
|
static bool CheckRtModelValid(const session::KernelGraph &graph);
|
||||||
|
// Release resource after the graph is destructed.
|
||||||
|
void Release(uint32_t graph_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
mindspore::HashMap<uint32_t, std::vector<ZeroCopyTaskPtr>> graph_zero_copy_tasks_;
|
||||||
|
};
|
||||||
|
} // namespace tasksink
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TASKSINK_RTMODEL_ZERO_COPY_H_
|
|
@ -162,13 +162,82 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AddressPtrList TaskGenerator::GetTaskInput(const CNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
auto op_name = common::AnfAlgo::GetCNodeName(node);
|
||||||
|
if (op_name == kAtomicAddrCleanOpName) {
|
||||||
|
LaunchAddrCleanKernel(node, &kernel_inputs);
|
||||||
|
return kernel_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
if (common::AnfAlgo::IsNoneInput(node, i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(node, i);
|
||||||
|
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, input_index_in_graph);
|
||||||
|
AddressPtr input = std::make_shared<Address>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
input->addr = device_address->ptr_;
|
||||||
|
input->size = device_address->size_;
|
||||||
|
|
||||||
|
auto prenode_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(prenode_with_index.first);
|
||||||
|
if (AnfUtils::IsRealCNodeKernel(prenode_with_index.first)) {
|
||||||
|
if (common::AnfAlgo::IsNonTaskOp(prenode_with_index.first->cast<CNodePtr>())) {
|
||||||
|
// use memory offset to implement NonTask Type Split op
|
||||||
|
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
|
||||||
|
// offset is split's output index * split's output size
|
||||||
|
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(split_input0_device_address);
|
||||||
|
input->addr =
|
||||||
|
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
|
||||||
|
MS_LOG(INFO) << "Change " << node->fullname_with_scope() << "'s input " << i << " address to "
|
||||||
|
<< split_input0_device_address->ptr_ << " + " << prenode_with_index.second * input->size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kernel_inputs.push_back(input);
|
||||||
|
}
|
||||||
|
return kernel_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddressPtrList TaskGenerator::GetTaskOutput(const CNodePtr &node) {
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
// No kernel output if output of the cnode is monad, such as LabelSwitch.
|
||||||
|
if (!HasAbstractMonad(node)) {
|
||||||
|
size_t output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
auto it = AnfAlgo::GetOutputAddr(node, i, false);
|
||||||
|
AddressPtr output = std::make_shared<Address>();
|
||||||
|
output->addr = it->ptr_;
|
||||||
|
output->size = it->size_;
|
||||||
|
kernel_outputs.push_back(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kernel_outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddressPtrList TaskGenerator::GetTaskWorkspace(const CNodePtr &node) {
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
||||||
|
auto device_address = AnfAlgo::GetWorkspaceAddr(node, i);
|
||||||
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||||
|
MS_EXCEPTION_IF_NULL(workspace);
|
||||||
|
workspace->addr = device_address->ptr_;
|
||||||
|
workspace->size = device_address->size_;
|
||||||
|
kernel_workspaces.push_back(workspace);
|
||||||
|
}
|
||||||
|
return kernel_workspaces;
|
||||||
|
}
|
||||||
|
|
||||||
bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id,
|
bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id,
|
||||||
std::vector<TaskInfoPtr> *task_info_list) {
|
std::vector<TaskInfoPtr> *task_info_list) {
|
||||||
MS_EXCEPTION_IF_NULL(task_info_list);
|
MS_EXCEPTION_IF_NULL(task_info_list);
|
||||||
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||||
AddressPtrList kernel_inputs;
|
|
||||||
AddressPtrList kernel_workspaces;
|
|
||||||
AddressPtrList kernel_outputs;
|
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr);
|
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
kernel_mod->set_unique_name(anf_node_ptr->UniqueName());
|
kernel_mod->set_unique_name(anf_node_ptr->UniqueName());
|
||||||
|
@ -185,59 +254,16 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_name != kAtomicAddrCleanOpName) {
|
AddressPtrList kernel_inputs;
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(anf_node_ptr);
|
AddressPtrList kernel_workspaces;
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
AddressPtrList kernel_outputs;
|
||||||
if (common::AnfAlgo::IsNoneInput(anf_node_ptr, i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(anf_node_ptr, i);
|
|
||||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, input_index_in_graph);
|
|
||||||
AddressPtr input = std::make_shared<Address>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
|
||||||
input->addr = device_address->ptr_;
|
|
||||||
input->size = device_address->size_;
|
|
||||||
|
|
||||||
auto prenode_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node_ptr, input_index_in_graph);
|
if (op_name == kAtomicAddrCleanOpName) {
|
||||||
MS_EXCEPTION_IF_NULL(prenode_with_index.first);
|
|
||||||
if (AnfUtils::IsRealCNodeKernel(prenode_with_index.first)) {
|
|
||||||
if (common::AnfAlgo::IsNonTaskOp(prenode_with_index.first->cast<CNodePtr>())) {
|
|
||||||
// use memory offset to implement NonTask Type Split op
|
|
||||||
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
|
|
||||||
// offset is split's output index * split's output size
|
|
||||||
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
|
|
||||||
MS_EXCEPTION_IF_NULL(split_input0_device_address);
|
|
||||||
input->addr =
|
|
||||||
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
|
|
||||||
MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to "
|
|
||||||
<< split_input0_device_address->ptr_ << " + " << prenode_with_index.second * input->size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
kernel_inputs.push_back(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
// No kernel output if output of the cnode is monad, such as LabelSwitch.
|
|
||||||
if (!HasAbstractMonad(anf_node_ptr)) {
|
|
||||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(anf_node_ptr);
|
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
|
||||||
auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i, false);
|
|
||||||
AddressPtr output = std::make_shared<Address>();
|
|
||||||
output->addr = it->ptr_;
|
|
||||||
output->size = it->size_;
|
|
||||||
kernel_outputs.push_back(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
|
||||||
auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i);
|
|
||||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
|
||||||
MS_EXCEPTION_IF_NULL(workspace);
|
|
||||||
workspace->addr = device_address->ptr_;
|
|
||||||
workspace->size = device_address->size_;
|
|
||||||
kernel_workspaces.push_back(workspace);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs);
|
LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs);
|
||||||
|
} else {
|
||||||
|
kernel_inputs = GetTaskInput(anf_node_ptr);
|
||||||
|
kernel_workspaces = GetTaskWorkspace(anf_node_ptr);
|
||||||
|
kernel_outputs = GetTaskOutput(anf_node_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||||
|
@ -251,6 +277,7 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
||||||
MS_LOG(ERROR) << "Empty task_info_ptrs.";
|
MS_LOG(ERROR) << "Empty task_info_ptrs.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "Node " << anf_node_ptr->fullname_with_scope() << " get task " << task_info_ptrs.front()->op_name();
|
||||||
debug_info->op_name_ = anf_node_ptr->fullname_with_scope();
|
debug_info->op_name_ = anf_node_ptr->fullname_with_scope();
|
||||||
debug_info->task_num_ = task_info_ptrs.size();
|
debug_info->task_num_ = task_info_ptrs.size();
|
||||||
debug_info->stream_id_ = task_info_ptrs[0]->stream_id();
|
debug_info->stream_id_ = task_info_ptrs[0]->stream_id();
|
||||||
|
|
|
@ -63,6 +63,10 @@ class TaskGenerator {
|
||||||
std::vector<TaskDebugInfoPtr> GetTaskDebugInfo() const { return task_debug_info_list_; }
|
std::vector<TaskDebugInfoPtr> GetTaskDebugInfo() const { return task_debug_info_list_; }
|
||||||
static void DumpTaskInfo(const string &real_filename, const std::vector<TaskDebugInfoPtr> &task_debug_info_list);
|
static void DumpTaskInfo(const string &real_filename, const std::vector<TaskDebugInfoPtr> &task_debug_info_list);
|
||||||
|
|
||||||
|
static AddressPtrList GetTaskInput(const CNodePtr &node);
|
||||||
|
static AddressPtrList GetTaskOutput(const CNodePtr &node);
|
||||||
|
static AddressPtrList GetTaskWorkspace(const CNodePtr &node);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<TaskDebugInfoPtr> task_debug_info_list_;
|
std::vector<TaskDebugInfoPtr> task_debug_info_list_;
|
||||||
static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs);
|
static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs);
|
||||||
|
|
|
@ -97,7 +97,7 @@ bool AscendDeviceContext::PartitionGraph(const FuncGraphPtr &func_graph) const {
|
||||||
RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
|
RunMode AscendDeviceContext::GetRunMode(const FuncGraphPtr &func_graph) const {
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && IsGraphMode() && !IsDynamicShapeGraph(func_graph)) {
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && !IsDynamicShapeGraph(func_graph)) {
|
||||||
return RunMode::kGraphMode;
|
return RunMode::kGraphMode;
|
||||||
} else {
|
} else {
|
||||||
return RunMode::kKernelMode;
|
return RunMode::kKernelMode;
|
||||||
|
|
|
@ -33,6 +33,7 @@ namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
using KernelGraph = mindspore::session::KernelGraph;
|
using KernelGraph = mindspore::session::KernelGraph;
|
||||||
|
|
||||||
|
namespace {
|
||||||
CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
|
CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
|
||||||
size_t node_sizes = kernel_nodes.size();
|
size_t node_sizes = kernel_nodes.size();
|
||||||
if (index >= node_sizes - 1) {
|
if (index >= node_sizes - 1) {
|
||||||
|
@ -162,6 +163,25 @@ void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
||||||
UnfoldRecursiveExecOrder(kernel_graph);
|
UnfoldRecursiveExecOrder(kernel_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EnableGraphInputZeroCopy(const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
// Zero copy is only enabled for PyNative.
|
||||||
|
if (!graph->has_flag(kFlagPyNativeRunInGraph) || !graph->is_graph_run_mode()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &input_nodes = graph->input_nodes();
|
||||||
|
for (const auto &input : input_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
if (AnfAlgo::OutputAddrExist(input, 0)) {
|
||||||
|
auto input_address = AnfAlgo::GetMutableOutputAddr(input, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_address);
|
||||||
|
input_address->set_is_ptr_persisted(false);
|
||||||
|
MS_LOG(INFO) << "Enable zero copy for input " << input->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void AscendGraphExecutor::Initialize() {
|
void AscendGraphExecutor::Initialize() {
|
||||||
res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get());
|
res_manager_ = dynamic_cast<AscendDeviceResManager *>(device_context_->device_res_manager_.get());
|
||||||
MS_EXCEPTION_IF_NULL(res_manager_);
|
MS_EXCEPTION_IF_NULL(res_manager_);
|
||||||
|
@ -217,6 +237,7 @@ void AscendGraphExecutor::PreprocessBeforeRun(const KernelGraphPtr &graph) const
|
||||||
AllocateGraphMemory(NOT_NULL(graph));
|
AllocateGraphMemory(NOT_NULL(graph));
|
||||||
LoadModel(NOT_NULL(graph));
|
LoadModel(NOT_NULL(graph));
|
||||||
AssignOutputNopNodeDeviceAddress(graph, device_context_);
|
AssignOutputNopNodeDeviceAddress(graph, device_context_);
|
||||||
|
EnableGraphInputZeroCopy(graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendGraphExecutor::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
void AscendGraphExecutor::UpdateExecOrder(const KernelGraphPtr &graph) const {
|
||||||
|
|
|
@ -220,12 +220,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
||||||
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
auto hccl_node = anf_node_.lock();
|
||||||
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||||
if (!mutable_workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode) ||
|
auto func_graph = hccl_node->func_graph();
|
||||||
mode == kPynativeMode) {
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
auto graph_run_mode = kernel_graph->is_graph_run_mode();
|
||||||
|
// Not task sink mode.
|
||||||
|
if (!mutable_workspace_size_list_.empty() || hccl_data_type_list_.empty() || !graph_run_mode) {
|
||||||
return mutable_workspace_size_list_;
|
return mutable_workspace_size_list_;
|
||||||
}
|
}
|
||||||
|
// Task sink mode.
|
||||||
mutable_workspace_size_list_.emplace_back(
|
mutable_workspace_size_list_.emplace_back(
|
||||||
hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
|
hccl::HcclAdapter::GetInstance().CalcWorkspaceSize(anf_node_.lock(), hccl_data_type_list_[0]));
|
||||||
return mutable_workspace_size_list_;
|
return mutable_workspace_size_list_;
|
||||||
|
|
|
@ -149,7 +149,8 @@ bool EliminateGraphOutputTransdata::Run(const FuncGraphPtr &func_graph) {
|
||||||
std::vector<PrimitivePtr> return_type = {prim::kPrimMakeTuple};
|
std::vector<PrimitivePtr> return_type = {prim::kPrimMakeTuple};
|
||||||
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(func_graph->output(), 0, false, return_type);
|
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(func_graph->output(), 0, false, return_type);
|
||||||
if (!common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
if (!common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||||
MS_LOG(EXCEPTION) << "Graph output is not a MakeTuple";
|
MS_LOG(INFO) << "Graph output is not a MakeTuple";
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transdata_ref_count = GetTransdataRefCount(func_graph);
|
auto transdata_ref_count = GetTransdataRefCount(func_graph);
|
||||||
|
|
|
@ -695,8 +695,12 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
|
||||||
opt::BackendCommonOptimization(root_graph);
|
opt::BackendCommonOptimization(root_graph);
|
||||||
root_graph->SetInputNodes();
|
root_graph->SetInputNodes();
|
||||||
|
|
||||||
auto graph_id = CompileGraphImpl(root_graph, device_context);
|
GraphId graph_id = 0;
|
||||||
|
if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
|
||||||
|
graph_id = CompileGraphImpl(root_graph, device_context);
|
||||||
|
} else {
|
||||||
|
graph_id = root_graph->graph_id();
|
||||||
|
}
|
||||||
// Set summary nodes for all graphs.
|
// Set summary nodes for all graphs.
|
||||||
session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs);
|
session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs);
|
||||||
|
|
||||||
|
@ -704,6 +708,7 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
|
||||||
// for ascend mindRT.
|
// for ascend mindRT.
|
||||||
session_->DumpGraphs(all_graphs);
|
session_->DumpGraphs(all_graphs);
|
||||||
|
|
||||||
|
if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
|
||||||
// Cache the backend graph output nodes to front nodes with output index.
|
// Cache the backend graph output nodes to front nodes with output index.
|
||||||
auto output = func_graph->output();
|
auto output = func_graph->output();
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
@ -711,6 +716,14 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
|
||||||
MS_EXCEPTION_IF_NULL(backend_node);
|
MS_EXCEPTION_IF_NULL(backend_node);
|
||||||
root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
|
root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
|
||||||
AnfAlgo::UpdateGraphValidRefPair(root_graph);
|
AnfAlgo::UpdateGraphValidRefPair(root_graph);
|
||||||
|
} else {
|
||||||
|
for (auto &node : root_graph->execution_order()) {
|
||||||
|
if (common::AnfAlgo::IsControlOpExecInBackend(node)) {
|
||||||
|
root_graph->set_flag(kFlagsIsCutGraph, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
root_graph->set_front_outputs({func_graph->output()});
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
|
MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
|
||||||
return graph_id;
|
return graph_id;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "include/common/utils/convert_utils.h"
|
#include "include/common/utils/convert_utils.h"
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "include/common/utils/parallel_context.h"
|
||||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
#include "runtime/graph_scheduler/device_tensor_store.h"
|
#include "runtime/graph_scheduler/device_tensor_store.h"
|
||||||
#include "runtime/device/ms_device_shape_transfer.h"
|
#include "runtime/device/ms_device_shape_transfer.h"
|
||||||
|
@ -330,4 +331,27 @@ void GraphAdapter::ReplaceGraphParameterProperties(const KernelGraphPtr &graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GraphAdapter::IsAutoParallel() {
|
||||||
|
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||||
|
auto parallel_mode = parallel_context->parallel_mode();
|
||||||
|
return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
bool pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||||
|
if (!pynative_mode) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||||
|
auto is_cut_graph = std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
|
||||||
|
return common::AnfAlgo::IsControlOpExecInBackend(node);
|
||||||
|
});
|
||||||
|
|
||||||
|
return !IsAutoParallel() && !is_cut_graph && !func_graph->has_flag(kFlagIsDynamicStructure);
|
||||||
|
}
|
||||||
} // namespace mindspore::pynative
|
} // namespace mindspore::pynative
|
||||||
|
|
|
@ -33,6 +33,8 @@ class GraphAdapter {
|
||||||
static void RemoveUnusedValueNodes(const KernelGraphPtr &graph);
|
static void RemoveUnusedValueNodes(const KernelGraphPtr &graph);
|
||||||
static void HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &tensors,
|
static void HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &tensors,
|
||||||
const std::vector<device::DeviceContext *> &device_contexts);
|
const std::vector<device::DeviceContext *> &device_contexts);
|
||||||
|
static bool PyNativeEnableTaskSink(const FuncGraphPtr &func_graph);
|
||||||
|
static bool IsAutoParallel();
|
||||||
};
|
};
|
||||||
} // namespace mindspore::pynative
|
} // namespace mindspore::pynative
|
||||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_GRAPH_ADAPTER_H_
|
||||||
|
|
|
@ -160,6 +160,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_device_address.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_device_address.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_pool.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_memory_pool.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
|
||||||
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_res_manager.cc"
|
||||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"
|
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_kernel_executor.cc"
|
||||||
|
|
|
@ -27,6 +27,10 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
|
||||||
uint32_t graph_id) {
|
uint32_t graph_id) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AddressPtrList TaskGenerator::GetTaskInput(const CNodePtr &node) { return {}; }
|
||||||
|
AddressPtrList TaskGenerator::GetTaskOutput(const CNodePtr &node) { return {}; }
|
||||||
|
AddressPtrList TaskGenerator::GetTaskWorkspace(const CNodePtr &node) { return {}; }
|
||||||
} // namespace tasksink
|
} // namespace tasksink
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
void DataDumper::LoadDumpInfo() {}
|
void DataDumper::LoadDumpInfo() {}
|
||||||
|
|
Loading…
Reference in New Issue