support dynamic datasink on GPU

This commit is contained in:
wYann 2022-01-18 19:10:36 +08:00
parent 55ba926a04
commit 39e89f73ac
11 changed files with 64 additions and 53 deletions

View File

@ -14,14 +14,15 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h"
#include <cuda_runtime_api.h>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "utils/convert_utils.h"
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
#include "backend/kernel_compiler/common_utils.h"
#ifndef ENABLE_SECURITY
#include "profiler/device/gpu/gpu_profiling.h"
#endif
@ -36,7 +37,7 @@ namespace kernel {
using mindspore::device::GpuBufferMgr;
DatasetIteratorKernelMod::DatasetIteratorKernelMod()
: handle_(GpuBufferMgr::INVALID_HANDLE), total_bytes_(0), profiling_enable_(false), profiling_op_(nullptr) {}
: handle_(GpuBufferMgr::INVALID_HANDLE), profiling_enable_(false), profiling_op_(nullptr) {}
DatasetIteratorKernelMod::~DatasetIteratorKernelMod() { GpuBufferMgr::GetInstance().Close(handle_); }
@ -46,17 +47,16 @@ bool DatasetIteratorKernelMod::Init(const CNodePtr &kernel_node) {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
std::vector<std::vector<int>> shapes;
std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types);
for (auto item : types) {
std::vector<TypePtr> type_ptrs;
GetShapeAndType(kernel_node, &shapes, &type_ptrs);
for (auto item : type_ptrs) {
MS_EXCEPTION_IF_NULL(item);
}
std::transform(type_ptrs.begin(), type_ptrs.end(), std::back_inserter(types_),
[](const TypePtr &value) { return value->type_id(); });
for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id());
int nums = ElementNums(shapes[i]);
int bytes = unit * nums;
output_size_list_.push_back(bytes);
total_bytes_ += bytes;
output_size_list_.push_back(0); // output_size could be dynamic when shapes is dynamic, just give fake value here.
}
#ifndef ENABLE_SECURITY
@ -74,7 +74,7 @@ bool DatasetIteratorKernelMod::Init(const CNodePtr &kernel_node) {
void DatasetIteratorKernelMod::InitSizeLists() { return; }
bool DatasetIteratorKernelMod::ReadDevice(void **addr, size_t *len) {
bool DatasetIteratorKernelMod::ReadDevice(std::vector<DataItemGpu> *data) {
uint64_t start_time_stamp = 0;
uint32_t queue_size = 0;
#ifndef ENABLE_SECURITY
@ -90,7 +90,7 @@ bool DatasetIteratorKernelMod::ReadDevice(void **addr, size_t *len) {
queue_size = GpuBufferMgr::GetInstance().Size(handle_);
}
#endif
auto ret = GpuBufferMgr::GetInstance().Front(handle_, addr, len);
auto ret = GpuBufferMgr::GetInstance().Front(handle_, data);
if (ret == device::SUCCESS) {
#ifndef ENABLE_SECURITY
if (profiling_enable_) {
@ -134,24 +134,18 @@ bool DatasetIteratorKernelMod::Launch(const std::vector<AddressPtr> &, const std
}
}
void *addr = nullptr;
size_t len = 0;
if (!ReadDevice(&addr, &len)) {
return false;
}
if (total_bytes_ != len) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len
<< " Bytes, expect: " << total_bytes_ << " Bytes.";
if (!ReadDevice(&output_data_)) {
return false;
}
for (size_t i = 0; i < output_size_list_.size(); i++) {
for (size_t i = 0; i < output_data_.size(); i++) {
void *output_addr = GetDeviceAddress<void>(outputs, i);
auto device_addr = output_data_[i].device_addr_;
auto data_len = output_data_[i].data_len_;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice,
cudaMemcpyAsync(output_addr, device_addr, data_len, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream)),
"Cuda Memcpy Failed");
addr = reinterpret_cast<unsigned char *>(addr) + output_size_list_[i];
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)),
@ -159,5 +153,15 @@ bool DatasetIteratorKernelMod::Launch(const std::vector<AddressPtr> &, const std
(void)GpuBufferMgr::GetInstance().Pop(handle_);
return true;
}
void DatasetIteratorKernelMod::PostExecute() {
std::vector<std::vector<size_t>> shapes;
for (const auto &item : output_data_) {
std::vector<size_t> shape;
std::transform(item.shapes_.begin(), item.shapes_.end(), std::back_inserter(shape), LongToSize);
shapes.push_back(shape);
}
AnfAlgo::SetOutputInferTypeAndShape(types_, shapes, kernel_node_.lock().get());
}
} // namespace kernel
} // namespace mindspore

View File

@ -23,9 +23,11 @@
#include "backend/kernel_compiler/gpu/data/dataset_profiling.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/blocking_queue.h"
namespace mindspore {
namespace kernel {
using mindspore::device::DataItemGpu;
class DatasetIteratorKernelMod : public NativeGpuKernelMod {
public:
DatasetIteratorKernelMod();
@ -34,17 +36,19 @@ class DatasetIteratorKernelMod : public NativeGpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
void PostExecute() override;
protected:
void InitSizeLists() override;
private:
bool ReadDevice(void **addr, size_t *len);
bool ReadDevice(std::vector<DataItemGpu> *data);
std::string queue_name_;
unsigned int handle_;
size_t total_bytes_;
bool profiling_enable_;
std::shared_ptr<GetNextProfiling> profiling_op_;
std::vector<TypeId> types_;
std::vector<DataItemGpu> output_data_;
};
MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernelMod)

View File

@ -79,7 +79,7 @@ class GpuKernelRegister {
// because the variable created by the macro will also contain a space. So, we solve this
// problem by writing uchar when calling these macros, and expanding uchar after the
// variable has been created.
#define uchar unsigned char
using uchar = unsigned char;
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(g_##kernel##_gpu_kernel_reg, __COUNTER__)
#define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt)

View File

@ -552,6 +552,7 @@ Status DeviceQueueOp::WorkerEntry(int32_t worker_id) {
for (auto &i : current_row) {
device::DataItemGpu data_item;
data_item.data_len_ = static_cast<size_t>(i->SizeInBytes());
data_item.shapes_ = i->shape().AsVector();
data_item.data_ptr_ = nullptr;
data_item.worker_id_ = worker_id;
items.push_back(data_item);

View File

@ -42,20 +42,19 @@ GpuQueue::GpuQueue(void *addr, const std::vector<size_t> &shape, const size_t &c
GpuQueue::~GpuQueue() { buffer_ = nullptr; }
BlockQueueStatus_T GpuQueue::Push(const std::vector<DataItemGpu> &data) {
int offset = 0;
BlockQueueStatus_T GpuQueue::Push(std::vector<DataItemGpu> data) {
void *addr = reinterpret_cast<uint8_t *>(buffer_) + tail_ * len_;
for (size_t i = 0; i < data.size(); i++) {
auto item = data[i];
if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) {
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_;
auto &item = data[i];
if (item.data_ptr_ == nullptr || item.data_len_ > shape_[i]) {
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_
<< ", exceeds the max len: " << shape_[i];
return ERROR_INPUT;
}
void *addr = reinterpret_cast<unsigned char *>(buffer_) + tail_ * len_ + offset;
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_),
"Cuda Memcpy Error");
offset += item.data_len_;
item.device_addr_ = addr;
addr = reinterpret_cast<uint8_t *>(addr) + item.data_len_;
}
node_info_[tail_].event_.reset(new cudaEvent_t());
@ -67,15 +66,13 @@ BlockQueueStatus_T GpuQueue::Push(const std::vector<DataItemGpu> &data) {
return SUCCESS;
}
BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const {
BlockQueueStatus_T GpuQueue::Front(std::vector<DataItemGpu> *data) const {
CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed");
CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed");
*addr = (unsigned char *)buffer_ + head_ * len_;
*len = len_;
for (auto item : node_info_[head_].data_) {
for (auto &item : node_info_[head_].data_) {
host_release_(item.data_ptr_, item.worker_id_);
}
*data = node_info_[head_].data_;
return SUCCESS;
}
@ -124,14 +121,14 @@ BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, uns
return SUCCESS;
}
BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) {
BlockQueueStatus_T BlockingQueue::Front(std::vector<DataItemGpu> *data) {
std::unique_lock<std::mutex> locker(mutex_);
bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); });
if (!timeout) {
return TIMEOUT;
}
return queue_->Front(addr, len);
return queue_->Front(data);
}
BlockQueueStatus_T BlockingQueue::Pop() {

View File

@ -33,10 +33,12 @@ namespace device {
enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_EXIST, HANDLE_NOT_EXIST, ERROR_INPUT, INTERNAL_ERROR, TIMEOUT };
struct DataItemGpu {
int32_t worker_id_;
int32_t worker_id_{0};
std::string data_type_;
size_t data_len_;
void *data_ptr_;
size_t data_len_{0};
void *data_ptr_{nullptr};
std::vector<int64_t> shapes_;
void *device_addr_{nullptr};
};
class GpuQueue {
@ -49,8 +51,8 @@ class GpuQueue {
inline bool IsEmpty() const { return size_ == 0; }
inline bool IsFull() const { return size_ == capacity_; }
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data);
BlockQueueStatus_T Front(void **ptr, size_t *len) const;
BlockQueueStatus_T Push(std::vector<DataItemGpu> data);
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data) const;
BlockQueueStatus_T Pop();
bool Destroy();
size_t Size() { return size_; }
@ -85,7 +87,7 @@ class BlockingQueue {
BlockQueueStatus_T Create(void *addr, const std::vector<size_t> &shape, const size_t &capacity);
void RegisterRelease(const std::function<void(void *, int32_t)> &func);
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec);
BlockQueueStatus_T Front(void **ptr, size_t *len);
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data);
BlockQueueStatus_T Pop();
bool Destroy();
size_t Size() { return queue_->Size(); }

View File

@ -114,12 +114,12 @@ BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector<Dat
return iter->second->Push(data, timeout_in_sec);
}
BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) {
BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, std::vector<DataItemGpu> *data) {
auto iter = handle_queue_map_.find(handle);
if (iter == handle_queue_map_.end()) {
return HANDLE_NOT_EXIST;
}
return iter->second->Front(addr, len);
return iter->second->Front(data);
}
BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) {

View File

@ -82,7 +82,7 @@ class GpuBufferMgr {
EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector<DataItemGpu> &data,
unsigned int timeout_in_sec);
EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len);
EXPORT BlockQueueStatus_T Front(unsigned int handle, std::vector<DataItemGpu> *data);
EXPORT BlockQueueStatus_T Pop(unsigned int handle);
EXPORT void set_device_id(int device_id);

View File

@ -145,7 +145,7 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
// Copy data from device queue by data kernel launching.
try {
auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_);
launch_info_.outputs_, AnfAlgo::IsDynamicShape(data_kernel_));
if (!ret) {
std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);

View File

@ -63,6 +63,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
# transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
if exec_dataset.dynamic_setting[0]:
_, dataset_shapes = exec_dataset.dynamic_min_max_shapes()
send_epoch_end = bool(dataset_size == -1)
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)

View File

@ -30,6 +30,7 @@ class MindData:
self._output_shapes = output_shapes
self._input_indexs = input_indexs
self._iter_num = 0
self.dynamic_setting = [False, None]
def get_dataset_size(self):
return self._size