!20336 [bugfix]GPU occur oom when cache all output tensor of graph

Merge pull request !20336 from zyli2020/mindrt_debug
This commit is contained in:
i-robot 2021-07-16 01:39:06 +00:00 committed by Gitee
commit 8bce519e9b
4 changed files with 18 additions and 1 deletions

View File

@ -342,6 +342,11 @@ py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
tensor.Wait();
}
tensor.data_sync();
// Release device address of graph output tensor.
if (tensor.need_release_device_mem()) {
const_cast<Tensor &>(tensor).set_device_address(nullptr);
}
}
return AsNumpy(tensor);
}

View File

@ -117,7 +117,10 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
if (output_position >= outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
}
outputs_[output_position] = CreateOutputTensor(output_node, output_index, output_position);
auto tensor = CreateOutputTensor(output_node, output_index, output_position);
tensor->set_need_release_device_mem(true);
outputs_[output_position] = tensor;
current_outputs_num_++;
// Save the output nodes to clear the device tensor in the running end.

View File

@ -473,6 +473,7 @@ Tensor::Tensor(const Tensor &tensor)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
need_release_device_mem_(tensor.need_release_device_mem_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
@ -487,6 +488,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
need_release_device_mem_(tensor.need_release_device_mem_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
@ -548,6 +550,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
device_sync_ = tensor.device_sync_;
need_release_device_mem_ = tensor.need_release_device_mem_;
data_ = tensor.data_;
id_ = tensor.id_;
event_ = tensor.event_;

View File

@ -294,6 +294,10 @@ class Tensor : public MetaTensor {
device_sync_->ResetRefCount();
}
}
bool need_release_device_mem() const { return need_release_device_mem_; }
void set_need_release_device_mem(bool release_device_mem) { need_release_device_mem_ = release_device_mem; }
void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; }
std::string padding_type() const { return padding_type_; }
@ -375,6 +379,8 @@ class Tensor : public MetaTensor {
bool graph_output_{false};
bool updated_by_device_{false};
DeviceSyncPtr device_sync_{nullptr};
// Release device address of graph output tensor or not.
bool need_release_device_mem_{false};
bool cache_enable_{false};
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};