!12165 add input data type check for ps cache mode

From: @zyli2020
Reviewed-by: @cristoval
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-07 16:43:24 +08:00 committed by Gitee
commit a3057441d6
5 changed files with 21 additions and 7 deletions

View File

@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() {
}
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
items[0].data_type_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
if (sub_item.data_ptr_ == nullptr) {
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
}
if (curr_row[i] == nullptr) {
MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null";
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr.");
}
sub_item.data_type_ = curr_row[i]->type().ToString();
const unsigned char *column_data = curr_row[i]->GetBuffer();
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {

View File

@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
#if ENABLE_D
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_,
items[0].tensorType_)) {
return FAILED;
}
}

View File

@ -44,10 +44,17 @@ std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string
return iter->second;
}
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type) {
if (cache_enable_ == false) {
return true;
}
// In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
const std::string supported_data_type = "int32";
if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
return false;
}
if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch.";
return true;

View File

@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch {
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;

View File

@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
struct DataItemGpu {
int32_t worker_id_;
std::string data_type_;
size_t data_len_;
void *data_ptr_;
};