fix bug of CPU not support more then 2G copy

This commit is contained in:
limingqi107 2021-07-09 13:20:57 +08:00
parent 3a408218ec
commit ae46d0beb9
7 changed files with 74 additions and 11 deletions

View File

@ -81,5 +81,35 @@ void IntToLong(void *dst, const void *src, size_t elem_num) {
long_data[i] = static_cast<int64_t>(int_data[i]);
}
}
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type) {
if (type == kNumberTypeFloat16) {
auto dst_data = static_cast<float16 *>(dst);
auto src_data = static_cast<const float16 *>(src);
ConvertSameType(dst_data, src_data, size >> 1);
} else if (type == kNumberTypeFloat32) {
auto dst_data = static_cast<float *>(dst);
auto src_data = static_cast<const float *>(src);
ConvertSameType(dst_data, src_data, size / sizeof(float));
} else if (type == kNumberTypeFloat64) {
auto dst_data = static_cast<double *>(dst);
auto src_data = static_cast<const double *>(src);
ConvertSameType(dst_data, src_data, size / sizeof(double));
} else if (type == kNumberTypeInt16) {
auto dst_data = static_cast<int16_t *>(dst);
auto src_data = static_cast<const int16_t *>(src);
ConvertSameType(dst_data, src_data, size >> 1);
} else if (type == kNumberTypeInt32) {
auto dst_data = static_cast<int *>(dst);
auto src_data = static_cast<const int *>(src);
ConvertSameType(dst_data, src_data, size / sizeof(int));
} else if (type == kNumberTypeInt64) {
auto dst_data = static_cast<int64_t *>(dst);
auto src_data = static_cast<const int64_t *>(src);
ConvertSameType(dst_data, src_data, size / sizeof(int64_t));
} else {
MS_LOG(EXCEPTION) << "Invalid Type: " << TypeIdLabel(type);
}
}
} // namespace device
} // namespace mindspore

View File

@ -31,6 +31,14 @@ void ShortToInt(void *dst, const void *src, size_t elem_num);
void IntToShort(void *dst, const void *src, size_t elem_num);
void LongToInt(void *dst, const void *src, size_t elem_num);
void IntToLong(void *dst, const void *src, size_t elem_num);
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type);
template <typename T>
void ConvertSameType(T *dst, const T *src, size_t elem_num) {
for (size_t i = 0; i < elem_num; ++i) {
dst[i] = src[i];
}
}
} // namespace device
} // namespace mindspore

View File

@ -17,11 +17,23 @@
#include <vector>
#include <memory>
#include "runtime/device/convert_tensor_utils.h"
#include "runtime/hardware/cpu/cpu_memory_pool.h"
#include "debug/data_dump/dump_json_parser.h"
namespace mindspore {
namespace device {
namespace cpu {
CPUDeviceAddress::~CPUDeviceAddress() {
if (ptr_ == nullptr) {
return;
}
if (from_mem_pool_) {
CPUMemoryPool::GetInstance().FreeTensorMem(ptr_);
ptr_ = nullptr;
}
}
bool CPUDeviceAddress::DumpMemToFile(const std::string &filepath, const std::string &, const ShapeVector &host_shape,
TypeId host_type, bool) const {
bool ret = false;
@ -56,7 +68,10 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
return true;
}
auto ret_code = memcpy_s(host_ptr, size, ptr_, size);
if (ret_code != EOK) {
// Return ERANGE when the copy size is larger than SECUREC_MEM_MAX_LEN.
if (ret_code == ERANGE) {
ConvertSameType(host_ptr, ptr_, size, type);
} else if (ret_code != EOK) {
MS_LOG(ERROR) << "Failed to copy tensor!";
return false;
}
@ -94,14 +109,17 @@ bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /* shape */, size_t
if (type == type_id_) {
if (size > size_) {
MS_LOG(INFO) << "No need sync, host size: " << size << ", device size: " << size_;
MS_LOG(WARNING) << "No need sync, host size: " << size << ", device size: " << size_;
return true;
}
auto ret_code = memcpy_s(ptr_, size, host_ptr, size);
if (ret_code != EOK) {
MS_LOG(ERROR) << "Failed to copy tensor!";
return false;
// Use the tensor host ptr to set the device ptr.
if (from_mem_pool_) {
CPUMemoryPool::GetInstance().FreeTensorMem(ptr_);
from_mem_pool_ = false;
}
ptr_ = const_cast<void *>(host_ptr);
original_ref_count_ = SIZE_MAX;
ref_count_ = SIZE_MAX;
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat16) {
HalfToFloat(ptr_, host_ptr, size >> 1);
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {

View File

@ -34,7 +34,7 @@ class CPUDeviceAddress : public DeviceAddress {
CPUDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
~CPUDeviceAddress() override = default;
~CPUDeviceAddress() override;
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,

View File

@ -75,6 +75,7 @@ class DeviceAddress : public mindspore::DeviceSync {
void SetSize(size_t size) { size_ = size; }
std::string format() const { return format_; }
TypeId type_id() const { return type_id_; }
bool from_mem_pool() const { return from_mem_pool_; }
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
@ -114,14 +115,14 @@ class DeviceAddress : public mindspore::DeviceSync {
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
}
void *ptr_{nullptr};
mutable void *ptr_{nullptr};
size_t size_{0};
size_t original_ref_count_{1};
mutable size_t original_ref_count_{1};
// It will be decreased in the running, and reset by original_ref_count_ when it is zero.
size_t ref_count_{1};
mutable size_t ref_count_{1};
string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false};
mutable bool from_mem_pool_{false};
uint8_t *communication_ptr_{nullptr};
ShapeVector host_shape_{};
// {node, out_index}

View File

@ -73,6 +73,9 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
MS_EXCEPTION_IF_NULL(mem_manager_);
if (!address->from_mem_pool()) {
return;
}
mem_manager_->FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}

View File

@ -183,6 +183,9 @@ bool GPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size
void GPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
if (!address->from_mem_pool()) {
return;
}
mem_manager_->FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}