!31201 Bugfix for PyNative Refactor

Merge pull request !31201 from caifubi/master-pynative-no-actor
This commit is contained in:
i-robot 2022-03-12 16:38:59 +00:00 committed by Gitee
commit e856f186ab
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 34 additions and 0 deletions

View File

@ -1314,6 +1314,10 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph, const OpRunInf
}
}
ReleaseForwardOutput(input_tensors);
}
void MindRTBackend::ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors) {
// Update forward op output ref counts, release it
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -1377,6 +1381,7 @@ void MindRTBackend::LazyExecuteTaskCallback() {
auto tensor_without_value_mask = GetTensorWithoutValueMask(context->op_run_info());
runtime::RunSingleOpGraph(context->graph(), tensor_without_value_mask, context->device_context(),
context->op_run_info().is_dynamic_shape);
ReleaseForwardOutput(context->op_run_info().input_tensors);
ClearGraphDeviceAddress(context->graph(), context->device_context(), context->op_run_info().is_gradient_out);
ClearInputDeviceAddress(context->graph(), context->device_context());
op_lazy_builder.PopOpRunTask();
@ -1435,6 +1440,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g
runtime::UpdateDeviceAddress(graph, tensor_without_value_mask, device_context);
runtime::RunSingleOpGraph(graph, tensor_without_value_mask, device_context, op_run_info->is_dynamic_shape);
ReleaseForwardOutput(op_run_info->input_tensors);
UpdateOutput(output_nodes, outputs);
ClearGraphDeviceAddress(graph, device_context, op_run_info->is_gradient_out);
ClearInputDeviceAddress(graph, device_context);

View File

@ -180,6 +180,8 @@ class MindRTBackend : public Backend {
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
void ReleaseForwardOutput(const std::vector<TensorPtr> &input_tensors);
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
// the corresponding device_context.

View File

@ -75,6 +75,7 @@ void UpdateInputNodeDeviceAddress(const std::vector<AnfNodePtr> &input_nodes,
if (tensor_address == nullptr) {
input_tensor->set_device_address(node_address);
input_tensor->set_sync_status(kNeedSyncHostToDeviceImmediately);
node_address->set_from_persistent_mem(input_tensor->is_parameter());
}
// The DeviceType and format of DeviceAddress is always the same after UpdateInputTensor
@ -392,6 +393,29 @@ void LaunchKernels(const KernelGraphPtr &graph, const device::DeviceContext *dev
}
MS_LOG(DEBUG) << "End";
}
void WaitCommunicationFinish(const std::vector<tensor::TensorPtr> &input_tensors) {
for (auto &input_tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensor);
if (input_tensor->NeedWaitDevice()) {
input_tensor->WaitDevice();
}
}
}
void ReleaseKernelResource(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
const auto &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
MS_EXCEPTION_IF_NULL(kernel);
if (kOpCacheBlackList.find(common::AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
if (kernel_mod) {
kernel_mod->ReleaseResource();
}
}
}
}
} // namespace
// Determine the address of the graph and do not change the address in subsequent executions
@ -408,8 +432,10 @@ void UpdateDeviceAddress(const KernelGraphPtr &graph, const std::vector<tensor::
void RunSingleOpGraph(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
const device::DeviceContext *device_context, bool is_dynamic_shape) {
WaitCommunicationFinish(input_tensors);
MallocForKernel(graph, device_context);
CopyDataToDevice(graph, input_tensors, device_context);
LaunchKernels(graph, device_context, is_dynamic_shape);
ReleaseKernelResource(graph);
}
} // namespace mindspore::runtime