!46664 fixed mem bugs for 2.0

Merge pull request !46664 from wanghenchang/master_2.0_1210
This commit is contained in:
i-robot 2022-12-12 01:09:51 +00:00 committed by Gitee
commit 8a211d9edf
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 3 deletions

View File

@ -320,7 +320,9 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden()); parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]); auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout); auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]); const auto &clone_value = ShallowCopyTensorValue(input_param_values[i]);
ClearDeviceAddress(clone_value);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, clone_value);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint)); (void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
} }
} }
@ -469,6 +471,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
MS_EXCEPTION_IF_NULL(sens_out); MS_EXCEPTION_IF_NULL(sens_out);
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString(); MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
last_node_ = output_node; last_node_ = output_node;
ClearDeviceAddress(sens_out);
sens_value_ = sens_out; sens_value_ = sens_out;
} }

View File

@ -40,11 +40,18 @@ void AsyncQueue::WorkerLoop() {
while (true) { while (true) {
std::shared_ptr<AsyncTask> task; std::shared_ptr<AsyncTask> task;
bool task_empty = false;
{ {
MS_LOG(DEBUG) << "Wait task in queue"; MS_LOG(DEBUG) << "Wait task in queue";
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this]() { return !tasks_.empty(); }); task_empty = tasks_.empty();
task = tasks_.front(); if (!task_empty) {
task = tasks_.front();
}
}
if (task_empty) {
std::this_thread::yield();
continue;
} }
MS_LOG(DEBUG) << "Get task"; MS_LOG(DEBUG) << "Get task";