!40704 RunOp compile refactor

Merge pull request !40704 from caifubi/master-pynative-op-compile-refactor
This commit is contained in:
i-robot 2022-08-25 10:58:18 +00:00 committed by Gitee
commit cd8b826fcf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 694 additions and 613 deletions

View File

@ -1387,5 +1387,27 @@ void AnfRuntimeAlgorithm::UpdateInternalParameterShape(
} }
} }
} }
void AnfRuntimeAlgorithm::AddOutInRefToGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &cnode : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
for (const auto &ref : kernel_info->out_in_ref_map()) {
size_t output_index = ref.first;
size_t input_index = ref.second;
auto final_pair = std::make_pair(cnode, output_index);
auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
<< ", output index: " << final_pair.second << " to input "
<< origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
// Add to graph only if the input is not a monad.
if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
graph->AddRefCorrespondPairs(final_pair, origin_pair);
}
}
}
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -180,6 +180,8 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
static void UpdateInternalParameterShape(const std::map<size_t, std::vector<AnfNodeWeakPtr>> &internal_parameters, static void UpdateInternalParameterShape(const std::map<size_t, std::vector<AnfNodeWeakPtr>> &internal_parameters,
const CNodePtr &cnode); const CNodePtr &cnode);
static bool IsShapesDynamic(const std::vector<ShapeVector> &shapes); static bool IsShapesDynamic(const std::vector<ShapeVector> &shapes);
static void AddOutInRefToGraph(const KernelGraphPtr &graph);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -23,8 +23,8 @@
#include "backend/graph_compiler/transform.h" #include "backend/graph_compiler/transform.h"
#include "backend/common/session/session_factory.h" #include "backend/common/session/session_factory.h"
#include "runtime/pynative/op_executor.h" #include "runtime/pynative/op_executor.h"
#include "runtime/pynative/op_compiler.h"
#include "backend/common/optimizer/helper.h" #include "backend/common/optimizer/helper.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/action.h" #include "pipeline/jit/action.h"
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "ir/anf.h" #include "ir/anf.h"
@ -1388,47 +1388,8 @@ std::shared_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
strategy); strategy);
} }
std::shared_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo( void MindRTBackend::EraseSingleOpCache(const GraphInfo &graph_info) {
const ActorInfo &actor_info, const std::vector<int64_t> &tensors_mask, const std::vector<TensorPtr> &input_tensors, pynative::OpCompiler::GetInstance().ClearOpCache(graph_info);
bool need_erase) {
std::vector<KernelGraphPtr> graphs;
std::vector<DeviceContext *> device_contexts;
runtime::KernelMapPosition outputs_order;
size_t position = 0;
MS_EXCEPTION_IF_NULL(graph_compiler_);
for (const auto &graph_info_to_context : graph_info_to_device_context_) {
const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
MS_EXCEPTION_IF_NULL(graph);
(void)graphs.emplace_back(graph);
(void)device_contexts.emplace_back(graph_info_to_context.second);
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
for (const auto &output : outputs) {
if (outputs_order.count(output) == 0) {
outputs_order[output] = {position++};
} else {
(void)outputs_order[output].emplace_back(position++);
}
}
}
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(&tensors_mask));
std::vector<std::vector<TensorPtr> *> input_tensors_list(
1, const_cast<std::vector<tensor::TensorPtr> *>(&input_tensors));
auto parser = std::make_shared<ControlNodeParser>();
return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
outputs_order, 0, actor_info, need_erase,
runtime::GraphExecutionStrategy::kStep);
}
void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id());
actor_to_graph_compiler_info_.erase(actor_info);
(void)graph_info_to_device_context_.erase(graph_info);
} }
void MindRTBackend::ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors) { void MindRTBackend::ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors) {
@ -1440,13 +1401,11 @@ void MindRTBackend::CompileSingleOpGraphs(const std::vector<std::shared_ptr<runt
return; return;
} }
std::vector<KernelGraphPtr> graphs; std::vector<KernelGraphPtr> graphs;
std::vector<GraphCompilerInfo *> graph_compiler_infos;
for (const auto &task : build_tasks) { for (const auto &task : build_tasks) {
MS_EXCEPTION_IF_NULL(task); MS_EXCEPTION_IF_NULL(task);
const auto &context = task->context(); const auto &context = task->context();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
graphs.push_back(context->graph()); graphs.push_back(context->graph());
graph_compiler_infos.push_back(context->graph_compiler_info());
} }
MS_EXCEPTION_IF_NULL(build_tasks[0]); MS_EXCEPTION_IF_NULL(build_tasks[0]);
auto &task_context = build_tasks[0]->context(); auto &task_context = build_tasks[0]->context();
@ -1456,11 +1415,7 @@ void MindRTBackend::CompileSingleOpGraphs(const std::vector<std::shared_ptr<runt
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, task_context->is_pynative_infer()); ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, task_context->is_pynative_infer());
auto device_context = task_context->device_context(); auto device_context = task_context->device_context();
graph_compiler_->BuildSingleOpGraphs(graphs, device_context); pynative::OpCompiler::BatchBuild(graphs, device_context);
for (const auto &graph_compile_info : graph_compiler_infos) {
MS_EXCEPTION_IF_NULL(graph_compile_info);
graph_compile_info->input_tensors_.clear();
}
} }
void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) { void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) {
@ -1524,26 +1479,22 @@ void MindRTBackend::BatchBuildCallback() {
} }
} }
void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info, void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
const OpCompilerInfoPtr &op_compiler_info,
const session::BackendOpRunInfoPtr &op_run_info) { const session::BackendOpRunInfoPtr &op_run_info) {
MS_EXCEPTION_IF_NULL(graph_compiler_info); MS_EXCEPTION_IF_NULL(op_compiler_info);
// Fetch outputs. const auto &graph = op_compiler_info->graph_;
if (graph_compiler_info->graphs_.empty()) {
MS_LOG(EXCEPTION) << "No graph found, op:" << graph_compiler_info->name_;
}
const auto &graph = graph_compiler_info->graphs_.front();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id()); const auto &output_nodes = op_compiler_info->graph_output_nodes_;
runtime::UpdateDeviceAddress(graph, GetTensorWithoutValueMask(op_run_info), runtime::UpdateDeviceAddress(graph, GetTensorWithoutValueMask(op_run_info), op_compiler_info->device_context_);
graph_compiler_info->device_contexts_.front());
UpdateOutput(output_nodes, outputs); UpdateOutput(output_nodes, outputs);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER); auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
auto run_op_context = std::make_shared<runtime::OpTaskContext>( auto run_op_context = std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, output_nodes, op_run_info,
graph_compiler_info, graph, output_nodes, op_run_info, graph_compiler_info->device_contexts_.front(), infer_flag); op_compiler_info->device_context_, infer_flag);
// Save build task and run task. // Save build task and run task.
std::promise<bool> promise; std::promise<bool> promise;
@ -1565,28 +1516,28 @@ void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
} }
} }
void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph_compiler_info, void MindRTBackend::RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) { const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(op_run_info); MS_EXCEPTION_IF_NULL(op_run_info);
MS_EXCEPTION_IF_NULL(graph_compiler_info); MS_EXCEPTION_IF_NULL(op_compiler_info);
// Fetch outputs. // Fetch outputs.
const auto &graph = graph_compiler_info->graphs_.front(); const auto &graph = op_compiler_info->graph_;
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(graph_compiler_); MS_EXCEPTION_IF_NULL(graph_compiler_);
const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id()); const auto &output_nodes = op_compiler_info->graph_output_nodes_;
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
auto device_context = graph_compiler_info->device_contexts_.front(); auto device_context = op_compiler_info->device_context_;
auto &op_executor = runtime::OpExecutor::GetInstance(); auto &op_executor = runtime::OpExecutor::GetInstance();
bool is_dynamic_shape = bool is_dynamic_shape =
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input; op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
bool async_exec_disabled = is_dynamic_shape || graph_compiler_info->need_erase_ || bool async_exec_disabled = is_dynamic_shape || op_compiler_info->need_erase_ ||
!op_run_info->base_op_run_info.lazy_build || OpInBlackList(op_run_info) || !op_run_info->base_op_run_info.lazy_build || OpInBlackList(op_run_info) ||
GetExecutionMode() == kGraphMode || EnablePyNativeSyncRunning(); GetExecutionMode() == kGraphMode || EnablePyNativeSyncRunning();
if (!async_exec_disabled) { if (!async_exec_disabled) {
MS_LOG(DEBUG) << "Async exec enabled, op:" << op_run_info->base_op_run_info.op_name; MS_LOG(DEBUG) << "Async exec enabled, op:" << op_run_info->base_op_run_info.op_name;
DispatchOpTask(single_op_cache_hit, outputs, graph_compiler_info, op_run_info); DispatchOpTask(single_op_cache_hit, outputs, op_compiler_info, op_run_info);
return; return;
} }
@ -1595,7 +1546,7 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph
WaitTaskFinish(); WaitTaskFinish();
} }
if (!single_op_cache_hit) { if (!single_op_cache_hit) {
CompileSingleOpGraph(graph, device_context, graph_compiler_info); CompileSingleOpGraph(graph, device_context);
} }
auto tensors_without_value_mask = GetTensorWithoutValueMask(op_run_info); auto tensors_without_value_mask = GetTensorWithoutValueMask(op_run_info);
runtime::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context); runtime::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context);
@ -1609,8 +1560,8 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph
if (op_run_info->base_op_run_info.has_dynamic_output) { if (op_run_info->base_op_run_info.has_dynamic_output) {
UpdateOutputAbstract(graph, op_run_info); UpdateOutputAbstract(graph, op_run_info);
} }
if (graph_compiler_info->need_erase_) { if (op_compiler_info->need_erase_) {
EraseSingleOpCache(graph_compiler_info->name_, op_run_info->base_op_run_info.graph_info, graph); EraseSingleOpCache(op_run_info->base_op_run_info.graph_info);
} }
} }
@ -1624,22 +1575,14 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
device_context->Initialize(); device_context->Initialize();
bool single_op_cache_hit = true; bool single_op_cache_hit = true;
auto graph_id = graph_compiler_->CompileGraph(op_run_info, &single_op_cache_hit, device_context); auto op_compiler_info =
std::string actor_info = std::to_string(graph_id) + "_" + op_run_info->base_op_run_info.op_name; pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_context);
if (runtime::OpExecutor::GetInstance().ActorInQueue(actor_info)) { MS_EXCEPTION_IF_NULL(op_compiler_info);
if (runtime::OpExecutor::GetInstance().ActorInQueue(op_compiler_info->graph_id_)) {
WaitTaskFinish(); WaitTaskFinish();
} }
GraphCompilerInfo *graph_compiler_info_ptr; if (!single_op_cache_hit) {
if (single_op_cache_hit) {
auto iter = actor_to_graph_compiler_info_.find(actor_info);
if (iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
}
graph_compiler_info_ptr = iter->second.get();
} else {
graph_info_to_device_context_.clear();
graph_info_to_device_context_[op_run_info->base_op_run_info.graph_info] = device_context;
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE); bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
@ -1647,7 +1590,7 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
bool is_dynamic_shape = bool is_dynamic_shape =
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input; op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
if (is_dynamic_shape) { if (is_dynamic_shape) {
const auto &graph = graph_compiler_->Fetch(op_run_info->base_op_run_info.graph_info); const auto &graph = op_compiler_info->graph_;
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
graph->UpdateGraphDynamicAttr(); graph->UpdateGraphDynamicAttr();
// Dynamic shape but select static op, must no cache // Dynamic shape but select static op, must no cache
@ -1655,26 +1598,16 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
enable_cache = false; enable_cache = false;
} }
} }
auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, op_run_info->base_op_run_info.input_mask, op_compiler_info->need_erase_ = !enable_cache;
op_run_info->base_op_run_info.input_tensor, !enable_cache);
graph_compiler_info_ptr = graph_compiler_info.get();
auto ret = actor_to_graph_compiler_info_.try_emplace(actor_info, std::move(graph_compiler_info));
if (!ret.second) {
MS_LOG(WARNING) << "ActorInfo:" << actor_info << " already exist in the map.";
}
} }
RunOpImpl(single_op_cache_hit, graph_compiler_info_ptr, op_run_info, outputs); RunOpImpl(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
} }
void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context, void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
GraphCompilerInfo *graph_compiler_info) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context); MS_EXCEPTION_IF_NULL(device_context);
graph_compiler_->BuildSingleOpGraphs({graph}, device_context); pynative::OpCompiler::BatchBuild({graph}, device_context);
MS_EXCEPTION_IF_NULL(graph_compiler_info);
graph_compiler_info->input_tensors_.clear();
} }
void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) { void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {

View File

@ -34,6 +34,7 @@
#include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context.h"
#include "runtime/graph_scheduler/graph_scheduler.h" #include "runtime/graph_scheduler/graph_scheduler.h"
#include "runtime/pynative/op_task.h" #include "runtime/pynative/op_task.h"
#include "runtime/pynative/op_compiler.h"
#include "include/backend/visible.h" #include "include/backend/visible.h"
namespace mindspore { namespace mindspore {
@ -136,8 +137,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode); void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode);
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode. // CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context, void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
GraphCompilerInfo *graph_compiler_info) const;
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode. // Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks); void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
@ -153,27 +153,21 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
// Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode. // Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph); std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
// Construct the GraphCompilerInfo by the compilation results of graph, used in PyNative mode.
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const ActorInfo &actor_info,
const std::vector<int64_t> &tensors_mask,
const std::vector<TensorPtr> &input_tensors,
bool need_erase);
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info); void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing, // In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
// so the latest single op cache should be erased when cache list size exceeds threshold value. // so the latest single op cache should be erased when cache list size exceeds threshold value.
void EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info, const KernelGraphPtr &graph); void EraseSingleOpCache(const GraphInfo &graph_info);
// Execute OpBuildTask and OpRunTask when the OpExecutor queue is full in PyNative mode. // Execute OpBuildTask and OpRunTask when the OpExecutor queue is full in PyNative mode.
void BatchBuildCallback(); void BatchBuildCallback();
// Run op or dispatch build task and run task. // Run op or dispatch build task and run task.
void RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph_compiler_info, void RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs); const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs);
// Dispatch task and execute the task in another thread. // Dispatch task and execute the task in another thread.
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info, void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, const OpCompilerInfoPtr &op_compiler_info,
const session::BackendOpRunInfoPtr &op_run_info); const session::BackendOpRunInfoPtr &op_run_info);
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,

View File

@ -23,8 +23,8 @@
#include "frontend/optimizer/ad/grad.h" #include "frontend/optimizer/ad/grad.h"
#include "pipeline/jit/pass.h" #include "pipeline/jit/pass.h"
#include "runtime/pynative/op_executor.h" #include "runtime/pynative/op_executor.h"
#include "runtime/pynative/op_compiler.h"
#include "ir/cell.h" #include "ir/cell.h"
#include "ir/func_graph_cloner.h"
namespace mindspore::pynative { namespace mindspore::pynative {
PyNativeExecutorPtr PyNativeExecutor::executor_ = nullptr; PyNativeExecutorPtr PyNativeExecutor::executor_ = nullptr;
@ -139,6 +139,7 @@ void PyNativeExecutor::set_kernel_build_server_dir(const py::object &kernel_buil
void PyNativeExecutor::ClearRes() { void PyNativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res"; MS_LOG(DEBUG) << "Clear all res";
runtime::OpExecutor::GetInstance().Reset(); runtime::OpExecutor::GetInstance().Reset();
pynative::OpCompiler::GetInstance().ClearAllCache();
// Maybe exit in runop step // Maybe exit in runop step
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();

View File

@ -3,7 +3,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc" "memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc" "ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc"
"common_somas_allocator.cc" "common_somas_allocator.cc" "device_address_utils.cc"
) )
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC) if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)

View File

@ -0,0 +1,321 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/device_address_utils.h"
#include <string>
#include <map>
#include <vector>
#include "ir/tensor.h"
#include "runtime/device/ms_device_shape_transfer.h"
namespace mindspore {
using tensor::TensorPtr;
namespace runtime {
// Whether device address of anf node is valid and device address type
// is consistent with device type, for example, device address type
// DeviceType::kGPU should be used on GPU device
bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
if (AnfAlgo::OutputAddrExist(node, index)) {
const auto &address = AnfAlgo::GetOutputAddr(node, index, false);
MS_EXCEPTION_IF_NULL(address);
return address->GetDeviceType() == device_context->GetDeviceType();
}
return false;
}
void DeviceAddressUtils::CreateParameterDeviceAddress(const DeviceContext *device_context,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
const std::vector<bool> &graph_valid_input = graph->valid_inputs();
(void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
// Anf nodes which need create device address.
std::vector<AnfNodePtr> nodes_list;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
AnfNodePtr item = graph_inputs[i];
MS_EXCEPTION_IF_NULL(item);
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
continue;
}
if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
for (const auto &out : outs) {
MS_EXCEPTION_IF_NULL(out);
if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
continue;
}
nodes_list.push_back(out);
}
}
if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
continue;
}
nodes_list.push_back(item);
}
// Create device address for anf node in nodes_list
for (const auto &item : nodes_list) {
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) {
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
if (output_type_id == kTypeUnknown) {
output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
trans::GetRuntimePaddingShape(item, index));
device_address->set_from_persistent_mem(item->isa<Parameter>());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
<< " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, index, item.get());
}
}
}
void DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *device_context,
const ValuePtr &node_value, size_t output_idx,
const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(node_value);
MS_EXCEPTION_IF_NULL(value_node);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::vector<TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (const auto &tensor : tensors) {
if (tensor == nullptr) {
MS_LOG(WARNING) << "Tensor is null";
return;
}
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
if (output_type_id == kTypeUnknown) {
output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
}
}
void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *device_context,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (NodeDeviceAddressExist(device_context, value_node, 0)) {
continue;
}
const auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
} else if (node_value->isa<StringImm>()) {
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
auto address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT,
kNumberTypeUInt8, ShapeVector());
MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
<< " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
}
}
}
void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *device_context,
const KernelGraphPtr &graph, bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
continue;
}
bool is_from_persistent_mem =
(is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
for (size_t i = 0; i < output_size; ++i) {
if (AnfAlgo::OutputAddrExist(kernel, i)) {
continue;
}
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i));
if (is_from_persistent_mem) {
device_address->set_from_persistent_mem(true);
}
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
<< " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
}
}
}
void DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
continue;
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
break;
}
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
kTypeUnknown, ShapeVector());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
<< " addr:" << device_address;
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
}
}
}
void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
// Collect the inplace groups.
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
continue;
}
auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(primitive);
auto inplace_group_attr = primitive->GetAttr("inplace_group");
MS_EXCEPTION_IF_NULL(inplace_group_attr);
auto group_id = GetValue<uint32_t>(inplace_group_attr);
(void)inplace_groups[group_id].emplace_back(kernel);
}
const size_t kMinInplaceGroupSize = 2;
for (const auto &inplace_group : inplace_groups) {
auto &group_nodes = inplace_group.second;
if (group_nodes.size() < kMinInplaceGroupSize) {
continue;
}
// Get the device address of the first node in the inplace group.
auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
MS_EXCEPTION_IF_NULL(node_primitive);
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
MS_EXCEPTION_IF_NULL(device_address);
// Update the device address of other nodes using device address of the first node in the inplace group.
for (size_t i = 1; i < group_nodes.size(); ++i) {
auto &group_node = group_nodes[i];
auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
MS_EXCEPTION_IF_NULL(prim);
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
// Update the reference count of device address.
device_address->IncreaseOriginalRefCount();
device_address->ResetRefCount();
}
}
}
void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
const session::AnfWithOutIndex &origin_pair) {
MS_EXCEPTION_IF_NULL(cur_pair.first);
MS_EXCEPTION_IF_NULL(origin_pair.first);
auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
MS_EXCEPTION_IF_NULL(origin_node_output_addr);
auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_pair.second, cur_pair.first.get());
// Update the reference count of device address.
cur_node_output_addr->DecreaseOriginalRefCount();
cur_node_output_addr->ResetRefCount();
origin_node_output_addr->IncreaseOriginalRefCount();
origin_node_output_addr->ResetRefCount();
} else {
MS_LOG(INFO) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
}
}
void DeviceAddressUtils::UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
for (auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
if (output_num == 0) {
MS_LOG(DEBUG) << "This kernel has no output size.";
continue;
}
for (size_t i = 0; i < output_num; ++i) {
session::AnfWithOutIndex out_pair(kernel, i);
if (graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
UpdateDeviceAddress(out_pair, origin_pair);
}
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
#include "runtime/hardware/device_context.h"
namespace mindspore {
using device::DeviceContext;
namespace runtime {
// Extract the methods related to DeviceAddress in GraphCompiler to the DeviceAddressUtils class.
class DeviceAddressUtils {
public:
static void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
static void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
size_t output_idx, const ValueNodePtr &value_node);
static void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
static void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
bool is_gradient_out);
static void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
static void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph);
static void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
const session::AnfWithOutIndex &origin_pair);
static void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph);
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_

View File

@ -21,6 +21,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include "runtime/graph_scheduler/graph_scheduler.h" #include "runtime/graph_scheduler/graph_scheduler.h"
#include "runtime/device/device_address_utils.h"
#include "runtime/pynative/op_executor.h" #include "runtime/pynative/op_executor.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
#include "runtime/device/ms_device_shape_transfer.h" #include "runtime/device/ms_device_shape_transfer.h"
@ -51,294 +52,6 @@
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
namespace { namespace {
// Whether device address of anf node is valid and device address type
// is consistent with device type, for example, device address type
// DeviceType::kGPU should be used on GPU device
bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
if (AnfAlgo::OutputAddrExist(node, index)) {
const auto &address = AnfAlgo::GetOutputAddr(node, index, false);
MS_EXCEPTION_IF_NULL(address);
return address->GetDeviceType() == device_context->GetDeviceType();
}
return false;
}
void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
const std::vector<bool> &graph_valid_input = graph->valid_inputs();
(void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
// Anf nodes which need create device address.
std::vector<AnfNodePtr> nodes_list;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
AnfNodePtr item = graph_inputs[i];
MS_EXCEPTION_IF_NULL(item);
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
continue;
}
if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
for (const auto &out : outs) {
MS_EXCEPTION_IF_NULL(out);
if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
continue;
}
nodes_list.push_back(out);
}
}
if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
continue;
}
nodes_list.push_back(item);
}
// Create device address for anf node in nodes_list
for (const auto &item : nodes_list) {
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) {
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
if (output_type_id == kTypeUnknown) {
output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
trans::GetRuntimePaddingShape(item, index));
device_address->set_from_persistent_mem(item->isa<Parameter>());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
<< " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, index, item.get());
}
}
}
void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
size_t output_idx, const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(node_value);
MS_EXCEPTION_IF_NULL(value_node);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::vector<TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (const auto &tensor : tensors) {
if (tensor == nullptr) {
MS_LOG(WARNING) << "Tensor is null";
return;
}
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
if (output_type_id == kTypeUnknown) {
output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
}
}
void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (NodeDeviceAddressExist(device_context, value_node, 0)) {
continue;
}
const auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
} else if (node_value->isa<StringImm>()) {
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
auto address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT,
kNumberTypeUInt8, ShapeVector());
MS_EXCEPTION_IF_NULL(address);
address->set_from_persistent_mem(true);
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
<< " addr:" << address;
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
}
}
}
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
bool is_gradient_out) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
continue;
}
bool is_from_persistent_mem =
(is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
for (size_t i = 0; i < output_size; ++i) {
if (AnfAlgo::OutputAddrExist(kernel, i)) {
continue;
}
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i));
if (is_from_persistent_mem) {
device_address->set_from_persistent_mem(true);
}
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
<< " addr:" << device_address;
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
}
}
}
void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
continue;
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
break;
}
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
kTypeUnknown, ShapeVector());
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
<< " addr:" << device_address;
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
}
}
}
void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
// Collect the inplace groups.
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
continue;
}
auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(primitive);
auto inplace_group_attr = primitive->GetAttr("inplace_group");
MS_EXCEPTION_IF_NULL(inplace_group_attr);
auto group_id = GetValue<uint32_t>(inplace_group_attr);
(void)inplace_groups[group_id].emplace_back(kernel);
}
const size_t kMinInplaceGroupSize = 2;
for (const auto &inplace_group : inplace_groups) {
auto &group_nodes = inplace_group.second;
if (group_nodes.size() < kMinInplaceGroupSize) {
continue;
}
// Get the device address of the first node in the inplace group.
auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
MS_EXCEPTION_IF_NULL(node_primitive);
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
MS_EXCEPTION_IF_NULL(device_address);
// Update the device address of other nodes using device address of the first node in the inplace group.
for (size_t i = 1; i < group_nodes.size(); ++i) {
auto &group_node = group_nodes[i];
auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
MS_EXCEPTION_IF_NULL(prim);
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
// Update the reference count of device address.
device_address->IncreaseOriginalRefCount();
device_address->ResetRefCount();
}
}
}
void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair, const session::AnfWithOutIndex &origin_pair) {
MS_EXCEPTION_IF_NULL(cur_pair.first);
MS_EXCEPTION_IF_NULL(origin_pair.first);
auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
MS_EXCEPTION_IF_NULL(origin_node_output_addr);
auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_pair.second, cur_pair.first.get());
// Update the reference count of device address.
cur_node_output_addr->DecreaseOriginalRefCount();
cur_node_output_addr->ResetRefCount();
origin_node_output_addr->IncreaseOriginalRefCount();
origin_node_output_addr->ResetRefCount();
} else {
MS_LOG(INFO) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
<< ", index is " << cur_pair.second;
}
}
void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
for (auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
if (output_num == 0) {
MS_LOG(DEBUG) << "This kernel has no output size.";
continue;
}
for (size_t i = 0; i < output_num; ++i) {
session::AnfWithOutIndex out_pair(kernel, i);
if (graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
UpdateDeviceAddress(out_pair, origin_pair);
}
}
}
}
void SetSummaryNodesRefCount(const KernelGraph *graph) { void SetSummaryNodesRefCount(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (!graph->summary_node_exist()) { if (!graph->summary_node_exist()) {
@ -359,27 +72,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
device_address->ResetRefCount(); device_address->ResetRefCount();
} }
} }
void SetGraphInputNodeActualAbstract(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (!op_run_info->base_op_run_info.has_dynamic_output && !op_run_info->base_op_run_info.has_dynamic_input) {
return;
}
const auto &tensor_mask = op_run_info->base_op_run_info.input_mask;
const auto &input_tensors = op_run_info->base_op_run_info.input_tensor;
auto &graph_inputs = graph->inputs();
for (size_t i = 0, j = 0; i < op_run_info->base_op_run_info.input_tensor.size() && j < graph_inputs.size(); ++i) {
if (tensor_mask[i] == kValueNodeTensorMask) {
continue;
}
if (input_tensors[i]->base_shape_ptr() != nullptr) {
const auto &shape_of_tensor = input_tensors[i]->shape();
auto actual_abstract = std::make_shared<abstract::AbstractTensor>(input_tensors[i]->Dtype(), shape_of_tensor);
graph_inputs[j]->set_user_data(kActualAbstract, actual_abstract);
}
++j;
}
}
} // namespace } // namespace
GraphCompilerInfo::~GraphCompilerInfo() { GraphCompilerInfo::~GraphCompilerInfo() {
@ -758,7 +450,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
device_context->kernel_executor_->CreateKernel(graph->execution_order()); device_context->kernel_executor_->CreateKernel(graph->execution_order());
// Read the output and input ref map and set to the kernel graph. // Read the output and input ref map and set to the kernel graph.
AddOutInRefToGraph(graph); AnfAlgo::AddOutInRefToGraph(graph);
// Optimize the nop node. // Optimize the nop node.
if (!run_in_pynative) { if (!run_in_pynative) {
@ -821,152 +513,23 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
return graph->graph_id(); return graph->graph_id();
} }
GraphId GraphCompiler::CompileGraph(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
const DeviceContext *device_context) {
// Check if the graph cache exists.
auto iter = run_op_graphs_.find(op_run_info->base_op_run_info.graph_info);
auto &op_executor = runtime::OpExecutor::GetInstance();
if (iter != run_op_graphs_.end() && op_executor.BuildQueueEmpty()) {
const auto &graph = iter->second;
MS_EXCEPTION_IF_NULL(graph);
SetGraphInputNodeActualAbstract(op_run_info, graph);
*single_op_cache_hit = true;
return graph->graph_id();
}
*single_op_cache_hit = false;
// Generate kernel graph.
MS_EXCEPTION_IF_NULL(session_);
KernelGraphPtr graph = session_->ConstructSingleOpGraph(
op_run_info, op_run_info->base_op_run_info.input_tensor, op_run_info->base_op_run_info.input_mask,
device_context->GetDeviceType() == device::DeviceType::kAscend);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
graph->set_run_mode(device::RunMode::kKernelMode);
graph->set_is_from_single_op(true);
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
auto deprecated_kernel_executor =
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
if (deprecated_kernel_executor != nullptr) {
deprecated_kernel_executor->UnifyMindIR(graph);
} else {
opt::CommonUnifyMindIR(graph);
}
// Select kernel and optimize
device_context->kernel_executor_->OptimizeGraph(graph);
UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
// Set dynamic shape actual abstract
SetGraphInputNodeActualAbstract(op_run_info, graph);
// Create device address for all anf nodes of graph.
CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
run_op_graphs_[op_run_info->base_op_run_info.graph_info] = graph;
auto output_nodes = graph->outputs();
auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
for (auto &node : output_nodes) {
MS_EXCEPTION_IF_NULL(node);
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
}
AnfAlgo::UpdateGraphValidRefPair(graph);
return graph->graph_id();
}
void GraphCompiler::UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info,
const KernelGraphPtr &graph) const {
// Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
// So need to get ref info to generate output addr, before create kernel.
if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
op_run_info->base_op_run_info.device_target != kGPUDevice) {
// just ascend ref mode is diff with cpu and gpu
return;
}
AddOutInRefToGraph(graph);
}
void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs,
const DeviceContext *device_context) const {
MS_EXCEPTION_IF_NULL(device_context);
std::vector<CNodePtr> node_to_build;
for (const auto &graph : graphs) {
const auto &nodes = graph->execution_order();
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
}
// Kernel build
device_context->kernel_executor_->CreateKernel(node_to_build);
for (const auto &graph : graphs) {
device_context->kernel_executor_->PreprocessBeforeRun(graph);
CreateKernelWorkspaceDeviceAddress(device_context, graph);
// Need to execute after PreprocessBeforeRunSingleOpGraph
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
}
}
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const { KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
return session_->GetGraph(graph_id); return session_->GetGraph(graph_id);
} }
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
auto iter = run_op_graphs_.find(graph_info);
if (iter == run_op_graphs_.end()) {
MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
return nullptr;
}
return iter->second;
}
void GraphCompiler::AddOutInRefToGraph(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &cnode : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
for (const auto &ref : kernel_info->out_in_ref_map()) {
size_t output_index = ref.first;
size_t input_index = ref.second;
auto final_pair = std::make_pair(cnode, output_index);
auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
<< ", output index: " << final_pair.second << " to input "
<< origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
// Add to graph only if the input is not a monad.
if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
graph->AddRefCorrespondPairs(final_pair, origin_pair);
}
}
}
}
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const { void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
CreateParameterDeviceAddress(device_context, graph); DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
CreateValueNodeDeviceAddress(device_context, graph); DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
CreateKernelOutputDeviceAddress(device_context, graph, false); DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, false);
CreateKernelWorkspaceDeviceAddress(device_context, graph); DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
UpdateDeviceAddressForInplaceNode(graph); DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
UpdateDeviceAddressForRefNode(graph); DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id(); MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
} }
void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
const DeviceContext *device_context,
bool is_gradient_out) const {
CreateParameterDeviceAddress(device_context, graph);
CreateValueNodeDeviceAddress(device_context, graph);
CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
UpdateDeviceAddressForInplaceNode(graph);
UpdateDeviceAddressForRefNode(graph);
}
void GraphCompiler::GetParamAndOutputIndex( void GraphCompiler::GetParamAndOutputIndex(
const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs, const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
std::map<AnfNodePtr, size_t> *parameter_index, std::map<AnfNodePtr, size_t> *parameter_index,
@ -1055,14 +618,6 @@ void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
session_->ClearAllBucket(graph_id); session_->ClearAllBucket(graph_id);
} }
const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
const auto &iter = run_op_graph_output_nodes_.find(graph_id);
if (iter == run_op_graph_output_nodes_.end()) {
MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
}
return iter->second;
}
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const { void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
MS_EXCEPTION_IF_NULL(session_); MS_EXCEPTION_IF_NULL(session_);
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
@ -1079,11 +634,6 @@ void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
} }
} }
void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
(void)run_op_graphs_.erase(graph_info);
(void)run_op_graph_output_nodes_.erase(graph_id);
}
void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const { void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(segment); MS_EXCEPTION_IF_NULL(segment);

View File

@ -105,19 +105,9 @@ class GraphCompiler {
// the detailed implementation of compiling graph is in 'CompileGraphImpl'. // the detailed implementation of compiling graph is in 'CompileGraphImpl'.
GraphId CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph, const DeviceContext *device_context); GraphId CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph, const DeviceContext *device_context);
// Construct single op kernel graph and compile the kernel graph in PyNative mode.
GraphId CompileGraph(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
const DeviceContext *device_context);
// Create kernel and Create workspace for graphs in PyNative mode.
void BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context) const;
// Get graph by graph id, if not exist return nullptr, used in Graph mode. // Get graph by graph id, if not exist return nullptr, used in Graph mode.
KernelGraphPtr Fetch(GraphId graph_id) const; KernelGraphPtr Fetch(GraphId graph_id) const;
// Get graph by graph info, if not exist return nullptr, used in PyNative mode.
KernelGraphPtr Fetch(const GraphInfo &graph_info) const;
// The following four methods used in PyNative back propagation to split complete kernel graph to single // The following four methods used in PyNative back propagation to split complete kernel graph to single
// op graph, and these methods will be removed to class MindRTBackend after deleting session module. // op graph, and these methods will be removed to class MindRTBackend after deleting session module.
@ -177,16 +167,11 @@ class GraphCompiler {
// operator. // operator.
void ClearAllBucket(const GraphId &graph_id); void ClearAllBucket(const GraphId &graph_id);
const std::vector<KernelWithIndex> &GetGraphOutputNodes(GraphId graph_id) const;
// Register a summary callback function, which is called in the final stages of summary. // Register a summary callback function, which is called in the final stages of summary.
void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const; void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const;
// Execute graph summary. // Execute graph summary.
void Summary(const std::vector<KernelGraphPtr> &graphs) const; void Summary(const std::vector<KernelGraphPtr> &graphs) const;
// Remove single op kernel graph cache and output nodes cache.
void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
// The implementation of compiling graph in Graph Mode, including optimizing graph, // The implementation of compiling graph in Graph Mode, including optimizing graph,
// setting operator info, creating kernel and transforming kernel graph to ActorSet. // setting operator info, creating kernel and transforming kernel graph to ActorSet.
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context, GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context,
@ -195,28 +180,12 @@ class GraphCompiler {
private: private:
DISABLE_COPY_AND_ASSIGN(GraphCompiler); DISABLE_COPY_AND_ASSIGN(GraphCompiler);
// Add operators' output and input reference map to the graph.
void AddOutInRefToGraph(const KernelGraphPtr &graph) const;
// Update ref info of graph, before create kernel.
void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info,
const KernelGraphPtr &graph) const;
// Create device address for all anf nodes of graph. // Create device address for all anf nodes of graph.
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const; void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
// Create device address for input and output of ops.
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
bool is_gradient_out) const;
// Set Graph's dependencies for pre_graph and post_graph. // Set Graph's dependencies for pre_graph and post_graph.
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const; void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;
// Single op kernel graph cache for PyNative mode.
mindspore::HashMap<GraphInfo, KernelGraphPtr> run_op_graphs_;
// Single op kernel graph output nodes cache for PyNative mode.
mindspore::HashMap<GraphId, std::vector<KernelWithIndex>> run_op_graph_output_nodes_;
// The member variable 'session_' will be removed after removing session module. // The member variable 'session_' will be removed after removing session module.
// Now all the GraphCompiler share the same 'session_'. // Now all the GraphCompiler share the same 'session_'.
session::SessionPtr session_; session::SessionPtr session_;

View File

@ -0,0 +1,162 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/pynative/op_compiler.h"
#include <memory>
#include <algorithm>
#include <vector>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "runtime/pynative/op_executor.h"
#include "runtime/pynative/op_runtime_info.h"
#include "runtime/device/device_address_utils.h"
namespace mindspore {
using runtime::DeviceAddressUtils;
namespace pynative {
namespace {
void SetGraphInputNodeActualAbstract(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (!op_run_info->base_op_run_info.has_dynamic_output && !op_run_info->base_op_run_info.has_dynamic_input) {
return;
}
const auto &tensor_mask = op_run_info->base_op_run_info.input_mask;
const auto &input_tensors = op_run_info->base_op_run_info.input_tensor;
auto &graph_inputs = graph->inputs();
for (size_t i = 0, j = 0; i < op_run_info->base_op_run_info.input_tensor.size() && j < graph_inputs.size(); ++i) {
if (tensor_mask[i] == kValueNodeTensorMask) {
continue;
}
if (input_tensors[i]->base_shape_ptr() != nullptr) {
const auto &shape_of_tensor = input_tensors[i]->shape();
auto actual_abstract = std::make_shared<abstract::AbstractTensor>(input_tensors[i]->Dtype(), shape_of_tensor);
graph_inputs[j]->set_user_data(kActualAbstract, actual_abstract);
}
++j;
}
}
void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
// Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
// So need to get ref info to generate output addr, before create kernel.
if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
op_run_info->base_op_run_info.device_target != kGPUDevice) {
// just ascend ref mode is diff with cpu and gpu
return;
}
AnfAlgo::AddOutInRefToGraph(graph);
}
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
bool is_gradient_out) {
DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
}
} // namespace
OpCompiler::OpCompiler() { session_ = session::SessionFactory::Get().Create(kSessionBasic); }
OpCompiler &OpCompiler::GetInstance() {
static OpCompiler instance;
return instance;
}
OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(op_run_info);
auto graph_info = op_run_info->base_op_run_info.graph_info;
auto iter = op_compiler_infos_.find(graph_info);
// Check if the graph cache exists.
auto &op_executor = runtime::OpExecutor::GetInstance();
if (iter != op_compiler_infos_.end() && op_executor.BuildQueueEmpty()) {
const auto &op_compiler_info = iter->second;
MS_EXCEPTION_IF_NULL(op_compiler_info);
SetGraphInputNodeActualAbstract(op_run_info, op_compiler_info->graph_);
*single_op_cache_hit = true;
return iter->second;
}
*single_op_cache_hit = false;
// Generate kernel graph.
MS_EXCEPTION_IF_NULL(session_);
KernelGraphPtr graph = session_->ConstructSingleOpGraph(
op_run_info, op_run_info->base_op_run_info.input_tensor, op_run_info->base_op_run_info.input_mask,
device_context->GetDeviceType() == device::DeviceType::kAscend);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);
graph->set_run_mode(device::RunMode::kKernelMode);
graph->set_is_from_single_op(true);
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
auto deprecated_kernel_executor =
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
if (deprecated_kernel_executor != nullptr) {
deprecated_kernel_executor->UnifyMindIR(graph);
} else {
opt::CommonUnifyMindIR(graph);
}
// Select kernel and optimize
device_context->kernel_executor_->OptimizeGraph(graph);
UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
// Set dynamic shape actual abstract
SetGraphInputNodeActualAbstract(op_run_info, graph);
// Create device address for all anf nodes of graph.
CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
auto output_nodes = graph->outputs();
std::vector<KernelWithIndex> outputs_with_index;
for (auto &node : output_nodes) {
MS_EXCEPTION_IF_NULL(node);
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
}
AnfAlgo::UpdateGraphValidRefPair(graph);
auto op_compiler_info =
std::make_shared<OpCompilerInfo>(graph_info, graph->graph_id(), graph, outputs_with_index, device_context, false);
op_compiler_infos_[graph_info] = op_compiler_info;
return op_compiler_info;
}
void OpCompiler::BatchBuild(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
std::vector<CNodePtr> node_to_build;
for (const auto &graph : graphs) {
const auto &nodes = graph->execution_order();
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
}
// Kernel build
device_context->kernel_executor_->CreateKernel(node_to_build);
for (const auto &graph : graphs) {
device_context->kernel_executor_->PreprocessBeforeRun(graph);
DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
// Need to execute after PreprocessBeforeRunSingleOpGraph
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
}
}
void OpCompiler::ClearOpCache(const GraphInfo &graph_info) { op_compiler_infos_.erase(graph_info); }
void OpCompiler::ClearAllCache() { op_compiler_infos_.clear(); }
} // namespace pynative
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
#include <utility>
#include <vector>
#include <memory>
#include "utils/ms_utils.h"
#include "backend/common/session/kernel_graph.h"
#include "backend/common/session/session_basic.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
using device::DeviceContext;
using session::KernelWithIndex;
namespace pynative {
struct OpCompilerInfo {
OpCompilerInfo(GraphInfo graph_info, GraphId graph_id, KernelGraphPtr graph,
std::vector<KernelWithIndex> graph_output_nodes, DeviceContext *device_context, bool need_erase)
: graph_info_(std::move(graph_info)),
graph_id_(graph_id),
graph_(std::move(graph)),
graph_output_nodes_(std::move(graph_output_nodes)),
device_context_(device_context),
need_erase_(need_erase) {}
~OpCompilerInfo() = default;
GraphInfo graph_info_;
GraphId graph_id_;
KernelGraphPtr graph_;
std::vector<KernelWithIndex> graph_output_nodes_;
DeviceContext *device_context_;
bool need_erase_;
};
using OpCompilerInfoPtr = std::shared_ptr<OpCompilerInfo>;
// FuncGraph, Backend and GraphCompiler correspond one-to-one,
// and GraphCompiler stores the compilation cache of operators.
// When the graph structure changes, the front-end will send multiple graphs,
// the operators of each graph will be compiled separately, which will result in very poor performance.
// Therefore, the OpCompiler class is required to save all operator caches and make them independent of Graph.
class BACKEND_EXPORT OpCompiler {
public:
static OpCompiler &GetInstance();
// Compile RunOpInfo into a KernelGraph.
OpCompilerInfoPtr Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
device::DeviceContext *device_context);
// Clear op cache in dynamic scenes.
// Otherwise, the operator cache will keep growing, resulting in insufficient memory.
void ClearOpCache(const GraphInfo &graph_info);
// Accumulate a certain number of operators,
// and then compile the operators in parallel to improve compilation efficiency.
static void BatchBuild(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context);
// Clear anf resources before process exit.
void ClearAllCache();
private:
OpCompiler();
~OpCompiler() = default;
DISABLE_COPY_AND_ASSIGN(OpCompiler);
// All operators shared the same session.
session::SessionPtr session_;
mindspore::HashMap<GraphInfo, OpCompilerInfoPtr> op_compiler_infos_;
};
} // namespace pynative
using OpCompilerInfoPtr = pynative::OpCompilerInfoPtr;
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_

View File

@ -89,7 +89,7 @@ void OpExecutor::PushOpBuildTask(const std::shared_ptr<OpBuildTask> &op_build_ta
void OpExecutor::PushOpRunTask(const std::shared_ptr<OpTask> &op_run_task) { void OpExecutor::PushOpRunTask(const std::shared_ptr<OpTask> &op_run_task) {
std::lock_guard<std::mutex> lock(task_mutex_); std::lock_guard<std::mutex> lock(task_mutex_);
op_run_tasks_.push(op_run_task); op_run_tasks_.push(op_run_task);
actor_in_queue_.insert(op_run_task->context()->graph_compiler_info()->name_); actor_in_queue_.insert(op_run_task->context()->graph_id());
task_cond_var_.notify_all(); task_cond_var_.notify_all();
} }
@ -117,9 +117,9 @@ bool OpExecutor::BuildQueueFull() {
return op_build_tasks_.size() > kMaxQueueSize; return op_build_tasks_.size() > kMaxQueueSize;
} }
bool OpExecutor::ActorInQueue(const std::string &actor_info) { bool OpExecutor::ActorInQueue(GraphId graph_id) {
std::lock_guard<std::mutex> lock(task_mutex_); std::lock_guard<std::mutex> lock(task_mutex_);
auto iter = actor_in_queue_.find(actor_info); auto iter = actor_in_queue_.find(graph_id);
return iter != actor_in_queue_.end(); return iter != actor_in_queue_.end();
} }
@ -152,7 +152,7 @@ void OpExecutor::WorkerLoop() {
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
if (!op_run_tasks_.empty()) { if (!op_run_tasks_.empty()) {
op_run_tasks_.pop(); op_run_tasks_.pop();
actor_in_queue_.erase(task->context()->graph_compiler_info()->name_); actor_in_queue_.erase(task->context()->graph_id());
} }
if (op_run_tasks_.empty()) { if (op_run_tasks_.empty()) {

View File

@ -67,7 +67,7 @@ class BACKEND_EXPORT OpExecutor {
// Determine if there is another task with the same name in execution. // Determine if there is another task with the same name in execution.
// Tasks with the same name use the same CNode cache. So we need to wait. // Tasks with the same name use the same CNode cache. So we need to wait.
bool ActorInQueue(const std::string &actor_info); bool ActorInQueue(GraphId graph_id);
// Wait for all OpRunTasks to finish executing. // Wait for all OpRunTasks to finish executing.
void Wait(); void Wait();
@ -88,7 +88,7 @@ class BACKEND_EXPORT OpExecutor {
std::vector<std::shared_ptr<OpBuildTask>> op_build_tasks_; std::vector<std::shared_ptr<OpBuildTask>> op_build_tasks_;
std::queue<std::shared_ptr<OpTask>> op_run_tasks_; std::queue<std::shared_ptr<OpTask>> op_run_tasks_;
std::set<std::string> actor_in_queue_; std::set<GraphId> actor_in_queue_;
std::function<void()> batch_build_callback_{nullptr}; std::function<void()> batch_build_callback_{nullptr};
inline static size_t kMaxQueueSize = 20; inline static size_t kMaxQueueSize = 20;
bool executing_{false}; bool executing_{false};

View File

@ -32,10 +32,9 @@
namespace mindspore::runtime { namespace mindspore::runtime {
class OpTaskContext { class OpTaskContext {
public: public:
OpTaskContext(GraphCompilerInfo *graph_compiler_info, KernelGraphPtr graph, OpTaskContext(GraphId graph_id, KernelGraphPtr graph, std::vector<session::KernelWithIndex> output_nodes,
std::vector<session::KernelWithIndex> output_nodes, session::BackendOpRunInfoPtr op_run_info, session::BackendOpRunInfoPtr op_run_info, device::DeviceContext *device_context, bool is_pynative_infer)
device::DeviceContext *device_context, bool is_pynative_infer) : graph_id_(graph_id),
: graph_compiler_info_(graph_compiler_info),
graph_(std::move(graph)), graph_(std::move(graph)),
output_nodes_(std::move(output_nodes)), output_nodes_(std::move(output_nodes)),
op_run_info_(std::move(op_run_info)), op_run_info_(std::move(op_run_info)),
@ -43,7 +42,7 @@ class OpTaskContext {
is_pyantive_infer_(is_pynative_infer) {} is_pyantive_infer_(is_pynative_infer) {}
~OpTaskContext() = default; ~OpTaskContext() = default;
GraphCompilerInfo *graph_compiler_info() const { return graph_compiler_info_; } GraphId graph_id() const { return graph_id_; }
const KernelGraphPtr &graph() const { return graph_; } const KernelGraphPtr &graph() const { return graph_; }
const std::vector<session::KernelWithIndex> &output_nodes() const { return output_nodes_; } const std::vector<session::KernelWithIndex> &output_nodes() const { return output_nodes_; }
const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; } const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; }
@ -51,7 +50,7 @@ class OpTaskContext {
bool is_pynative_infer() const { return is_pyantive_infer_; } bool is_pynative_infer() const { return is_pyantive_infer_; }
private: private:
GraphCompilerInfo *graph_compiler_info_; GraphId graph_id_;
KernelGraphPtr graph_; KernelGraphPtr graph_;
std::vector<session::KernelWithIndex> output_nodes_; std::vector<session::KernelWithIndex> output_nodes_;
session::BackendOpRunInfoPtr op_run_info_; session::BackendOpRunInfoPtr op_run_info_;