forked from mindspore-Ecosystem/mindspore
Move ascend dependent functions to ascend kernel runtime.
This commit is contained in:
parent
9018737e99
commit
5ac60ff202
|
@ -178,10 +178,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
#endif
|
||||
// alloc mem
|
||||
MemoryAlloc(root_graph.get());
|
||||
// task generate
|
||||
GenerateTaskInfo(root_graph);
|
||||
// load task into device
|
||||
LoadTask(root_graph);
|
||||
// generate and load task into device
|
||||
Load(root_graph);
|
||||
DumpAllGraphs(all_graphs);
|
||||
// return the root_graph id to backend
|
||||
auto graph_id = root_graph->graph_id();
|
||||
|
@ -258,10 +256,8 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
} else {
|
||||
// alloc memory, including static memory and dynamic memory
|
||||
MemoryAlloc(graph.get());
|
||||
// generate task info for task sink mode
|
||||
GenerateTaskInfo(graph);
|
||||
// load task info to device if it is sink mode
|
||||
LoadTask(graph);
|
||||
// generate and load task info to device if it is sink mode
|
||||
Load(graph);
|
||||
}
|
||||
// sync the inital const tensor to device
|
||||
SyncInitialTenosrToDevice();
|
||||
|
@ -322,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
|
|||
#endif
|
||||
{
|
||||
// run task on device
|
||||
ExecTask(kernel_graph);
|
||||
Execute(kernel_graph);
|
||||
}
|
||||
// summary
|
||||
Summary(kernel_graph.get());
|
||||
|
@ -554,30 +550,19 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
|
|||
runtime_instance->RunOpClearMemory(kernel_graph);
|
||||
}
|
||||
|
||||
void AscendSession::GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
bool ret_ok = runtime_instance->GenTask(kernel_graph.get());
|
||||
if (!ret_ok) {
|
||||
MS_LOG(EXCEPTION) << "Generate task error!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
bool ret_ok = runtime_instance->LoadTask(kernel_graph.get());
|
||||
bool ret_ok = runtime_instance->Load(kernel_graph.get());
|
||||
if (!ret_ok) {
|
||||
MS_LOG(EXCEPTION) << "Load task error!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
|
|
|
@ -81,9 +81,8 @@ class AscendSession : public SessionBasic {
|
|||
void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
|
||||
void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -454,19 +454,31 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
SetContext();
|
||||
if (graph == nullptr) {
|
||||
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
|
||||
}
|
||||
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
|
||||
bool AscendKernelRuntime::Load(session::KernelGraph *graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (!is_task_sink) {
|
||||
return true;
|
||||
}
|
||||
if (!GenTask(graph)) {
|
||||
return false;
|
||||
}
|
||||
if (!LoadTask(graph)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
SetContext();
|
||||
if (graph == nullptr) {
|
||||
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
|
||||
}
|
||||
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_REUSE)) {
|
||||
// Get normal graph ir for memreuse
|
||||
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
|
||||
|
@ -517,13 +529,6 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
|
||||
}
|
||||
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (!is_task_sink) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (GraphWithEmptyTaskList(graph)) {
|
||||
MS_LOG(WARNING) << "LoadTask end, task list is empty";
|
||||
return true;
|
||||
|
@ -604,6 +609,36 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
|
|||
}
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
|
||||
bool ret = false;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (is_task_sink) {
|
||||
ret = RunTask(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
|
||||
#else
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
||||
SetContext();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -40,10 +40,12 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
~AscendKernelRuntime() override;
|
||||
bool Init() override;
|
||||
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
|
||||
bool LoadData(session::KernelGraph *graph, Debugger *debugger) override;
|
||||
bool GenTask(const session::KernelGraph *graph) override;
|
||||
bool RunTask(const session::KernelGraph *graph) override;
|
||||
bool LoadTask(const session::KernelGraph *graph) override;
|
||||
bool LoadData(session::KernelGraph *graph, Debugger *debugger);
|
||||
bool GenTask(const session::KernelGraph *graph);
|
||||
bool LoadTask(const session::KernelGraph *graph);
|
||||
bool RunTask(const session::KernelGraph *graph);
|
||||
bool Load(session::KernelGraph *graph) override;
|
||||
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
|
||||
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
|
||||
const std::unordered_set<ValueNodePtr> &value_nodes,
|
||||
const std::vector<CNodePtr> &execution_order) override;
|
||||
|
|
|
@ -40,37 +40,8 @@ KernelRuntime::~KernelRuntime() {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
|
||||
bool ret = false;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (is_task_sink) {
|
||||
ret = RunTask(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
|
||||
#else
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
#endif
|
||||
return ret;
|
||||
}
|
||||
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; }
|
||||
|
||||
// for D to impl
|
||||
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
|
||||
if (graph != nullptr) {
|
||||
return true;
|
||||
|
@ -78,37 +49,6 @@ bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *d
|
|||
return false;
|
||||
}
|
||||
|
||||
// for D to impl
|
||||
bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
|
||||
if (graph != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// for D to impl
|
||||
bool KernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
if (graph != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool KernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
||||
if (graph != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// for D to impl
|
||||
bool KernelRuntime::RunTask(const session::KernelGraph *graph) {
|
||||
if (graph != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::OutputAddrExist(kernel, index)) {
|
||||
|
|
|
@ -58,11 +58,9 @@ class KernelRuntime {
|
|||
void RunOpClearMemory(const session::KernelGraph *graph);
|
||||
bool DumpDataEnabled();
|
||||
bool DumpDataEnabledIteration();
|
||||
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr);
|
||||
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
|
||||
virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger);
|
||||
virtual bool RunTask(const session::KernelGraph *graph);
|
||||
virtual bool GenTask(const session::KernelGraph *graph);
|
||||
virtual bool Load(session::KernelGraph *graph);
|
||||
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0;
|
||||
bool LaunchKernel(const session::KernelGraph *graph);
|
||||
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
|
||||
const AddressPtrList &kernel_outputs,
|
||||
|
@ -80,7 +78,6 @@ class KernelRuntime {
|
|||
#ifdef ENABLE_DUMP_E2E
|
||||
DumpConfPtr GetDumpConf();
|
||||
#endif
|
||||
virtual bool LoadTask(const session::KernelGraph *graph);
|
||||
// for GPU and D to impl
|
||||
virtual void ReleaseDeviceRes() {}
|
||||
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
|
||||
|
|
Loading…
Reference in New Issue