!18656 Fix bug of resnet50 512 batch size memory not enough

Merge pull request !18656 from JoyLvliang/fix_bug_of_resnet50_512_batch_size_memory_not_enough
This commit is contained in:
i-robot 2021-06-24 01:06:41 +00:00 committed by Gitee
commit 0768ae686e
3 changed files with 30 additions and 11 deletions

View File

@ -720,11 +720,20 @@ void TopCellInfo::ClearDeviceMemory() {
}
k_pynative_cell_ptr_ = nullptr;
for (const auto &elem : tensor_id_with_tensor_object_) {
std::for_each(elem.second.begin(), elem.second.end(), [](const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
tensor->set_device_address(nullptr);
});
// Get all tensors obj in value node of running graph
std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
MS_EXCEPTION_IF_NULL(resource_);
const auto &bprop_graph = resource_->func_graph();
MS_EXCEPTION_IF_NULL(bprop_graph);
const auto &value_node_list = bprop_graph->value_nodes();
for (const auto &elem : value_node_list) {
auto value_node = elem.first->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
}
for (const auto &elem : tensors_in_bprop_graph) {
MS_EXCEPTION_IF_NULL(elem);
elem->set_device_address(nullptr);
}
}
@ -1513,6 +1522,13 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
MS_EXCEPTION_IF_NULL(resource);
// Get all tensors id of forward op
std::unordered_set<std::string> forward_op_tensor_id;
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
for (const auto &record : op_info_with_tensor_id) {
std::for_each(record.second.begin(), record.second.end(),
[&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); });
}
// Get all tensors obj in value node of bprop graph
const auto &bprop_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(bprop_graph);
@ -1532,6 +1548,9 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
// Save tensor in value node of bprop graph
for (const auto &tensor : tensors_in_bprop_graph) {
MS_EXCEPTION_IF_NULL(tensor);
if (forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end() || tensor->device_address() == nullptr) {
continue;
}
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address() << " shape and dtype "
@ -2084,7 +2103,7 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p
}
void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g,
const py::object &out) {
const py::object &out, const std::string &out_id) {
MS_EXCEPTION_IF_NULL(curr_g);
if (!(py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out))) {
MS_LOG(EXCEPTION) << "The out of top cell should be tuple or list when set maketuple as output node";
@ -2101,7 +2120,6 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
auto cnode = curr_g_->NewCNode(inputs);
MS_LOG(DEBUG) << "Tuple output node info " << cnode->DebugString();
// record node info in graph map
auto out_id = GetId(out);
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
@ -2131,7 +2149,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
auto tuple_out = py::cast<py::tuple>(out);
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, tuple_out);
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, tuple_out, out_id);
} else {
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
MakeValueNode(out, out_id);

View File

@ -250,7 +250,8 @@ class GradExecutor {
const std::vector<int64_t> &index) {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
}
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out);
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out,
const std::string &out_id);
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
private:

View File

@ -753,7 +753,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
if (!mock) {
LaunchKernelWithoutMock(graph, kernel, kernel_inputs, kernel_workspaces, kernel_outputs, profiling);
if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
if (gpu_kernel != nullptr && dynamic_kernel != nullptr && dynamic_kernel->is_dynamic_shape()) {
gpu_kernel->PostExecute();
}
@ -844,7 +844,7 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
}
if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
if (gpu_kernel != nullptr && dynamic_kernel != nullptr && dynamic_kernel->is_dynamic_shape()) {
gpu_kernel->PostExecute();
}
}