forked from mindspore-Ecosystem/mindspore
!40704 RunOp compile refactor
Merge pull request !40704 from caifubi/master-pynative-op-compile-refactor
This commit is contained in:
commit
cd8b826fcf
|
@ -1387,5 +1387,27 @@ void AnfRuntimeAlgorithm::UpdateInternalParameterShape(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnfRuntimeAlgorithm::AddOutInRefToGraph(const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
for (const auto &cnode : graph->execution_order()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
for (const auto &ref : kernel_info->out_in_ref_map()) {
|
||||||
|
size_t output_index = ref.first;
|
||||||
|
size_t input_index = ref.second;
|
||||||
|
auto final_pair = std::make_pair(cnode, output_index);
|
||||||
|
auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
|
||||||
|
MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
|
||||||
|
<< ", output index: " << final_pair.second << " to input "
|
||||||
|
<< origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
|
||||||
|
// Add to graph only if the input is not a monad.
|
||||||
|
if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
|
||||||
|
graph->AddRefCorrespondPairs(final_pair, origin_pair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -180,6 +180,8 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
|
||||||
static void UpdateInternalParameterShape(const std::map<size_t, std::vector<AnfNodeWeakPtr>> &internal_parameters,
|
static void UpdateInternalParameterShape(const std::map<size_t, std::vector<AnfNodeWeakPtr>> &internal_parameters,
|
||||||
const CNodePtr &cnode);
|
const CNodePtr &cnode);
|
||||||
static bool IsShapesDynamic(const std::vector<ShapeVector> &shapes);
|
static bool IsShapesDynamic(const std::vector<ShapeVector> &shapes);
|
||||||
|
|
||||||
|
static void AddOutInRefToGraph(const KernelGraphPtr &graph);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include "backend/graph_compiler/transform.h"
|
#include "backend/graph_compiler/transform.h"
|
||||||
#include "backend/common/session/session_factory.h"
|
#include "backend/common/session/session_factory.h"
|
||||||
#include "runtime/pynative/op_executor.h"
|
#include "runtime/pynative/op_executor.h"
|
||||||
|
#include "runtime/pynative/op_compiler.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
#include "pipeline/pynative/pynative_execute.h"
|
|
||||||
#include "pipeline/jit/action.h"
|
#include "pipeline/jit/action.h"
|
||||||
#include "pipeline/jit/parse/data_converter.h"
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
@ -1388,47 +1388,8 @@ std::shared_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
||||||
strategy);
|
strategy);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
void MindRTBackend::EraseSingleOpCache(const GraphInfo &graph_info) {
|
||||||
const ActorInfo &actor_info, const std::vector<int64_t> &tensors_mask, const std::vector<TensorPtr> &input_tensors,
|
pynative::OpCompiler::GetInstance().ClearOpCache(graph_info);
|
||||||
bool need_erase) {
|
|
||||||
std::vector<KernelGraphPtr> graphs;
|
|
||||||
std::vector<DeviceContext *> device_contexts;
|
|
||||||
runtime::KernelMapPosition outputs_order;
|
|
||||||
size_t position = 0;
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
|
||||||
for (const auto &graph_info_to_context : graph_info_to_device_context_) {
|
|
||||||
const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
(void)graphs.emplace_back(graph);
|
|
||||||
(void)device_contexts.emplace_back(graph_info_to_context.second);
|
|
||||||
|
|
||||||
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
|
|
||||||
for (const auto &output : outputs) {
|
|
||||||
if (outputs_order.count(output) == 0) {
|
|
||||||
outputs_order[output] = {position++};
|
|
||||||
} else {
|
|
||||||
(void)outputs_order[output].emplace_back(position++);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(&tensors_mask));
|
|
||||||
std::vector<std::vector<TensorPtr> *> input_tensors_list(
|
|
||||||
1, const_cast<std::vector<tensor::TensorPtr> *>(&input_tensors));
|
|
||||||
auto parser = std::make_shared<ControlNodeParser>();
|
|
||||||
return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
|
|
||||||
std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
|
|
||||||
outputs_order, 0, actor_info, need_erase,
|
|
||||||
runtime::GraphExecutionStrategy::kStep);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info,
|
|
||||||
const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
|
||||||
graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id());
|
|
||||||
actor_to_graph_compiler_info_.erase(actor_info);
|
|
||||||
(void)graph_info_to_device_context_.erase(graph_info);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors) {
|
void MindRTBackend::ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors) {
|
||||||
|
@ -1440,13 +1401,11 @@ void MindRTBackend::CompileSingleOpGraphs(const std::vector<std::shared_ptr<runt
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<KernelGraphPtr> graphs;
|
std::vector<KernelGraphPtr> graphs;
|
||||||
std::vector<GraphCompilerInfo *> graph_compiler_infos;
|
|
||||||
for (const auto &task : build_tasks) {
|
for (const auto &task : build_tasks) {
|
||||||
MS_EXCEPTION_IF_NULL(task);
|
MS_EXCEPTION_IF_NULL(task);
|
||||||
const auto &context = task->context();
|
const auto &context = task->context();
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
graphs.push_back(context->graph());
|
graphs.push_back(context->graph());
|
||||||
graph_compiler_infos.push_back(context->graph_compiler_info());
|
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(build_tasks[0]);
|
MS_EXCEPTION_IF_NULL(build_tasks[0]);
|
||||||
auto &task_context = build_tasks[0]->context();
|
auto &task_context = build_tasks[0]->context();
|
||||||
|
@ -1456,11 +1415,7 @@ void MindRTBackend::CompileSingleOpGraphs(const std::vector<std::shared_ptr<runt
|
||||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, task_context->is_pynative_infer());
|
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, task_context->is_pynative_infer());
|
||||||
|
|
||||||
auto device_context = task_context->device_context();
|
auto device_context = task_context->device_context();
|
||||||
graph_compiler_->BuildSingleOpGraphs(graphs, device_context);
|
pynative::OpCompiler::BatchBuild(graphs, device_context);
|
||||||
for (const auto &graph_compile_info : graph_compiler_infos) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph_compile_info);
|
|
||||||
graph_compile_info->input_tensors_.clear();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) {
|
void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) {
|
||||||
|
@ -1524,26 +1479,22 @@ void MindRTBackend::BatchBuildCallback() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info,
|
void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
|
||||||
|
const OpCompilerInfoPtr &op_compiler_info,
|
||||||
const session::BackendOpRunInfoPtr &op_run_info) {
|
const session::BackendOpRunInfoPtr &op_run_info) {
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
MS_EXCEPTION_IF_NULL(op_compiler_info);
|
||||||
// Fetch outputs.
|
const auto &graph = op_compiler_info->graph_;
|
||||||
if (graph_compiler_info->graphs_.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "No graph found, op:" << graph_compiler_info->name_;
|
|
||||||
}
|
|
||||||
const auto &graph = graph_compiler_info->graphs_.front();
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id());
|
const auto &output_nodes = op_compiler_info->graph_output_nodes_;
|
||||||
|
|
||||||
runtime::UpdateDeviceAddress(graph, GetTensorWithoutValueMask(op_run_info),
|
runtime::UpdateDeviceAddress(graph, GetTensorWithoutValueMask(op_run_info), op_compiler_info->device_context_);
|
||||||
graph_compiler_info->device_contexts_.front());
|
|
||||||
UpdateOutput(output_nodes, outputs);
|
UpdateOutput(output_nodes, outputs);
|
||||||
|
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
|
auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
|
||||||
auto run_op_context = std::make_shared<runtime::OpTaskContext>(
|
auto run_op_context = std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, output_nodes, op_run_info,
|
||||||
graph_compiler_info, graph, output_nodes, op_run_info, graph_compiler_info->device_contexts_.front(), infer_flag);
|
op_compiler_info->device_context_, infer_flag);
|
||||||
|
|
||||||
// Save build task and run task.
|
// Save build task and run task.
|
||||||
std::promise<bool> promise;
|
std::promise<bool> promise;
|
||||||
|
@ -1565,28 +1516,28 @@ void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph_compiler_info,
|
void MindRTBackend::RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
|
||||||
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
|
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
MS_EXCEPTION_IF_NULL(op_compiler_info);
|
||||||
// Fetch outputs.
|
// Fetch outputs.
|
||||||
const auto &graph = graph_compiler_info->graphs_.front();
|
const auto &graph = op_compiler_info->graph_;
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||||
const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id());
|
const auto &output_nodes = op_compiler_info->graph_output_nodes_;
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
|
|
||||||
auto device_context = graph_compiler_info->device_contexts_.front();
|
auto device_context = op_compiler_info->device_context_;
|
||||||
auto &op_executor = runtime::OpExecutor::GetInstance();
|
auto &op_executor = runtime::OpExecutor::GetInstance();
|
||||||
bool is_dynamic_shape =
|
bool is_dynamic_shape =
|
||||||
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
|
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
|
||||||
|
|
||||||
bool async_exec_disabled = is_dynamic_shape || graph_compiler_info->need_erase_ ||
|
bool async_exec_disabled = is_dynamic_shape || op_compiler_info->need_erase_ ||
|
||||||
!op_run_info->base_op_run_info.lazy_build || OpInBlackList(op_run_info) ||
|
!op_run_info->base_op_run_info.lazy_build || OpInBlackList(op_run_info) ||
|
||||||
GetExecutionMode() == kGraphMode || EnablePyNativeSyncRunning();
|
GetExecutionMode() == kGraphMode || EnablePyNativeSyncRunning();
|
||||||
if (!async_exec_disabled) {
|
if (!async_exec_disabled) {
|
||||||
MS_LOG(DEBUG) << "Async exec enabled, op:" << op_run_info->base_op_run_info.op_name;
|
MS_LOG(DEBUG) << "Async exec enabled, op:" << op_run_info->base_op_run_info.op_name;
|
||||||
DispatchOpTask(single_op_cache_hit, outputs, graph_compiler_info, op_run_info);
|
DispatchOpTask(single_op_cache_hit, outputs, op_compiler_info, op_run_info);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1595,7 +1546,7 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph
|
||||||
WaitTaskFinish();
|
WaitTaskFinish();
|
||||||
}
|
}
|
||||||
if (!single_op_cache_hit) {
|
if (!single_op_cache_hit) {
|
||||||
CompileSingleOpGraph(graph, device_context, graph_compiler_info);
|
CompileSingleOpGraph(graph, device_context);
|
||||||
}
|
}
|
||||||
auto tensors_without_value_mask = GetTensorWithoutValueMask(op_run_info);
|
auto tensors_without_value_mask = GetTensorWithoutValueMask(op_run_info);
|
||||||
runtime::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context);
|
runtime::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context);
|
||||||
|
@ -1609,8 +1560,8 @@ void MindRTBackend::RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph
|
||||||
if (op_run_info->base_op_run_info.has_dynamic_output) {
|
if (op_run_info->base_op_run_info.has_dynamic_output) {
|
||||||
UpdateOutputAbstract(graph, op_run_info);
|
UpdateOutputAbstract(graph, op_run_info);
|
||||||
}
|
}
|
||||||
if (graph_compiler_info->need_erase_) {
|
if (op_compiler_info->need_erase_) {
|
||||||
EraseSingleOpCache(graph_compiler_info->name_, op_run_info->base_op_run_info.graph_info, graph);
|
EraseSingleOpCache(op_run_info->base_op_run_info.graph_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1624,22 +1575,14 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
|
||||||
device_context->Initialize();
|
device_context->Initialize();
|
||||||
|
|
||||||
bool single_op_cache_hit = true;
|
bool single_op_cache_hit = true;
|
||||||
auto graph_id = graph_compiler_->CompileGraph(op_run_info, &single_op_cache_hit, device_context);
|
auto op_compiler_info =
|
||||||
std::string actor_info = std::to_string(graph_id) + "_" + op_run_info->base_op_run_info.op_name;
|
pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_context);
|
||||||
if (runtime::OpExecutor::GetInstance().ActorInQueue(actor_info)) {
|
MS_EXCEPTION_IF_NULL(op_compiler_info);
|
||||||
|
if (runtime::OpExecutor::GetInstance().ActorInQueue(op_compiler_info->graph_id_)) {
|
||||||
WaitTaskFinish();
|
WaitTaskFinish();
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphCompilerInfo *graph_compiler_info_ptr;
|
if (!single_op_cache_hit) {
|
||||||
if (single_op_cache_hit) {
|
|
||||||
auto iter = actor_to_graph_compiler_info_.find(actor_info);
|
|
||||||
if (iter == actor_to_graph_compiler_info_.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
|
|
||||||
}
|
|
||||||
graph_compiler_info_ptr = iter->second.get();
|
|
||||||
} else {
|
|
||||||
graph_info_to_device_context_.clear();
|
|
||||||
graph_info_to_device_context_[op_run_info->base_op_run_info.graph_info] = device_context;
|
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||||
|
@ -1647,7 +1590,7 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
|
||||||
bool is_dynamic_shape =
|
bool is_dynamic_shape =
|
||||||
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
|
op_run_info->base_op_run_info.has_dynamic_output || op_run_info->base_op_run_info.has_dynamic_input;
|
||||||
if (is_dynamic_shape) {
|
if (is_dynamic_shape) {
|
||||||
const auto &graph = graph_compiler_->Fetch(op_run_info->base_op_run_info.graph_info);
|
const auto &graph = op_compiler_info->graph_;
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
graph->UpdateGraphDynamicAttr();
|
graph->UpdateGraphDynamicAttr();
|
||||||
// Dynamic shape but select static op, must no cache
|
// Dynamic shape but select static op, must no cache
|
||||||
|
@ -1655,26 +1598,16 @@ void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, Vecto
|
||||||
enable_cache = false;
|
enable_cache = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, op_run_info->base_op_run_info.input_mask,
|
op_compiler_info->need_erase_ = !enable_cache;
|
||||||
op_run_info->base_op_run_info.input_tensor, !enable_cache);
|
|
||||||
graph_compiler_info_ptr = graph_compiler_info.get();
|
|
||||||
|
|
||||||
auto ret = actor_to_graph_compiler_info_.try_emplace(actor_info, std::move(graph_compiler_info));
|
|
||||||
if (!ret.second) {
|
|
||||||
MS_LOG(WARNING) << "ActorInfo:" << actor_info << " already exist in the map.";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RunOpImpl(single_op_cache_hit, graph_compiler_info_ptr, op_run_info, outputs);
|
RunOpImpl(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
void MindRTBackend::CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
|
||||||
GraphCompilerInfo *graph_compiler_info) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
graph_compiler_->BuildSingleOpGraphs({graph}, device_context);
|
pynative::OpCompiler::BatchBuild({graph}, device_context);
|
||||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
|
||||||
graph_compiler_info->input_tensors_.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
|
void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
#include "runtime/hardware/device_context.h"
|
#include "runtime/hardware/device_context.h"
|
||||||
#include "runtime/graph_scheduler/graph_scheduler.h"
|
#include "runtime/graph_scheduler/graph_scheduler.h"
|
||||||
#include "runtime/pynative/op_task.h"
|
#include "runtime/pynative/op_task.h"
|
||||||
|
#include "runtime/pynative/op_compiler.h"
|
||||||
#include "include/backend/visible.h"
|
#include "include/backend/visible.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -136,8 +137,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
||||||
void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode);
|
void CompileGraph(const GraphSegmentPtr &segment, device::RunMode run_mode);
|
||||||
|
|
||||||
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
|
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
|
||||||
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
||||||
GraphCompilerInfo *graph_compiler_info) const;
|
|
||||||
|
|
||||||
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
|
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
|
||||||
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
|
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
|
||||||
|
@ -153,27 +153,21 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
||||||
// Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
|
// Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
|
||||||
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
|
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
|
||||||
|
|
||||||
// Construct the GraphCompilerInfo by the compilation results of graph, used in PyNative mode.
|
|
||||||
std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const ActorInfo &actor_info,
|
|
||||||
const std::vector<int64_t> &tensors_mask,
|
|
||||||
const std::vector<TensorPtr> &input_tensors,
|
|
||||||
bool need_erase);
|
|
||||||
|
|
||||||
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
|
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
|
||||||
|
|
||||||
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
|
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
|
||||||
// so the latest single op cache should be erased when cache list size exceeds threshold value.
|
// so the latest single op cache should be erased when cache list size exceeds threshold value.
|
||||||
void EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info, const KernelGraphPtr &graph);
|
void EraseSingleOpCache(const GraphInfo &graph_info);
|
||||||
|
|
||||||
// Execute OpBuildTask and OpRunTask when the OpExecutor queue is full in PyNative mode.
|
// Execute OpBuildTask and OpRunTask when the OpExecutor queue is full in PyNative mode.
|
||||||
void BatchBuildCallback();
|
void BatchBuildCallback();
|
||||||
|
|
||||||
// Run op or dispatch build task and run task.
|
// Run op or dispatch build task and run task.
|
||||||
void RunOpImpl(bool single_op_cache_hit, GraphCompilerInfo *graph_compiler_info,
|
void RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
|
||||||
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs);
|
const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs);
|
||||||
|
|
||||||
// Dispatch task and execute the task in another thread.
|
// Dispatch task and execute the task in another thread.
|
||||||
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, GraphCompilerInfo *graph_compiler_info,
|
void DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, const OpCompilerInfoPtr &op_compiler_info,
|
||||||
const session::BackendOpRunInfoPtr &op_run_info);
|
const session::BackendOpRunInfoPtr &op_run_info);
|
||||||
|
|
||||||
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include "frontend/optimizer/ad/grad.h"
|
#include "frontend/optimizer/ad/grad.h"
|
||||||
#include "pipeline/jit/pass.h"
|
#include "pipeline/jit/pass.h"
|
||||||
#include "runtime/pynative/op_executor.h"
|
#include "runtime/pynative/op_executor.h"
|
||||||
|
#include "runtime/pynative/op_compiler.h"
|
||||||
#include "ir/cell.h"
|
#include "ir/cell.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
|
||||||
|
|
||||||
namespace mindspore::pynative {
|
namespace mindspore::pynative {
|
||||||
PyNativeExecutorPtr PyNativeExecutor::executor_ = nullptr;
|
PyNativeExecutorPtr PyNativeExecutor::executor_ = nullptr;
|
||||||
|
@ -139,6 +139,7 @@ void PyNativeExecutor::set_kernel_build_server_dir(const py::object &kernel_buil
|
||||||
void PyNativeExecutor::ClearRes() {
|
void PyNativeExecutor::ClearRes() {
|
||||||
MS_LOG(DEBUG) << "Clear all res";
|
MS_LOG(DEBUG) << "Clear all res";
|
||||||
runtime::OpExecutor::GetInstance().Reset();
|
runtime::OpExecutor::GetInstance().Reset();
|
||||||
|
pynative::OpCompiler::GetInstance().ClearAllCache();
|
||||||
|
|
||||||
// Maybe exit in runop step
|
// Maybe exit in runop step
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
|
|
@ -3,7 +3,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*
|
||||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||||
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" "tensor_array.cc"
|
||||||
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc"
|
"ms_device_shape_transfer.cc" "context_extends.cc" "stream_synchronizer.cc" "tensors_queue.cc" "auto_mem_offload.cc"
|
||||||
"common_somas_allocator.cc"
|
"common_somas_allocator.cc" "device_address_utils.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)
|
if("${ENABLE_HIDDEN}" STREQUAL "OFF" AND NOT MSVC)
|
||||||
|
|
|
@ -0,0 +1,321 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/device_address_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "runtime/device/ms_device_shape_transfer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using tensor::TensorPtr;
|
||||||
|
namespace runtime {
|
||||||
|
// Whether device address of anf node is valid and device address type
|
||||||
|
// is consistent with device type, for example, device address type
|
||||||
|
// DeviceType::kGPU should be used on GPU device
|
||||||
|
bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
if (AnfAlgo::OutputAddrExist(node, index)) {
|
||||||
|
const auto &address = AnfAlgo::GetOutputAddr(node, index, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
return address->GetDeviceType() == device_context->GetDeviceType();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::CreateParameterDeviceAddress(const DeviceContext *device_context,
|
||||||
|
const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
|
||||||
|
const std::vector<bool> &graph_valid_input = graph->valid_inputs();
|
||||||
|
(void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
|
||||||
|
|
||||||
|
// Anf nodes which need create device address.
|
||||||
|
std::vector<AnfNodePtr> nodes_list;
|
||||||
|
for (size_t i = 0; i < graph_inputs.size(); ++i) {
|
||||||
|
AnfNodePtr item = graph_inputs[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
|
||||||
|
std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
|
||||||
|
for (const auto &out : outs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(out);
|
||||||
|
if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
nodes_list.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
nodes_list.push_back(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create device address for anf node in nodes_list
|
||||||
|
for (const auto &item : nodes_list) {
|
||||||
|
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
|
||||||
|
for (size_t index = 0; index < output_size; index++) {
|
||||||
|
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
||||||
|
if (output_type_id == kTypeUnknown) {
|
||||||
|
output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
||||||
|
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||||
|
nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
|
||||||
|
trans::GetRuntimePaddingShape(item, index));
|
||||||
|
device_address->set_from_persistent_mem(item->isa<Parameter>());
|
||||||
|
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
|
||||||
|
<< " addr:" << device_address;
|
||||||
|
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *device_context,
|
||||||
|
const ValuePtr &node_value, size_t output_idx,
|
||||||
|
const ValueNodePtr &value_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
MS_EXCEPTION_IF_NULL(node_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
const auto &ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
std::vector<TensorPtr> tensors;
|
||||||
|
TensorValueToTensor(node_value, &tensors);
|
||||||
|
|
||||||
|
for (const auto &tensor : tensors) {
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "Tensor is null";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||||
|
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
|
||||||
|
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
|
||||||
|
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
|
||||||
|
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
||||||
|
value_node.get());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
|
||||||
|
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
|
||||||
|
if (output_type_id == kTypeUnknown) {
|
||||||
|
output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
||||||
|
}
|
||||||
|
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
||||||
|
|
||||||
|
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||||
|
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
|
||||||
|
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
|
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *device_context,
|
||||||
|
const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
if (NodeDeviceAddressExist(device_context, value_node, 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &node_value = value_node->value();
|
||||||
|
MS_EXCEPTION_IF_NULL(node_value);
|
||||||
|
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||||
|
CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
|
||||||
|
} else if (node_value->isa<StringImm>()) {
|
||||||
|
auto value = GetValue<std::string>(node_value);
|
||||||
|
size_t tensor_size = value.size();
|
||||||
|
auto address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT,
|
||||||
|
kNumberTypeUInt8, ShapeVector());
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->set_from_persistent_mem(true);
|
||||||
|
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
|
||||||
|
<< " addr:" << address;
|
||||||
|
|
||||||
|
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *device_context,
|
||||||
|
const KernelGraphPtr &graph, bool is_gradient_out) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
|
||||||
|
bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
|
||||||
|
auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
|
||||||
|
|
||||||
|
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_from_persistent_mem =
|
||||||
|
(is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
|
||||||
|
|
||||||
|
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
|
||||||
|
for (size_t i = 0; i < output_size; ++i) {
|
||||||
|
if (AnfAlgo::OutputAddrExist(kernel, i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
||||||
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
||||||
|
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
||||||
|
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||||
|
nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i));
|
||||||
|
if (is_from_persistent_mem) {
|
||||||
|
device_address->set_from_persistent_mem(true);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
||||||
|
<< " addr:" << device_address;
|
||||||
|
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context,
|
||||||
|
const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||||
|
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||||
|
if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
|
||||||
|
kTypeUnknown, ShapeVector());
|
||||||
|
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
||||||
|
<< " addr:" << device_address;
|
||||||
|
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
// Collect the inplace groups.
|
||||||
|
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
|
||||||
|
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto inplace_group_attr = primitive->GetAttr("inplace_group");
|
||||||
|
MS_EXCEPTION_IF_NULL(inplace_group_attr);
|
||||||
|
auto group_id = GetValue<uint32_t>(inplace_group_attr);
|
||||||
|
(void)inplace_groups[group_id].emplace_back(kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t kMinInplaceGroupSize = 2;
|
||||||
|
for (const auto &inplace_group : inplace_groups) {
|
||||||
|
auto &group_nodes = inplace_group.second;
|
||||||
|
if (group_nodes.size() < kMinInplaceGroupSize) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Get the device address of the first node in the inplace group.
|
||||||
|
auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
|
||||||
|
MS_EXCEPTION_IF_NULL(node_primitive);
|
||||||
|
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
|
||||||
|
// Update the device address of other nodes using device address of the first node in the inplace group.
|
||||||
|
for (size_t i = 1; i < group_nodes.size(); ++i) {
|
||||||
|
auto &group_node = group_nodes[i];
|
||||||
|
auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
|
||||||
|
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
|
||||||
|
// Update the reference count of device address.
|
||||||
|
device_address->IncreaseOriginalRefCount();
|
||||||
|
device_address->ResetRefCount();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
|
||||||
|
const session::AnfWithOutIndex &origin_pair) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_pair.first);
|
||||||
|
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
||||||
|
|
||||||
|
auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(origin_node_output_addr);
|
||||||
|
auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
|
||||||
|
|
||||||
|
if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
|
||||||
|
MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
|
||||||
|
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
|
||||||
|
<< ", index is " << cur_pair.second;
|
||||||
|
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_pair.second, cur_pair.first.get());
|
||||||
|
// Update the reference count of device address.
|
||||||
|
cur_node_output_addr->DecreaseOriginalRefCount();
|
||||||
|
cur_node_output_addr->ResetRefCount();
|
||||||
|
origin_node_output_addr->IncreaseOriginalRefCount();
|
||||||
|
origin_node_output_addr->ResetRefCount();
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
|
||||||
|
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
|
||||||
|
<< ", index is " << cur_pair.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceAddressUtils::UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto &kernels = graph->execution_order();
|
||||||
|
for (auto &kernel : kernels) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
|
||||||
|
if (output_num == 0) {
|
||||||
|
MS_LOG(DEBUG) << "This kernel has no output size.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
session::AnfWithOutIndex out_pair(kernel, i);
|
||||||
|
if (graph->IsInRefOutputMap(out_pair)) {
|
||||||
|
auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
|
||||||
|
UpdateDeviceAddress(out_pair, origin_pair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
|
||||||
|
|
||||||
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using device::DeviceContext;
|
||||||
|
namespace runtime {
|
||||||
|
// Extract the methods related to DeviceAddress in GraphCompiler to the DeviceAddressUtils class.
|
||||||
|
class DeviceAddressUtils {
|
||||||
|
public:
|
||||||
|
static void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
|
||||||
|
static void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
|
||||||
|
size_t output_idx, const ValueNodePtr &value_node);
|
||||||
|
static void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
|
||||||
|
static void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
|
||||||
|
bool is_gradient_out);
|
||||||
|
static void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
|
||||||
|
static void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph);
|
||||||
|
static void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
|
||||||
|
const session::AnfWithOutIndex &origin_pair);
|
||||||
|
static void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph);
|
||||||
|
};
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
|
|
@ -21,6 +21,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "runtime/graph_scheduler/graph_scheduler.h"
|
#include "runtime/graph_scheduler/graph_scheduler.h"
|
||||||
|
#include "runtime/device/device_address_utils.h"
|
||||||
#include "runtime/pynative/op_executor.h"
|
#include "runtime/pynative/op_executor.h"
|
||||||
#include "runtime/device/device_address.h"
|
#include "runtime/device/device_address.h"
|
||||||
#include "runtime/device/ms_device_shape_transfer.h"
|
#include "runtime/device/ms_device_shape_transfer.h"
|
||||||
|
@ -51,294 +52,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
namespace {
|
namespace {
|
||||||
// Whether device address of anf node is valid and device address type
|
|
||||||
// is consistent with device type, for example, device address type
|
|
||||||
// DeviceType::kGPU should be used on GPU device
|
|
||||||
bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
if (AnfAlgo::OutputAddrExist(node, index)) {
|
|
||||||
const auto &address = AnfAlgo::GetOutputAddr(node, index, false);
|
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
|
||||||
return address->GetDeviceType() == device_context->GetDeviceType();
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
|
|
||||||
const std::vector<bool> &graph_valid_input = graph->valid_inputs();
|
|
||||||
(void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
|
|
||||||
|
|
||||||
// Anf nodes which need create device address.
|
|
||||||
std::vector<AnfNodePtr> nodes_list;
|
|
||||||
for (size_t i = 0; i < graph_inputs.size(); ++i) {
|
|
||||||
AnfNodePtr item = graph_inputs[i];
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
|
|
||||||
std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
|
|
||||||
for (const auto &out : outs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(out);
|
|
||||||
if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
nodes_list.push_back(out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
nodes_list.push_back(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create device address for anf node in nodes_list
|
|
||||||
for (const auto &item : nodes_list) {
|
|
||||||
auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
|
|
||||||
for (size_t index = 0; index < output_size; index++) {
|
|
||||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
|
||||||
if (output_type_id == kTypeUnknown) {
|
|
||||||
output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
|
|
||||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
|
||||||
nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
|
|
||||||
trans::GetRuntimePaddingShape(item, index));
|
|
||||||
device_address->set_from_persistent_mem(item->isa<Parameter>());
|
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
|
|
||||||
<< " addr:" << device_address;
|
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, item.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
|
|
||||||
size_t output_idx, const ValueNodePtr &value_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
MS_EXCEPTION_IF_NULL(node_value);
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
const auto &ms_context = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
|
||||||
std::vector<TensorPtr> tensors;
|
|
||||||
TensorValueToTensor(node_value, &tensors);
|
|
||||||
|
|
||||||
for (const auto &tensor : tensors) {
|
|
||||||
if (tensor == nullptr) {
|
|
||||||
MS_LOG(WARNING) << "Tensor is null";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
|
||||||
if (output_address != nullptr && output_address->GetDeviceType() == device_context->GetDeviceType()) {
|
|
||||||
// We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
|
|
||||||
// in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
|
|
||||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
|
||||||
value_node.get());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
|
|
||||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
|
|
||||||
if (output_type_id == kTypeUnknown) {
|
|
||||||
output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
|
||||||
}
|
|
||||||
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
|
|
||||||
|
|
||||||
device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(
|
|
||||||
nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
|
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
|
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
|
||||||
address->set_from_persistent_mem(true);
|
|
||||||
AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
if (NodeDeviceAddressExist(device_context, value_node, 0)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto &node_value = value_node->value();
|
|
||||||
MS_EXCEPTION_IF_NULL(node_value);
|
|
||||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
|
||||||
CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
|
|
||||||
} else if (node_value->isa<StringImm>()) {
|
|
||||||
auto value = GetValue<std::string>(node_value);
|
|
||||||
size_t tensor_size = value.size();
|
|
||||||
auto address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT,
|
|
||||||
kNumberTypeUInt8, ShapeVector());
|
|
||||||
MS_EXCEPTION_IF_NULL(address);
|
|
||||||
address->set_from_persistent_mem(true);
|
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
|
|
||||||
<< " addr:" << address;
|
|
||||||
|
|
||||||
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
|
|
||||||
bool is_gradient_out) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
|
|
||||||
bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
|
|
||||||
auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
|
|
||||||
|
|
||||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
|
||||||
for (const auto &kernel : kernels) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
|
||||||
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_from_persistent_mem =
|
|
||||||
(is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
|
|
||||||
|
|
||||||
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
|
|
||||||
for (size_t i = 0; i < output_size; ++i) {
|
|
||||||
if (AnfAlgo::OutputAddrExist(kernel, i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
|
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
|
|
||||||
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
|
|
||||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
|
||||||
nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i));
|
|
||||||
if (is_from_persistent_mem) {
|
|
||||||
device_address->set_from_persistent_mem(true);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
|
||||||
<< " addr:" << device_address;
|
|
||||||
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
|
||||||
for (const auto &kernel : kernels) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
|
||||||
if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
|
||||||
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
|
||||||
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
|
||||||
if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(nullptr, workspace_sizes[i], "",
|
|
||||||
kTypeUnknown, ShapeVector());
|
|
||||||
MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
|
|
||||||
<< " addr:" << device_address;
|
|
||||||
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
// Collect the inplace groups.
|
|
||||||
std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
|
|
||||||
const std::vector<CNodePtr> &kernels = graph->execution_order();
|
|
||||||
for (const auto &kernel : kernels) {
|
|
||||||
if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto inplace_group_attr = primitive->GetAttr("inplace_group");
|
|
||||||
MS_EXCEPTION_IF_NULL(inplace_group_attr);
|
|
||||||
auto group_id = GetValue<uint32_t>(inplace_group_attr);
|
|
||||||
(void)inplace_groups[group_id].emplace_back(kernel);
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t kMinInplaceGroupSize = 2;
|
|
||||||
for (const auto &inplace_group : inplace_groups) {
|
|
||||||
auto &group_nodes = inplace_group.second;
|
|
||||||
if (group_nodes.size() < kMinInplaceGroupSize) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Get the device address of the first node in the inplace group.
|
|
||||||
auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
|
|
||||||
MS_EXCEPTION_IF_NULL(node_primitive);
|
|
||||||
auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
|
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
|
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
|
||||||
|
|
||||||
// Update the device address of other nodes using device address of the first node in the inplace group.
|
|
||||||
for (size_t i = 1; i < group_nodes.size(); ++i) {
|
|
||||||
auto &group_node = group_nodes[i];
|
|
||||||
auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
|
|
||||||
AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
|
|
||||||
// Update the reference count of device address.
|
|
||||||
device_address->IncreaseOriginalRefCount();
|
|
||||||
device_address->ResetRefCount();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair, const session::AnfWithOutIndex &origin_pair) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cur_pair.first);
|
|
||||||
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
|
||||||
|
|
||||||
auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
|
|
||||||
MS_EXCEPTION_IF_NULL(origin_node_output_addr);
|
|
||||||
auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
|
|
||||||
MS_EXCEPTION_IF_NULL(cur_node_output_addr);
|
|
||||||
|
|
||||||
if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
|
|
||||||
MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
|
|
||||||
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
|
|
||||||
<< ", index is " << cur_pair.second;
|
|
||||||
AnfAlgo::SetOutputAddr(origin_node_output_addr, cur_pair.second, cur_pair.first.get());
|
|
||||||
// Update the reference count of device address.
|
|
||||||
cur_node_output_addr->DecreaseOriginalRefCount();
|
|
||||||
cur_node_output_addr->ResetRefCount();
|
|
||||||
origin_node_output_addr->IncreaseOriginalRefCount();
|
|
||||||
origin_node_output_addr->ResetRefCount();
|
|
||||||
} else {
|
|
||||||
MS_LOG(INFO) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
|
|
||||||
<< ", index is " << origin_pair.second << ", cur kernel is " << cur_pair.first->fullname_with_scope()
|
|
||||||
<< ", index is " << cur_pair.second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
auto &kernels = graph->execution_order();
|
|
||||||
for (auto &kernel : kernels) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
|
||||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
|
|
||||||
if (output_num == 0) {
|
|
||||||
MS_LOG(DEBUG) << "This kernel has no output size.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
|
||||||
session::AnfWithOutIndex out_pair(kernel, i);
|
|
||||||
if (graph->IsInRefOutputMap(out_pair)) {
|
|
||||||
auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
|
|
||||||
UpdateDeviceAddress(out_pair, origin_pair);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetSummaryNodesRefCount(const KernelGraph *graph) {
|
void SetSummaryNodesRefCount(const KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
if (!graph->summary_node_exist()) {
|
if (!graph->summary_node_exist()) {
|
||||||
|
@ -359,27 +72,6 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
|
||||||
device_address->ResetRefCount();
|
device_address->ResetRefCount();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetGraphInputNodeActualAbstract(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
if (!op_run_info->base_op_run_info.has_dynamic_output && !op_run_info->base_op_run_info.has_dynamic_input) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const auto &tensor_mask = op_run_info->base_op_run_info.input_mask;
|
|
||||||
const auto &input_tensors = op_run_info->base_op_run_info.input_tensor;
|
|
||||||
auto &graph_inputs = graph->inputs();
|
|
||||||
for (size_t i = 0, j = 0; i < op_run_info->base_op_run_info.input_tensor.size() && j < graph_inputs.size(); ++i) {
|
|
||||||
if (tensor_mask[i] == kValueNodeTensorMask) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (input_tensors[i]->base_shape_ptr() != nullptr) {
|
|
||||||
const auto &shape_of_tensor = input_tensors[i]->shape();
|
|
||||||
auto actual_abstract = std::make_shared<abstract::AbstractTensor>(input_tensors[i]->Dtype(), shape_of_tensor);
|
|
||||||
graph_inputs[j]->set_user_data(kActualAbstract, actual_abstract);
|
|
||||||
}
|
|
||||||
++j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
GraphCompilerInfo::~GraphCompilerInfo() {
|
GraphCompilerInfo::~GraphCompilerInfo() {
|
||||||
|
@ -758,7 +450,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
device_context->kernel_executor_->CreateKernel(graph->execution_order());
|
device_context->kernel_executor_->CreateKernel(graph->execution_order());
|
||||||
|
|
||||||
// Read the output and input ref map and set to the kernel graph.
|
// Read the output and input ref map and set to the kernel graph.
|
||||||
AddOutInRefToGraph(graph);
|
AnfAlgo::AddOutInRefToGraph(graph);
|
||||||
|
|
||||||
// Optimize the nop node.
|
// Optimize the nop node.
|
||||||
if (!run_in_pynative) {
|
if (!run_in_pynative) {
|
||||||
|
@ -821,152 +513,23 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
return graph->graph_id();
|
return graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphId GraphCompiler::CompileGraph(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
|
|
||||||
const DeviceContext *device_context) {
|
|
||||||
// Check if the graph cache exists.
|
|
||||||
auto iter = run_op_graphs_.find(op_run_info->base_op_run_info.graph_info);
|
|
||||||
auto &op_executor = runtime::OpExecutor::GetInstance();
|
|
||||||
if (iter != run_op_graphs_.end() && op_executor.BuildQueueEmpty()) {
|
|
||||||
const auto &graph = iter->second;
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
SetGraphInputNodeActualAbstract(op_run_info, graph);
|
|
||||||
*single_op_cache_hit = true;
|
|
||||||
return graph->graph_id();
|
|
||||||
}
|
|
||||||
*single_op_cache_hit = false;
|
|
||||||
// Generate kernel graph.
|
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
|
||||||
KernelGraphPtr graph = session_->ConstructSingleOpGraph(
|
|
||||||
op_run_info, op_run_info->base_op_run_info.input_tensor, op_run_info->base_op_run_info.input_mask,
|
|
||||||
device_context->GetDeviceType() == device::DeviceType::kAscend);
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
|
|
||||||
graph->set_run_mode(device::RunMode::kKernelMode);
|
|
||||||
graph->set_is_from_single_op(true);
|
|
||||||
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
|
|
||||||
auto deprecated_kernel_executor =
|
|
||||||
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
|
||||||
if (deprecated_kernel_executor != nullptr) {
|
|
||||||
deprecated_kernel_executor->UnifyMindIR(graph);
|
|
||||||
} else {
|
|
||||||
opt::CommonUnifyMindIR(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select kernel and optimize
|
|
||||||
device_context->kernel_executor_->OptimizeGraph(graph);
|
|
||||||
|
|
||||||
UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
|
|
||||||
|
|
||||||
// Set dynamic shape actual abstract
|
|
||||||
SetGraphInputNodeActualAbstract(op_run_info, graph);
|
|
||||||
|
|
||||||
// Create device address for all anf nodes of graph.
|
|
||||||
CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
|
|
||||||
|
|
||||||
run_op_graphs_[op_run_info->base_op_run_info.graph_info] = graph;
|
|
||||||
|
|
||||||
auto output_nodes = graph->outputs();
|
|
||||||
auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
|
|
||||||
for (auto &node : output_nodes) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfAlgo::UpdateGraphValidRefPair(graph);
|
|
||||||
return graph->graph_id();
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info,
|
|
||||||
const KernelGraphPtr &graph) const {
|
|
||||||
// Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
|
|
||||||
// So need to get ref info to generate output addr, before create kernel.
|
|
||||||
if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
|
|
||||||
op_run_info->base_op_run_info.device_target != kGPUDevice) {
|
|
||||||
// just ascend ref mode is diff with cpu and gpu
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
AddOutInRefToGraph(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs,
|
|
||||||
const DeviceContext *device_context) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(device_context);
|
|
||||||
std::vector<CNodePtr> node_to_build;
|
|
||||||
for (const auto &graph : graphs) {
|
|
||||||
const auto &nodes = graph->execution_order();
|
|
||||||
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
|
|
||||||
}
|
|
||||||
// Kernel build
|
|
||||||
device_context->kernel_executor_->CreateKernel(node_to_build);
|
|
||||||
|
|
||||||
for (const auto &graph : graphs) {
|
|
||||||
device_context->kernel_executor_->PreprocessBeforeRun(graph);
|
|
||||||
CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
|
||||||
// Need to execute after PreprocessBeforeRunSingleOpGraph
|
|
||||||
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
|
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
return session_->GetGraph(graph_id);
|
return session_->GetGraph(graph_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
|
|
||||||
auto iter = run_op_graphs_.find(graph_info);
|
|
||||||
if (iter == run_op_graphs_.end()) {
|
|
||||||
MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::AddOutInRefToGraph(const KernelGraphPtr &graph) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
for (const auto &cnode : graph->execution_order()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
||||||
for (const auto &ref : kernel_info->out_in_ref_map()) {
|
|
||||||
size_t output_index = ref.first;
|
|
||||||
size_t input_index = ref.second;
|
|
||||||
auto final_pair = std::make_pair(cnode, output_index);
|
|
||||||
auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
|
|
||||||
MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
|
|
||||||
<< ", output index: " << final_pair.second << " to input "
|
|
||||||
<< origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
|
|
||||||
// Add to graph only if the input is not a monad.
|
|
||||||
if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
|
|
||||||
graph->AddRefCorrespondPairs(final_pair, origin_pair);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
|
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
|
||||||
MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
|
||||||
CreateParameterDeviceAddress(device_context, graph);
|
DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
|
||||||
CreateValueNodeDeviceAddress(device_context, graph);
|
DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
|
||||||
CreateKernelOutputDeviceAddress(device_context, graph, false);
|
DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, false);
|
||||||
CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
||||||
UpdateDeviceAddressForInplaceNode(graph);
|
DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
|
||||||
UpdateDeviceAddressForRefNode(graph);
|
DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
|
||||||
|
|
||||||
MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
|
MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
|
|
||||||
const DeviceContext *device_context,
|
|
||||||
bool is_gradient_out) const {
|
|
||||||
CreateParameterDeviceAddress(device_context, graph);
|
|
||||||
CreateValueNodeDeviceAddress(device_context, graph);
|
|
||||||
CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
|
|
||||||
UpdateDeviceAddressForInplaceNode(graph);
|
|
||||||
UpdateDeviceAddressForRefNode(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::GetParamAndOutputIndex(
|
void GraphCompiler::GetParamAndOutputIndex(
|
||||||
const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
|
const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
|
||||||
std::map<AnfNodePtr, size_t> *parameter_index,
|
std::map<AnfNodePtr, size_t> *parameter_index,
|
||||||
|
@ -1055,14 +618,6 @@ void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
|
||||||
session_->ClearAllBucket(graph_id);
|
session_->ClearAllBucket(graph_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
|
|
||||||
const auto &iter = run_op_graph_output_nodes_.find(graph_id);
|
|
||||||
if (iter == run_op_graph_output_nodes_.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
|
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
|
@ -1079,11 +634,6 @@ void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
|
|
||||||
(void)run_op_graphs_.erase(graph_info);
|
|
||||||
(void)run_op_graph_output_nodes_.erase(graph_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
|
void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(segment);
|
MS_EXCEPTION_IF_NULL(segment);
|
||||||
|
|
|
@ -105,19 +105,9 @@ class GraphCompiler {
|
||||||
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
||||||
GraphId CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph, const DeviceContext *device_context);
|
GraphId CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph, const DeviceContext *device_context);
|
||||||
|
|
||||||
// Construct single op kernel graph and compile the kernel graph in PyNative mode.
|
|
||||||
GraphId CompileGraph(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
|
|
||||||
const DeviceContext *device_context);
|
|
||||||
|
|
||||||
// Create kernel and Create workspace for graphs in PyNative mode.
|
|
||||||
void BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context) const;
|
|
||||||
|
|
||||||
// Get graph by graph id, if not exist return nullptr, used in Graph mode.
|
// Get graph by graph id, if not exist return nullptr, used in Graph mode.
|
||||||
KernelGraphPtr Fetch(GraphId graph_id) const;
|
KernelGraphPtr Fetch(GraphId graph_id) const;
|
||||||
|
|
||||||
// Get graph by graph info, if not exist return nullptr, used in PyNative mode.
|
|
||||||
KernelGraphPtr Fetch(const GraphInfo &graph_info) const;
|
|
||||||
|
|
||||||
// The following four methods used in PyNative back propagation to split complete kernel graph to single
|
// The following four methods used in PyNative back propagation to split complete kernel graph to single
|
||||||
// op graph, and these methods will be removed to class MindRTBackend after deleting session module.
|
// op graph, and these methods will be removed to class MindRTBackend after deleting session module.
|
||||||
|
|
||||||
|
@ -177,16 +167,11 @@ class GraphCompiler {
|
||||||
// operator.
|
// operator.
|
||||||
void ClearAllBucket(const GraphId &graph_id);
|
void ClearAllBucket(const GraphId &graph_id);
|
||||||
|
|
||||||
const std::vector<KernelWithIndex> &GetGraphOutputNodes(GraphId graph_id) const;
|
|
||||||
|
|
||||||
// Register a summary callback function, which is called in the final stages of summary.
|
// Register a summary callback function, which is called in the final stages of summary.
|
||||||
void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const;
|
void RegisterSummaryCallBackFunc(const CallBackFunc &callback) const;
|
||||||
// Execute graph summary.
|
// Execute graph summary.
|
||||||
void Summary(const std::vector<KernelGraphPtr> &graphs) const;
|
void Summary(const std::vector<KernelGraphPtr> &graphs) const;
|
||||||
|
|
||||||
// Remove single op kernel graph cache and output nodes cache.
|
|
||||||
void EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id);
|
|
||||||
|
|
||||||
// The implementation of compiling graph in Graph Mode, including optimizing graph,
|
// The implementation of compiling graph in Graph Mode, including optimizing graph,
|
||||||
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
|
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
|
||||||
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
GraphId CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||||
|
@ -195,28 +180,12 @@ class GraphCompiler {
|
||||||
private:
|
private:
|
||||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||||
|
|
||||||
// Add operators' output and input reference map to the graph.
|
|
||||||
void AddOutInRefToGraph(const KernelGraphPtr &graph) const;
|
|
||||||
|
|
||||||
// Update ref info of graph, before create kernel.
|
|
||||||
void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info,
|
|
||||||
const KernelGraphPtr &graph) const;
|
|
||||||
|
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
void CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const;
|
||||||
|
|
||||||
// Create device address for input and output of ops.
|
|
||||||
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
|
||||||
bool is_gradient_out) const;
|
|
||||||
|
|
||||||
// Set Graph's dependencies for pre_graph and post_graph.
|
// Set Graph's dependencies for pre_graph and post_graph.
|
||||||
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;
|
void SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const;
|
||||||
|
|
||||||
// Single op kernel graph cache for PyNative mode.
|
|
||||||
mindspore::HashMap<GraphInfo, KernelGraphPtr> run_op_graphs_;
|
|
||||||
// Single op kernel graph output nodes cache for PyNative mode.
|
|
||||||
mindspore::HashMap<GraphId, std::vector<KernelWithIndex>> run_op_graph_output_nodes_;
|
|
||||||
|
|
||||||
// The member variable 'session_' will be removed after removing session module.
|
// The member variable 'session_' will be removed after removing session module.
|
||||||
// Now all the GraphCompiler share the same 'session_'.
|
// Now all the GraphCompiler share the same 'session_'.
|
||||||
session::SessionPtr session_;
|
session::SessionPtr session_;
|
||||||
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/pynative/op_compiler.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
#include "runtime/pynative/op_executor.h"
|
||||||
|
#include "runtime/pynative/op_runtime_info.h"
|
||||||
|
#include "runtime/device/device_address_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using runtime::DeviceAddressUtils;
|
||||||
|
namespace pynative {
|
||||||
|
namespace {
|
||||||
|
void SetGraphInputNodeActualAbstract(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
if (!op_run_info->base_op_run_info.has_dynamic_output && !op_run_info->base_op_run_info.has_dynamic_input) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto &tensor_mask = op_run_info->base_op_run_info.input_mask;
|
||||||
|
const auto &input_tensors = op_run_info->base_op_run_info.input_tensor;
|
||||||
|
auto &graph_inputs = graph->inputs();
|
||||||
|
for (size_t i = 0, j = 0; i < op_run_info->base_op_run_info.input_tensor.size() && j < graph_inputs.size(); ++i) {
|
||||||
|
if (tensor_mask[i] == kValueNodeTensorMask) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (input_tensors[i]->base_shape_ptr() != nullptr) {
|
||||||
|
const auto &shape_of_tensor = input_tensors[i]->shape();
|
||||||
|
auto actual_abstract = std::make_shared<abstract::AbstractTensor>(input_tensors[i]->Dtype(), shape_of_tensor);
|
||||||
|
graph_inputs[j]->set_user_data(kActualAbstract, actual_abstract);
|
||||||
|
}
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
|
||||||
|
// Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
|
||||||
|
// So need to get ref info to generate output addr, before create kernel.
|
||||||
|
if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
|
||||||
|
op_run_info->base_op_run_info.device_target != kGPUDevice) {
|
||||||
|
// just ascend ref mode is diff with cpu and gpu
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfAlgo::AddOutInRefToGraph(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||||
|
bool is_gradient_out) {
|
||||||
|
DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
|
||||||
|
DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
|
||||||
|
DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
|
||||||
|
DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
|
||||||
|
DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
OpCompiler::OpCompiler() { session_ = session::SessionFactory::Get().Create(kSessionBasic); }
|
||||||
|
|
||||||
|
OpCompiler &OpCompiler::GetInstance() {
|
||||||
|
static OpCompiler instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
|
||||||
|
device::DeviceContext *device_context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
auto graph_info = op_run_info->base_op_run_info.graph_info;
|
||||||
|
auto iter = op_compiler_infos_.find(graph_info);
|
||||||
|
// Check if the graph cache exists.
|
||||||
|
auto &op_executor = runtime::OpExecutor::GetInstance();
|
||||||
|
if (iter != op_compiler_infos_.end() && op_executor.BuildQueueEmpty()) {
|
||||||
|
const auto &op_compiler_info = iter->second;
|
||||||
|
MS_EXCEPTION_IF_NULL(op_compiler_info);
|
||||||
|
SetGraphInputNodeActualAbstract(op_run_info, op_compiler_info->graph_);
|
||||||
|
*single_op_cache_hit = true;
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
*single_op_cache_hit = false;
|
||||||
|
// Generate kernel graph.
|
||||||
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
|
KernelGraphPtr graph = session_->ConstructSingleOpGraph(
|
||||||
|
op_run_info, op_run_info->base_op_run_info.input_tensor, op_run_info->base_op_run_info.input_mask,
|
||||||
|
device_context->GetDeviceType() == device::DeviceType::kAscend);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
|
||||||
|
graph->set_run_mode(device::RunMode::kKernelMode);
|
||||||
|
graph->set_is_from_single_op(true);
|
||||||
|
// session_ is SessionBasic, AscendUnifyMindIR has not been executed.
|
||||||
|
auto deprecated_kernel_executor =
|
||||||
|
dynamic_cast<device::DeprecatedKernelExecutor *>(device_context->kernel_executor_.get());
|
||||||
|
if (deprecated_kernel_executor != nullptr) {
|
||||||
|
deprecated_kernel_executor->UnifyMindIR(graph);
|
||||||
|
} else {
|
||||||
|
opt::CommonUnifyMindIR(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select kernel and optimize
|
||||||
|
device_context->kernel_executor_->OptimizeGraph(graph);
|
||||||
|
|
||||||
|
UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
|
||||||
|
|
||||||
|
// Set dynamic shape actual abstract
|
||||||
|
SetGraphInputNodeActualAbstract(op_run_info, graph);
|
||||||
|
|
||||||
|
// Create device address for all anf nodes of graph.
|
||||||
|
CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
|
||||||
|
|
||||||
|
auto output_nodes = graph->outputs();
|
||||||
|
std::vector<KernelWithIndex> outputs_with_index;
|
||||||
|
for (auto &node : output_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
|
||||||
|
}
|
||||||
|
AnfAlgo::UpdateGraphValidRefPair(graph);
|
||||||
|
|
||||||
|
auto op_compiler_info =
|
||||||
|
std::make_shared<OpCompilerInfo>(graph_info, graph->graph_id(), graph, outputs_with_index, device_context, false);
|
||||||
|
op_compiler_infos_[graph_info] = op_compiler_info;
|
||||||
|
return op_compiler_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpCompiler::BatchBuild(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
std::vector<CNodePtr> node_to_build;
|
||||||
|
for (const auto &graph : graphs) {
|
||||||
|
const auto &nodes = graph->execution_order();
|
||||||
|
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
|
||||||
|
}
|
||||||
|
// Kernel build
|
||||||
|
device_context->kernel_executor_->CreateKernel(node_to_build);
|
||||||
|
|
||||||
|
for (const auto &graph : graphs) {
|
||||||
|
device_context->kernel_executor_->PreprocessBeforeRun(graph);
|
||||||
|
DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
|
||||||
|
// Need to execute after PreprocessBeforeRunSingleOpGraph
|
||||||
|
runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpCompiler::ClearOpCache(const GraphInfo &graph_info) { op_compiler_infos_.erase(graph_info); }
|
||||||
|
|
||||||
|
void OpCompiler::ClearAllCache() { op_compiler_infos_.clear(); }
|
||||||
|
} // namespace pynative
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "backend/common/session/kernel_graph.h"
|
||||||
|
#include "backend/common/session/session_basic.h"
|
||||||
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
using device::DeviceContext;
|
||||||
|
using session::KernelWithIndex;
|
||||||
|
namespace pynative {
|
||||||
|
struct OpCompilerInfo {
|
||||||
|
OpCompilerInfo(GraphInfo graph_info, GraphId graph_id, KernelGraphPtr graph,
|
||||||
|
std::vector<KernelWithIndex> graph_output_nodes, DeviceContext *device_context, bool need_erase)
|
||||||
|
: graph_info_(std::move(graph_info)),
|
||||||
|
graph_id_(graph_id),
|
||||||
|
graph_(std::move(graph)),
|
||||||
|
graph_output_nodes_(std::move(graph_output_nodes)),
|
||||||
|
device_context_(device_context),
|
||||||
|
need_erase_(need_erase) {}
|
||||||
|
~OpCompilerInfo() = default;
|
||||||
|
GraphInfo graph_info_;
|
||||||
|
GraphId graph_id_;
|
||||||
|
KernelGraphPtr graph_;
|
||||||
|
std::vector<KernelWithIndex> graph_output_nodes_;
|
||||||
|
DeviceContext *device_context_;
|
||||||
|
bool need_erase_;
|
||||||
|
};
|
||||||
|
using OpCompilerInfoPtr = std::shared_ptr<OpCompilerInfo>;
|
||||||
|
|
||||||
|
// FuncGraph, Backend and GraphCompiler correspond one-to-one,
|
||||||
|
// and GraphCompiler stores the compilation cache of operators.
|
||||||
|
// When the graph structure changes, the front-end will send multiple graphs,
|
||||||
|
// the operators of each graph will be compiled separately, which will result in very poor performance.
|
||||||
|
// Therefore, the OpCompiler class is required to save all operator caches and make them independent of Graph.
|
||||||
|
class BACKEND_EXPORT OpCompiler {
|
||||||
|
public:
|
||||||
|
static OpCompiler &GetInstance();
|
||||||
|
|
||||||
|
// Compile RunOpInfo into a KernelGraph.
|
||||||
|
OpCompilerInfoPtr Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
|
||||||
|
device::DeviceContext *device_context);
|
||||||
|
|
||||||
|
// Clear op cache in dynamic scenes.
|
||||||
|
// Otherwise, the operator cache will keep growing, resulting in insufficient memory.
|
||||||
|
void ClearOpCache(const GraphInfo &graph_info);
|
||||||
|
|
||||||
|
// Accumulate a certain number of operators,
|
||||||
|
// and then compile the operators in parallel to improve compilation efficiency.
|
||||||
|
static void BatchBuild(const std::vector<KernelGraphPtr> &graphs, const DeviceContext *device_context);
|
||||||
|
|
||||||
|
// Clear anf resources before process exit.
|
||||||
|
void ClearAllCache();
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpCompiler();
|
||||||
|
~OpCompiler() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(OpCompiler);
|
||||||
|
|
||||||
|
// All operators shared the same session.
|
||||||
|
session::SessionPtr session_;
|
||||||
|
mindspore::HashMap<GraphInfo, OpCompilerInfoPtr> op_compiler_infos_;
|
||||||
|
};
|
||||||
|
} // namespace pynative
|
||||||
|
using OpCompilerInfoPtr = pynative::OpCompilerInfoPtr;
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_COMPILER_H_
|
|
@ -89,7 +89,7 @@ void OpExecutor::PushOpBuildTask(const std::shared_ptr<OpBuildTask> &op_build_ta
|
||||||
void OpExecutor::PushOpRunTask(const std::shared_ptr<OpTask> &op_run_task) {
|
void OpExecutor::PushOpRunTask(const std::shared_ptr<OpTask> &op_run_task) {
|
||||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||||
op_run_tasks_.push(op_run_task);
|
op_run_tasks_.push(op_run_task);
|
||||||
actor_in_queue_.insert(op_run_task->context()->graph_compiler_info()->name_);
|
actor_in_queue_.insert(op_run_task->context()->graph_id());
|
||||||
task_cond_var_.notify_all();
|
task_cond_var_.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,9 +117,9 @@ bool OpExecutor::BuildQueueFull() {
|
||||||
return op_build_tasks_.size() > kMaxQueueSize;
|
return op_build_tasks_.size() > kMaxQueueSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OpExecutor::ActorInQueue(const std::string &actor_info) {
|
bool OpExecutor::ActorInQueue(GraphId graph_id) {
|
||||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||||
auto iter = actor_in_queue_.find(actor_info);
|
auto iter = actor_in_queue_.find(graph_id);
|
||||||
return iter != actor_in_queue_.end();
|
return iter != actor_in_queue_.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ void OpExecutor::WorkerLoop() {
|
||||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||||
if (!op_run_tasks_.empty()) {
|
if (!op_run_tasks_.empty()) {
|
||||||
op_run_tasks_.pop();
|
op_run_tasks_.pop();
|
||||||
actor_in_queue_.erase(task->context()->graph_compiler_info()->name_);
|
actor_in_queue_.erase(task->context()->graph_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_run_tasks_.empty()) {
|
if (op_run_tasks_.empty()) {
|
||||||
|
|
|
@ -67,7 +67,7 @@ class BACKEND_EXPORT OpExecutor {
|
||||||
|
|
||||||
// Determine if there is another task with the same name in execution.
|
// Determine if there is another task with the same name in execution.
|
||||||
// Tasks with the same name use the same CNode cache. So we need to wait.
|
// Tasks with the same name use the same CNode cache. So we need to wait.
|
||||||
bool ActorInQueue(const std::string &actor_info);
|
bool ActorInQueue(GraphId graph_id);
|
||||||
|
|
||||||
// Wait for all OpRunTasks to finish executing.
|
// Wait for all OpRunTasks to finish executing.
|
||||||
void Wait();
|
void Wait();
|
||||||
|
@ -88,7 +88,7 @@ class BACKEND_EXPORT OpExecutor {
|
||||||
|
|
||||||
std::vector<std::shared_ptr<OpBuildTask>> op_build_tasks_;
|
std::vector<std::shared_ptr<OpBuildTask>> op_build_tasks_;
|
||||||
std::queue<std::shared_ptr<OpTask>> op_run_tasks_;
|
std::queue<std::shared_ptr<OpTask>> op_run_tasks_;
|
||||||
std::set<std::string> actor_in_queue_;
|
std::set<GraphId> actor_in_queue_;
|
||||||
std::function<void()> batch_build_callback_{nullptr};
|
std::function<void()> batch_build_callback_{nullptr};
|
||||||
inline static size_t kMaxQueueSize = 20;
|
inline static size_t kMaxQueueSize = 20;
|
||||||
bool executing_{false};
|
bool executing_{false};
|
||||||
|
|
|
@ -32,10 +32,9 @@
|
||||||
namespace mindspore::runtime {
|
namespace mindspore::runtime {
|
||||||
class OpTaskContext {
|
class OpTaskContext {
|
||||||
public:
|
public:
|
||||||
OpTaskContext(GraphCompilerInfo *graph_compiler_info, KernelGraphPtr graph,
|
OpTaskContext(GraphId graph_id, KernelGraphPtr graph, std::vector<session::KernelWithIndex> output_nodes,
|
||||||
std::vector<session::KernelWithIndex> output_nodes, session::BackendOpRunInfoPtr op_run_info,
|
session::BackendOpRunInfoPtr op_run_info, device::DeviceContext *device_context, bool is_pynative_infer)
|
||||||
device::DeviceContext *device_context, bool is_pynative_infer)
|
: graph_id_(graph_id),
|
||||||
: graph_compiler_info_(graph_compiler_info),
|
|
||||||
graph_(std::move(graph)),
|
graph_(std::move(graph)),
|
||||||
output_nodes_(std::move(output_nodes)),
|
output_nodes_(std::move(output_nodes)),
|
||||||
op_run_info_(std::move(op_run_info)),
|
op_run_info_(std::move(op_run_info)),
|
||||||
|
@ -43,7 +42,7 @@ class OpTaskContext {
|
||||||
is_pyantive_infer_(is_pynative_infer) {}
|
is_pyantive_infer_(is_pynative_infer) {}
|
||||||
~OpTaskContext() = default;
|
~OpTaskContext() = default;
|
||||||
|
|
||||||
GraphCompilerInfo *graph_compiler_info() const { return graph_compiler_info_; }
|
GraphId graph_id() const { return graph_id_; }
|
||||||
const KernelGraphPtr &graph() const { return graph_; }
|
const KernelGraphPtr &graph() const { return graph_; }
|
||||||
const std::vector<session::KernelWithIndex> &output_nodes() const { return output_nodes_; }
|
const std::vector<session::KernelWithIndex> &output_nodes() const { return output_nodes_; }
|
||||||
const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; }
|
const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; }
|
||||||
|
@ -51,7 +50,7 @@ class OpTaskContext {
|
||||||
bool is_pynative_infer() const { return is_pyantive_infer_; }
|
bool is_pynative_infer() const { return is_pyantive_infer_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphCompilerInfo *graph_compiler_info_;
|
GraphId graph_id_;
|
||||||
KernelGraphPtr graph_;
|
KernelGraphPtr graph_;
|
||||||
std::vector<session::KernelWithIndex> output_nodes_;
|
std::vector<session::KernelWithIndex> output_nodes_;
|
||||||
session::BackendOpRunInfoPtr op_run_info_;
|
session::BackendOpRunInfoPtr op_run_info_;
|
||||||
|
|
Loading…
Reference in New Issue