forked from mindspore-Ecosystem/mindspore
support dynamic datasink on GPU
This commit is contained in:
parent
55ba926a04
commit
39e89f73ac
|
@ -14,14 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "profiler/device/gpu/gpu_profiling.h"
|
||||
#endif
|
||||
|
@ -36,7 +37,7 @@ namespace kernel {
|
|||
using mindspore::device::GpuBufferMgr;
|
||||
|
||||
DatasetIteratorKernelMod::DatasetIteratorKernelMod()
|
||||
: handle_(GpuBufferMgr::INVALID_HANDLE), total_bytes_(0), profiling_enable_(false), profiling_op_(nullptr) {}
|
||||
: handle_(GpuBufferMgr::INVALID_HANDLE), profiling_enable_(false), profiling_op_(nullptr) {}
|
||||
|
||||
DatasetIteratorKernelMod::~DatasetIteratorKernelMod() { GpuBufferMgr::GetInstance().Close(handle_); }
|
||||
|
||||
|
@ -46,17 +47,16 @@ bool DatasetIteratorKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<TypePtr> types;
|
||||
GetShapeAndType(kernel_node, &shapes, &types);
|
||||
for (auto item : types) {
|
||||
std::vector<TypePtr> type_ptrs;
|
||||
GetShapeAndType(kernel_node, &shapes, &type_ptrs);
|
||||
for (auto item : type_ptrs) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::transform(type_ptrs.begin(), type_ptrs.end(), std::back_inserter(types_),
|
||||
[](const TypePtr &value) { return value->type_id(); });
|
||||
|
||||
for (size_t i = 0; i < shapes.size(); i++) {
|
||||
int unit = UnitSizeInBytes(types[i]->type_id());
|
||||
int nums = ElementNums(shapes[i]);
|
||||
int bytes = unit * nums;
|
||||
output_size_list_.push_back(bytes);
|
||||
total_bytes_ += bytes;
|
||||
output_size_list_.push_back(0); // output_size could be dynamic when shapes is dynamic, just give fake value here.
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
|
@ -74,7 +74,7 @@ bool DatasetIteratorKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
|
||||
void DatasetIteratorKernelMod::InitSizeLists() { return; }
|
||||
|
||||
bool DatasetIteratorKernelMod::ReadDevice(void **addr, size_t *len) {
|
||||
bool DatasetIteratorKernelMod::ReadDevice(std::vector<DataItemGpu> *data) {
|
||||
uint64_t start_time_stamp = 0;
|
||||
uint32_t queue_size = 0;
|
||||
#ifndef ENABLE_SECURITY
|
||||
|
@ -90,7 +90,7 @@ bool DatasetIteratorKernelMod::ReadDevice(void **addr, size_t *len) {
|
|||
queue_size = GpuBufferMgr::GetInstance().Size(handle_);
|
||||
}
|
||||
#endif
|
||||
auto ret = GpuBufferMgr::GetInstance().Front(handle_, addr, len);
|
||||
auto ret = GpuBufferMgr::GetInstance().Front(handle_, data);
|
||||
if (ret == device::SUCCESS) {
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (profiling_enable_) {
|
||||
|
@ -134,24 +134,18 @@ bool DatasetIteratorKernelMod::Launch(const std::vector<AddressPtr> &, const std
|
|||
}
|
||||
}
|
||||
|
||||
void *addr = nullptr;
|
||||
size_t len = 0;
|
||||
if (!ReadDevice(&addr, &len)) {
|
||||
return false;
|
||||
}
|
||||
if (total_bytes_ != len) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dataset front error, read: " << len
|
||||
<< " Bytes, expect: " << total_bytes_ << " Bytes.";
|
||||
if (!ReadDevice(&output_data_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_size_list_.size(); i++) {
|
||||
for (size_t i = 0; i < output_data_.size(); i++) {
|
||||
void *output_addr = GetDeviceAddress<void>(outputs, i);
|
||||
auto device_addr = output_data_[i].device_addr_;
|
||||
auto data_len = output_data_[i].data_len_;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice,
|
||||
cudaMemcpyAsync(output_addr, device_addr, data_len, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream)),
|
||||
"Cuda Memcpy Failed");
|
||||
addr = reinterpret_cast<unsigned char *>(addr) + output_size_list_[i];
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)),
|
||||
|
@ -159,5 +153,15 @@ bool DatasetIteratorKernelMod::Launch(const std::vector<AddressPtr> &, const std
|
|||
(void)GpuBufferMgr::GetInstance().Pop(handle_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void DatasetIteratorKernelMod::PostExecute() {
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (const auto &item : output_data_) {
|
||||
std::vector<size_t> shape;
|
||||
std::transform(item.shapes_.begin(), item.shapes_.end(), std::back_inserter(shape), LongToSize);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types_, shapes, kernel_node_.lock().get());
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,9 +23,11 @@
|
|||
#include "backend/kernel_compiler/gpu/data/dataset_profiling.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
#include "runtime/device/gpu/blocking_queue.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using mindspore::device::DataItemGpu;
|
||||
|
||||
class DatasetIteratorKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
DatasetIteratorKernelMod();
|
||||
|
@ -34,17 +36,19 @@ class DatasetIteratorKernelMod : public NativeGpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
bool Init(const CNodePtr &kernel_node) override;
|
||||
void PostExecute() override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override;
|
||||
|
||||
private:
|
||||
bool ReadDevice(void **addr, size_t *len);
|
||||
bool ReadDevice(std::vector<DataItemGpu> *data);
|
||||
std::string queue_name_;
|
||||
unsigned int handle_;
|
||||
size_t total_bytes_;
|
||||
bool profiling_enable_;
|
||||
std::shared_ptr<GetNextProfiling> profiling_op_;
|
||||
std::vector<TypeId> types_;
|
||||
std::vector<DataItemGpu> output_data_;
|
||||
};
|
||||
|
||||
MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernelMod)
|
||||
|
|
|
@ -79,7 +79,7 @@ class GpuKernelRegister {
|
|||
// because the variable created by the macro will also contain a space. So, we solve this
|
||||
// problem by writing uchar when calling these macros, and expanding uchar after the
|
||||
// variable has been created.
|
||||
#define uchar unsigned char
|
||||
using uchar = unsigned char;
|
||||
|
||||
#define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(g_##kernel##_gpu_kernel_reg, __COUNTER__)
|
||||
#define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt)
|
||||
|
|
|
@ -552,6 +552,7 @@ Status DeviceQueueOp::WorkerEntry(int32_t worker_id) {
|
|||
for (auto &i : current_row) {
|
||||
device::DataItemGpu data_item;
|
||||
data_item.data_len_ = static_cast<size_t>(i->SizeInBytes());
|
||||
data_item.shapes_ = i->shape().AsVector();
|
||||
data_item.data_ptr_ = nullptr;
|
||||
data_item.worker_id_ = worker_id;
|
||||
items.push_back(data_item);
|
||||
|
|
|
@ -42,20 +42,19 @@ GpuQueue::GpuQueue(void *addr, const std::vector<size_t> &shape, const size_t &c
|
|||
|
||||
GpuQueue::~GpuQueue() { buffer_ = nullptr; }
|
||||
|
||||
BlockQueueStatus_T GpuQueue::Push(const std::vector<DataItemGpu> &data) {
|
||||
int offset = 0;
|
||||
BlockQueueStatus_T GpuQueue::Push(std::vector<DataItemGpu> data) {
|
||||
void *addr = reinterpret_cast<uint8_t *>(buffer_) + tail_ * len_;
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
auto item = data[i];
|
||||
if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) {
|
||||
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_;
|
||||
auto &item = data[i];
|
||||
if (item.data_ptr_ == nullptr || item.data_len_ > shape_[i]) {
|
||||
MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_
|
||||
<< ", exceeds the max len: " << shape_[i];
|
||||
return ERROR_INPUT;
|
||||
}
|
||||
|
||||
void *addr = reinterpret_cast<unsigned char *>(buffer_) + tail_ * len_ + offset;
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_),
|
||||
"Cuda Memcpy Error");
|
||||
|
||||
offset += item.data_len_;
|
||||
item.device_addr_ = addr;
|
||||
addr = reinterpret_cast<uint8_t *>(addr) + item.data_len_;
|
||||
}
|
||||
|
||||
node_info_[tail_].event_.reset(new cudaEvent_t());
|
||||
|
@ -67,15 +66,13 @@ BlockQueueStatus_T GpuQueue::Push(const std::vector<DataItemGpu> &data) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const {
|
||||
BlockQueueStatus_T GpuQueue::Front(std::vector<DataItemGpu> *data) const {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed");
|
||||
*addr = (unsigned char *)buffer_ + head_ * len_;
|
||||
*len = len_;
|
||||
|
||||
for (auto item : node_info_[head_].data_) {
|
||||
for (auto &item : node_info_[head_].data_) {
|
||||
host_release_(item.data_ptr_, item.worker_id_);
|
||||
}
|
||||
*data = node_info_[head_].data_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -124,14 +121,14 @@ BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, uns
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) {
|
||||
BlockQueueStatus_T BlockingQueue::Front(std::vector<DataItemGpu> *data) {
|
||||
std::unique_lock<std::mutex> locker(mutex_);
|
||||
bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); });
|
||||
if (!timeout) {
|
||||
return TIMEOUT;
|
||||
}
|
||||
|
||||
return queue_->Front(addr, len);
|
||||
return queue_->Front(data);
|
||||
}
|
||||
|
||||
BlockQueueStatus_T BlockingQueue::Pop() {
|
||||
|
|
|
@ -33,10 +33,12 @@ namespace device {
|
|||
enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_EXIST, HANDLE_NOT_EXIST, ERROR_INPUT, INTERNAL_ERROR, TIMEOUT };
|
||||
|
||||
struct DataItemGpu {
|
||||
int32_t worker_id_;
|
||||
int32_t worker_id_{0};
|
||||
std::string data_type_;
|
||||
size_t data_len_;
|
||||
void *data_ptr_;
|
||||
size_t data_len_{0};
|
||||
void *data_ptr_{nullptr};
|
||||
std::vector<int64_t> shapes_;
|
||||
void *device_addr_{nullptr};
|
||||
};
|
||||
|
||||
class GpuQueue {
|
||||
|
@ -49,8 +51,8 @@ class GpuQueue {
|
|||
inline bool IsEmpty() const { return size_ == 0; }
|
||||
inline bool IsFull() const { return size_ == capacity_; }
|
||||
|
||||
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data);
|
||||
BlockQueueStatus_T Front(void **ptr, size_t *len) const;
|
||||
BlockQueueStatus_T Push(std::vector<DataItemGpu> data);
|
||||
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data) const;
|
||||
BlockQueueStatus_T Pop();
|
||||
bool Destroy();
|
||||
size_t Size() { return size_; }
|
||||
|
@ -85,7 +87,7 @@ class BlockingQueue {
|
|||
BlockQueueStatus_T Create(void *addr, const std::vector<size_t> &shape, const size_t &capacity);
|
||||
void RegisterRelease(const std::function<void(void *, int32_t)> &func);
|
||||
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec);
|
||||
BlockQueueStatus_T Front(void **ptr, size_t *len);
|
||||
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data);
|
||||
BlockQueueStatus_T Pop();
|
||||
bool Destroy();
|
||||
size_t Size() { return queue_->Size(); }
|
||||
|
|
|
@ -114,12 +114,12 @@ BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector<Dat
|
|||
return iter->second->Push(data, timeout_in_sec);
|
||||
}
|
||||
|
||||
BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) {
|
||||
BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, std::vector<DataItemGpu> *data) {
|
||||
auto iter = handle_queue_map_.find(handle);
|
||||
if (iter == handle_queue_map_.end()) {
|
||||
return HANDLE_NOT_EXIST;
|
||||
}
|
||||
return iter->second->Front(addr, len);
|
||||
return iter->second->Front(data);
|
||||
}
|
||||
|
||||
BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) {
|
||||
|
|
|
@ -82,7 +82,7 @@ class GpuBufferMgr {
|
|||
|
||||
EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector<DataItemGpu> &data,
|
||||
unsigned int timeout_in_sec);
|
||||
EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len);
|
||||
EXPORT BlockQueueStatus_T Front(unsigned int handle, std::vector<DataItemGpu> *data);
|
||||
EXPORT BlockQueueStatus_T Pop(unsigned int handle);
|
||||
|
||||
EXPORT void set_device_id(int device_id);
|
||||
|
|
|
@ -145,7 +145,7 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
|
|||
// Copy data from device queue by data kernel launching.
|
||||
try {
|
||||
auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
launch_info_.outputs_, AnfAlgo::IsDynamicShape(data_kernel_));
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
|
|
@ -63,6 +63,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
if exec_dataset.dynamic_setting[0]:
|
||||
_, dataset_shapes = exec_dataset.dynamic_min_max_shapes()
|
||||
send_epoch_end = bool(dataset_size == -1)
|
||||
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ class MindData:
|
|||
self._output_shapes = output_shapes
|
||||
self._input_indexs = input_indexs
|
||||
self._iter_num = 0
|
||||
self.dynamic_setting = [False, None]
|
||||
|
||||
def get_dataset_size(self):
|
||||
return self._size
|
||||
|
|
Loading…
Reference in New Issue