diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 624aa06f907..c1d103b4c83 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() { } // Data prefetch only when PS mode enables cache. - if (items.size() > 0) { - if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { - return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); - } + if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_, + items[0].data_type_)) { + return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); } while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); @@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector *items, if (sub_item.data_ptr_ == nullptr) { return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed."); } + if (curr_row[i] == nullptr) { + MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null"; + return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr."); + } + sub_item.data_type_ = curr_row[i]->type().ToString(); const unsigned char *column_data = curr_row[i]->GetBuffer(); if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, static_cast(curr_row[i++]->SizeInBytes())) != 0) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 9bfdcacee45..94175037e21 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe #if ENABLE_D // Data prefetch only when PS mode enables cache. if (items.size() > 0) { - if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { + if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_, + items[0].tensorType_)) { return FAILED; } } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc index 0b32e2d7df2..e0a9d0ab01c 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc @@ -44,10 +44,17 @@ std::shared_ptr PsDataPrefetch::ps_data_channel(const std::string return iter->second; } -bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { +bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size, + const std::string &data_type) { if (cache_enable_ == false) { return true; } + // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32' + const std::string supported_data_type = "int32"; + if (data_type != supported_data_type) { + MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; + return false; + } if (data == nullptr) { MS_LOG(WARNING) << "No data prefetch."; return true; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h index 5b69fa47798..256a6475ddf 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch { EXPORT bool cache_enable() const { return cache_enable_; } EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); - EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); + EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size, + const std::string &data_type); EXPORT bool FinalizeData(const std::string &channel_name); EXPORT void NotifyFinalize(); EXPORT void *data(const std::string &channel_name) const; diff --git a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h index 52bb954d527..ebb97f0866b 100644 --- a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h +++ b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h @@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, struct DataItemGpu { int32_t worker_id_; + std::string data_type_; size_t data_len_; void *data_ptr_; };