!17839 Fix device memory can not release in PyNative mode

Merge pull request !17839 from zyli2020/fix_issue_defect
This commit is contained in:
i-robot 2021-06-10 11:37:13 +08:00 committed by Gitee
commit 9852cced86
24 changed files with 278 additions and 213 deletions

View File

@ -35,7 +35,10 @@ DatasetIteratorKernel::DatasetIteratorKernel()
DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); }
void DatasetIteratorKernel::ReleaseResource() { GpuBufferMgr::GetInstance().Close(handle_); }
void DatasetIteratorKernel::ReleaseResource() {
GpuBufferMgr::GetInstance().Close(handle_);
handle_ = HandleMgr::INVALID_HANDLE;
}
const std::vector<size_t> &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; }
@ -122,6 +125,13 @@ bool DatasetIteratorKernel::ReadDevice(void **addr, size_t *len) {
bool DatasetIteratorKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream) {
if (handle_ == HandleMgr::INVALID_HANDLE) {
handle_ = GpuBufferMgr::GetInstance().Open(0, queue_name_, output_size_list_);
if (handle_ == HandleMgr::INVALID_HANDLE) {
MS_LOG(EXCEPTION) << "Gpu Queue(" << queue_name_ << ") Open Failed";
}
}
void *addr = nullptr;
size_t len = 0;
if (!ReadDevice(&addr, &len)) {

View File

@ -41,6 +41,11 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (init_seed_ == 0 && cur_seed_ == 0) {
// Update current seed.
cur_seed_ = time(NULL);
generator_.seed(cur_seed_);
}
VARIABLE_NOT_USED(workspaces);
T *sampled_candidates = GetDeviceAddress<T>(outputs, 0);
S *true_expected_count = GetDeviceAddress<S>(outputs, 1);
@ -87,10 +92,15 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
num_sampled_ = GetAttr<int64_t>(kernel_node, "num_sampled");
unique_ = GetAttr<bool>(kernel_node, "unique");
range_max_ = GetAttr<int64_t>(kernel_node, "range_max");
int64_t seed = GetAttr<int64_t>(kernel_node, "seed");
remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits");
if (seed == 0) seed = time(NULL);
generator_.seed(seed);
init_seed_ = GetAttr<int64_t>(kernel_node, "seed");
if (init_seed_ == 0) {
cur_seed_ = time(NULL);
generator_.seed(cur_seed_);
} else {
generator_.seed(init_seed_);
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformCandidateSampler supports only 2-D inputs.";
@ -104,6 +114,13 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
return true;
}
void ReleaseResource() override {
// Reset current seed.
if (init_seed_ == 0 && cur_seed_ != 0) {
cur_seed_ = 0;
}
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
@ -174,6 +191,8 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int64_t init_seed_{0};
int64_t cur_seed_{0};
};
} // namespace kernel
} // namespace mindspore

View File

@ -562,7 +562,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() &&
kOpCacheAllowList.find(op_run_info.op_name) == kOpCacheAllowList.end()) {
kOpCacheBlackList.find(op_run_info.op_name) == kOpCacheBlackList.end()) {
return;
}
// Prepare the graph
@ -606,7 +606,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
UpdateOutputAbstract(kernel_graph, op_run_info);
}
RunOpClearMemory(kernel_graph.get());
if (kOpCacheAllowList.find(op_run_info->op_name) != kOpCacheAllowList.end()) {
if (kOpCacheBlackList.find(op_run_info->op_name) != kOpCacheBlackList.end()) {
run_op_graphs_.erase(graph_info);
}
}

View File

@ -1344,6 +1344,13 @@ std::string KernelGraph::ToString() const { return std::string("kernel_graph_").
KernelGraph::~KernelGraph() {
try {
// Release the kernel resource.
for (const auto &kernel : execution_order_) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
if (kernel_mod != nullptr) {
kernel_mod->ReleaseResource();
}
}
device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_, *inputs_, graph_value_nodes_,
execution_order_);
} catch (const std::exception &e) {

View File

@ -312,7 +312,6 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
AnfAlgo::SetOutputAddr(device_address, 0, param.get());
}
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter
@ -425,92 +424,6 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
}
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
const auto &node = kernel_with_index.first;
if (node->isa<CNode>()) {
(*ref_count)[kernel_with_index] += 1;
}
}
}
}
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
// Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
ref_iter->second -= 1;
if (ref_iter->second != 0) {
continue;
}
ref_count->erase(ref_iter);
auto output_iter = op_output_map->find(kernel_with_index);
if (output_iter == op_output_map->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
op_output_map->erase(output_iter);
}
}
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, GraphOutputInfo *graph_output_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(graph_output_info);
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
if (iter == graph_output_info->output_indexes.end()) {
continue;
}
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
}
}
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -1301,6 +1214,94 @@ void SessionBasic::CreateOutputPlaceholder(
}
}
void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
const auto &node = kernel_with_index.first;
if (node->isa<CNode>()) {
(*ref_count)[kernel_with_index] += 1;
}
}
}
}
void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
// Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
ref_iter->second -= 1;
if (ref_iter->second != 0) {
continue;
}
ref_count->erase(ref_iter);
auto output_iter = op_output_map->find(kernel_with_index);
if (output_iter == op_output_map->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
op_output_map->erase(output_iter);
}
}
void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
GraphOutputInfo *graph_output_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(graph_output_info);
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
if (iter == graph_output_info->output_indexes.end()) {
continue;
}
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
}
}
}
TensorPtr SessionBasic::GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {

View File

@ -181,6 +181,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, GraphOutputInfo *graph_output_info);
protected:
friend class Executor;

View File

@ -43,6 +43,7 @@
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "runtime/framework/actor/actor_common.h"
#if (ENABLE_CPU && !_WIN32)
#include "ps/parameter_server.h"
#include "ps/scheduler.h"
@ -81,9 +82,10 @@ void ExecuteActionForMindRT(const ResourcePtr &res) {
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size " << args.size();
auto outs = mindrt_bc_ptr->RunGraph(actor_info, args);
MS_LOG(DEBUG) << "out size " << outs.size();
return outs[0];
VectorRef outputs;
mindrt_bc_ptr->RunGraph(actor_info, args, &outputs);
MS_LOG(DEBUG) << "out size " << outputs.size();
return outputs[0];
});
res->results()[kOutput] = run;
}
@ -550,7 +552,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
}
// The graph compiling of mindRT.
if ((backend == kMsConvert) && compile::IsMindRTUsed()) {
if ((backend == kMsConvert) && IsMindRTUsed()) {
TaskEmitActionForMindRT(res);
return true;
}
@ -580,7 +582,7 @@ bool ExecuteAction(const ResourcePtr &res) {
std::string backend = MsContext::GetInstance()->backend_policy();
// The graph running of mindRT.
if ((backend == kMsConvert) && compile::IsMindRTUsed()) {
if ((backend == kMsConvert) && IsMindRTUsed()) {
ExecuteActionForMindRT(res);
return true;
}

View File

@ -51,6 +51,7 @@
#include "load_mindir/load_model.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/framework/actor/actor_common.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/constants.h"
@ -1090,13 +1091,14 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
auto backend = compile::CreateBackend();
MS_EXCEPTION_IF_NULL(backend);
// The data set graph compiling and running of mindRT.
if (compile::IsMindRTUsed()) {
if (IsMindRTUsed()) {
const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
MS_EXCEPTION_IF_NULL(mindrt_backend);
auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
VectorRef args;
if (need_run) {
(void)mindrt_backend->RunGraph(actor_info, args);
VectorRef outputs;
mindrt_backend->RunGraph(actor_info, args, &outputs);
}
ConfigManager::GetInstance().set_iter_num(size);
return true;

View File

@ -65,6 +65,8 @@
#endif
#include "debug/anf_ir_dump.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/hardware/device_context_manager.h"
using mindspore::tensor::TensorPy;
@ -1688,7 +1690,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
op_exec_info->next_input_index};
#endif
VectorRef outputs;
if (!compile::IsMindRTUsed()) {
if (!IsMindRTUsed()) {
kSession->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
} else {
if (mind_rt_backend == nullptr) {
@ -1698,7 +1700,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
}
const compile::ActorInfo &actor_info =
mind_rt_backend->CompileGraph(op_run_info, graph_info, &tensors_mask, &input_tensors);
outputs = mind_rt_backend->RunGraph(actor_info, &op_run_info, &tensors_mask, &input_tensors);
mind_rt_backend->RunGraph(actor_info, &op_run_info, &tensors_mask, &input_tensors, &outputs);
}
if (op_exec_info->is_dynamic_shape) {
@ -2788,10 +2790,22 @@ void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::obj
}
void PynativeExecutor::Sync() {
if (kSession == nullptr) {
MS_EXCEPTION(NotExistsError) << "No session has been created!";
if (!IsMindRTUsed()) {
if (kSession == nullptr) {
MS_EXCEPTION(NotExistsError) << "No session has been created!";
}
kSession->SyncStream();
} else {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
const auto &device_context =
device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
MS_EXCEPTION_IF_NULL(device_context);
(void)device_context->SyncStream();
}
kSession->SyncStream();
}
void PynativeExecutor::EnterConstruct(const py::object &cell) {

View File

@ -281,17 +281,8 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) {
MS_LOG(INFO) << "Clear graph:" << graph_id << " GPU runtime resource";
// Release the kernel resource.
for (const auto &kernel : execution_order) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
if (kernel_mod == nullptr) {
continue;
}
kernel_mod->ReleaseResource();
}
// Clear the output address of graph.
ClearOutputAddress(inputs, value_nodes, execution_order);
graph_output_map_.erase(graph_id);
}

View File

@ -126,4 +126,8 @@ bool IsGatherActor(const AnfNodePtr &front_node,
return false;
}
} // namespace runtime
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
// Return false in the transitional stage.
bool IsMindRTUsed() { return false; }
} // namespace mindspore

View File

@ -74,6 +74,9 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node);
bool IsGatherActor(const AnfNodePtr &front_node,
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor_);
} // namespace runtime
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
bool IsMindRTUsed();
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ACTOR_COMMON_H_

View File

@ -75,7 +75,7 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
if (CheckLaunchCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_context_->UpdateKernelDynamicShape(kernel_);
device_context_->UpdateDynamicShape(kernel_);
}
FetchInputDeviceTensor(context);
@ -92,7 +92,7 @@ void KernelActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *cont
if (CheckLaunchCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_context_->UpdateKernelDynamicShape(kernel_);
device_context_->UpdateDynamicShape(kernel_);
}
FetchInputDeviceTensor(context);

View File

@ -108,8 +108,12 @@ void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, cons
}
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
bool is_graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
if (is_graph_mode || is_pynative_infer) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
}
continue;
}
@ -372,50 +376,24 @@ void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const
*graph_info = session_->GetSingleOpGraphInfo(kernel, input_tensors);
}
void GraphCompiler::RecoverGraphOutput(
const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs,
std::vector<TensorPtr> *runop_output_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
std::vector<TensorPtr> output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = std::make_pair(kernel, out_index++);
(*op_output_map)[kernel_with_index] = output_tensor;
const auto &iter = output_indexes.find(kernel_with_index);
if (iter == output_indexes.end()) {
continue;
}
void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
MS_EXCEPTION_IF_NULL(session_);
session_->GetRefCount(graph.get(), ref_count);
}
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
MS_EXCEPTION_IF_NULL(session_);
session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
runop_output_tensors->emplace_back(output_tensor);
}
}
void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, TensorPtr> *op_output_map,
GraphOutputInfo *graph_output_info) const {
MS_EXCEPTION_IF_NULL(session_);
session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
}
void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {

View File

@ -22,6 +22,7 @@
#include <string>
#include <unordered_map>
#include <map>
#include <set>
#include "runtime/hardware/device_context.h"
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
@ -30,6 +31,7 @@
namespace mindspore {
using device::DeviceContext;
using session::CallBackFunc;
using session::GraphOutputInfo;
using session::InputTensorInfo;
using session::KernelGraph;
using session::KernelWithIndex;
@ -76,11 +78,19 @@ class GraphCompiler {
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
OpRunInfo *run_info, GraphInfo *graph_info);
// Calculate ref count of PyNative back propagation operators.
void CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const;
// Update ref count of PyNative back propagation operators.
void UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const;
// Handle single op output tensor and recover output of original complete kernel graph.
void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs,
std::vector<TensorPtr> *runop_output_tensors);
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, TensorPtr> *op_output_map,
GraphOutputInfo *graph_output_info) const;
// Collect output tensors of back propagation graph for allreduce operators to average gradient,
// used in PyNative distributed training mode.

View File

@ -1923,7 +1923,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
// Check the data source actors.
for (const auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
if (data_source_actor->output_data_arrows_.size() == 0) {
if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() == 0) {
MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
return false;
}

View File

@ -87,7 +87,7 @@ class DeviceContext {
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const = 0;
// Infer kernel shape and update abstract info for dynamic shape kernel.
virtual void UpdateKernelDynamicShape(const CNodePtr &kernel) const { AnfAlgo::InferShape(kernel); }
virtual void UpdateDynamicShape(const CNodePtr &kernel) const { AnfAlgo::InferShape(kernel); }
// Launch a kernel via 'KernelMod' of the kernel.
virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,

View File

@ -292,7 +292,7 @@ void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const
void GPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const { CreateGPUKernel(nodes); }
void GPUDeviceContext::UpdateKernelDynamicShape(const CNodePtr &kernel) const {
void GPUDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);

View File

@ -57,7 +57,7 @@ class GPUDeviceContext : public DeviceContext {
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
// Infer kernel shape and update abstract info for dynamic shape kernel.
void UpdateKernelDynamicShape(const CNodePtr &kernel) const override;
void UpdateDynamicShape(const CNodePtr &kernel) const override;
bool LaunchKernel(const CNodePtr &kernel, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,

View File

@ -552,7 +552,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
const std::set<std::string> kOpCacheAllowList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
kGetNextOpName};
const std::set<std::string> kHWSpecialFormatSet = {

View File

@ -346,23 +346,29 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
device_context->Initialize();
bool single_op_cache_hit;
(void)graph_compiler_->CompileGraph(op_run_info, graph_info, tensors_mask, input_tensors, &single_op_cache_hit,
device_context);
auto graph_id = graph_compiler_->CompileGraph(op_run_info, graph_info, tensors_mask, input_tensors,
&single_op_cache_hit, device_context);
// The actor set name: graph_id + single operator name.
std::string actor_info = std::to_string(graph_id) + "_" + op_run_info.op_name;
if (single_op_cache_hit) {
return graph_info;
auto iter = actor_to_graph_compiler_info_.find(actor_info);
if (iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
}
return iter->first;
}
graph_info_to_device_context_.clear();
graph_info_to_device_context_[graph_info] = device_context;
auto graph_compiler_info = ConstructGraphCompilerInfo(tensors_mask, input_tensors);
auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, tensors_mask, input_tensors);
const auto actor_set =
runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info, runtime::GraphExecutionStrategy::kStep);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
graph_compiler_info->input_tensors_.clear();
actor_to_graph_compiler_info_.emplace(actor_set->name_, std::move(graph_compiler_info));
return actor_set->name_;
auto ret = actor_to_graph_compiler_info_.emplace(actor_info, std::move(graph_compiler_info));
return ret.first->first;
}
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
@ -370,11 +376,16 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
MS_EXCEPTION_IF_NULL(graph_compiler_);
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
const auto &graph = graphs[graph_index];
std::map<AnfNodePtr, size_t> parameter_index;
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, &parameter_index, &output_indexes);
std::map<AnfNodePtr, size_t> parameter_index;
GraphOutputInfo graph_output_info;
graph_output_info.graph_outputs = outputs;
graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, &parameter_index,
&graph_output_info.output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref_count;
graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
// Clear bucket resources every step
if (graph->is_bprop()) {
@ -392,28 +403,30 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
const ActorInfo &actor_info =
CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
VectorRef op_outputs =
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
VectorRef op_outputs;
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
&op_outputs);
std::vector<tensor::TensorPtr> new_output_tensors;
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, output_indexes, &op_output_map, outputs,
&new_output_tensors);
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
graph_output_info.graph_output_tensors.clear();
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (graph->is_bprop()) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), new_output_tensors);
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
}
}
}
}
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args) {
void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
MS_LOG(INFO) << "Run actor begin, actor name: " << actor_info;
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return VectorRef();
return;
}
// Fetch the graph compiler info.
@ -448,12 +461,12 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
input_tensors.emplace_back(input_tensor);
// Run in the pynative mode.
VectorRef outputs;
MS_EXCEPTION_IF_NULL(outputs);
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (pynative_mode) {
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, &outputs);
return outputs;
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
return;
}
mindspore::ScopedLongRunning long_running;
@ -485,14 +498,13 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
VectorRef tmp;
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(tmp.elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
outputs.emplace_back(std::move(tmp));
outputs->emplace_back(std::move(tmp));
} else if (output_tensors.size() == 1) {
outputs.emplace_back(std::move(output_tensors.front()));
outputs->emplace_back(std::move(output_tensors.front()));
}
MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;
graph_compiler_->Summary(graph_compiler_info.graphs_);
return outputs;
}
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
@ -552,19 +564,18 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
}
std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
const std::vector<int64_t> *tensors_mask, const std::vector<tensor::TensorPtr> *input_tensors) {
const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(graph_compiler_);
std::vector<KernelGraphPtr> graphs;
std::vector<DeviceContext *> device_contexts;
runtime::KernelMapPosition outputs_order;
size_t position = 0;
std::string name;
for (const auto &graph_info_to_context : graph_info_to_device_context_) {
const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
graphs.emplace_back(graph);
device_contexts.emplace_back(graph_info_to_context.second);
name.append(graph_info_to_context.first);
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
for (const auto &output : outputs) {
@ -578,12 +589,12 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
return std::make_unique<GraphCompilerInfo>(
graphs, device_contexts, tensors_mask_list, input_tensors_list, std::vector<AnfNodePtr>(),
std::vector<AnfNodePtr>(), outputs_order, FrontToBackendNodeWithContext(), FuncGraphToParameter(),
HostParameterToWeight(), std::vector<AnfNodePtr>(), outputs_order.size(), name);
HostParameterToWeight(), std::vector<AnfNodePtr>(), outputs_order.size(), actor_info);
}
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,
const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors) {
void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,
const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs) {
const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
if (graph_iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
@ -619,18 +630,28 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run
}
// Fetch outputs.
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
VectorRef outputs;
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(outputs.elements_),
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(outputs->elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
// update output abstract of dynamic op to op_run_info
// Update output abstract of dynamic op to op_run_info
if (op_run_info->is_dynamic_shape) {
UpdateOutputAbstract(graph_compiler_info.graphs_.front(), op_run_info);
}
return outputs;
// Release the kernel resource.
const auto &graph = graph_compiler_info.graphs_.front();
MS_EXCEPTION_IF_NULL(graph);
const auto &kernel = graph->execution_order().front();
MS_EXCEPTION_IF_NULL(kernel);
if (kOpCacheBlackList.find(AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
if (kernel_mod) {
kernel_mod->ReleaseResource();
}
}
}
} // namespace compile
} // namespace mindspore

View File

@ -35,6 +35,7 @@
namespace mindspore {
namespace compile {
using OpRunInfo = session::OpRunInfo;
using GraphOutputInfo = session::GraphOutputInfo;
using DeviceContext = device::DeviceContext;
using ActorInfo = runtime::ActorInfo;
using GraphCompiler = runtime::GraphCompiler;
@ -114,11 +115,11 @@ class MindRTBackend : public Backend {
std::vector<tensor::TensorPtr> *input_tensors);
// Run Graph in the graph mode.
VectorRef RunGraph(const ActorInfo &actor_info, const VectorRef &args);
void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs);
// Run Graph in the pyNative mode.
VectorRef RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info, const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors);
void RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info, const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs);
private:
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
@ -129,7 +130,8 @@ class MindRTBackend : public Backend {
std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
// Construct the GraphCompilerInfo by the compilation results of graph, used in PyNative mode.
std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const std::vector<int64_t> *tensors_mask,
std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const ActorInfo &actor_info,
const std::vector<int64_t> *tensors_mask,
const std::vector<tensor::TensorPtr> *input_tensors);
// Split complete kernel graph to single op graph in PyNative back

View File

@ -32,6 +32,7 @@
#include "utils/ms_context.h"
#include "debug/trace.h"
#include "debug/anf_ir_dump.h"
#include "runtime/framework/actor/actor_common.h"
namespace mindspore {
namespace compile {
@ -521,10 +522,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
return rt;
}
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
// Return false in the transitional stage.
bool IsMindRTUsed() { return false; }
BackendPtr CreateBackend() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -131,9 +131,6 @@ class CompileGraphs {
BackendPtr backend_;
};
// Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
bool IsMindRTUsed();
BackendPtr CreateBackend();
} // namespace compile