forked from mindspore-Ecosystem/mindspore
!12165 add input data type check for ps cache mode
From: @zyli2020 Reviewed-by: @cristoval Signed-off-by:
This commit is contained in:
commit
a3057441d6
|
@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() {
|
|||
}
|
||||
|
||||
// Data prefetch only when PS mode enables cache.
|
||||
if (items.size() > 0) {
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
|
||||
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
|
||||
}
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
|
||||
items[0].data_type_)) {
|
||||
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
|
||||
}
|
||||
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
|
||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
|
||||
|
@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
|
|||
if (sub_item.data_ptr_ == nullptr) {
|
||||
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
|
||||
}
|
||||
if (curr_row[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null";
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr.");
|
||||
}
|
||||
sub_item.data_type_ = curr_row[i]->type().ToString();
|
||||
const unsigned char *column_data = curr_row[i]->GetBuffer();
|
||||
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
|
||||
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {
|
||||
|
|
|
@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
|
|||
#if ENABLE_D
|
||||
// Data prefetch only when PS mode enables cache.
|
||||
if (items.size() > 0) {
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
|
||||
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_,
|
||||
items[0].tensorType_)) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,10 +44,17 @@ std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
|
||||
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
|
||||
const std::string &data_type) {
|
||||
if (cache_enable_ == false) {
|
||||
return true;
|
||||
}
|
||||
// In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
|
||||
const std::string supported_data_type = "int32";
|
||||
if (data_type != supported_data_type) {
|
||||
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
|
||||
return false;
|
||||
}
|
||||
if (data == nullptr) {
|
||||
MS_LOG(WARNING) << "No data prefetch.";
|
||||
return true;
|
||||
|
|
|
@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch {
|
|||
EXPORT bool cache_enable() const { return cache_enable_; }
|
||||
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
|
||||
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
|
||||
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
|
||||
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
|
||||
const std::string &data_type);
|
||||
EXPORT bool FinalizeData(const std::string &channel_name);
|
||||
EXPORT void NotifyFinalize();
|
||||
EXPORT void *data(const std::string &channel_name) const;
|
||||
|
|
|
@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
|
|||
|
||||
struct DataItemGpu {
|
||||
int32_t worker_id_;
|
||||
std::string data_type_;
|
||||
size_t data_len_;
|
||||
void *data_ptr_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue