From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-04-14 09:38:44 +08:00 committed by Gitee
commit d65b370bdf
11 changed files with 55 additions and 32 deletions

View File

@ -526,7 +526,9 @@ void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const
auto it = receive_messages_.find(request_id);
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
if (size > 0) {
auto ret = memcpy_s(received_data.get()->data(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
@ -586,7 +588,9 @@ void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const P
// If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
int ret = memcpy_s(received_data->data(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}

View File

@ -106,7 +106,7 @@ int HttpClient::ReadHeaderDoneCallback(struct evhttp_request *request, void *arg
handler->set_request(request);
struct evkeyvalq *headers = evhttp_request_get_input_headers(request);
MS_EXCEPTION_IF_NULL(headers);
struct evkeyval *header;
struct evkeyval *header = nullptr;
TAILQ_FOREACH(header, headers, next) {
MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value;
std::string len = "Content-Length";
@ -204,7 +204,6 @@ Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, st
bool HttpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_);
// int ret = event_base_dispatch(event_base_);
int ret = event_base_loop(event_base_, 0);
if (ret == 0) {
MS_LOG(DEBUG) << "Event base dispatch success!";

View File

@ -237,7 +237,9 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) {
void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer);
int ret = memcpy_s(body_->data() + offset_, num, buffer, num);
size_t dest_size = num;
size_t src_size = num;
int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}

View File

@ -53,10 +53,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
remaining_length_ -= copy_len;
num -= copy_len;
int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len);
size_t dest_size = copy_len;
size_t src_size = copy_len;
auto ret = memcpy_s(message_buffer_.get() + last_copy_len_, dest_size, buffer_data, src_size);
last_copy_len_ += copy_len;
buffer_data += copy_len;
if (ret != 0) {
if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}

View File

@ -48,7 +48,7 @@ class TcpMessageHandler {
bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_;
size_t remaining_length_;
char header_[16];
char header_[16]{0};
int header_index_;
size_t last_copy_len_;
MessageHeader message_header_;

View File

@ -102,7 +102,9 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
auto ret = memcpy_s(res.get(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
auto ret = memcpy_s(res.get(), dest_size, data, src_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}

View File

@ -548,8 +548,9 @@ void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, Ve
auto weight = ps_->weight(key);
*res_data.mutable_values() = {weight->begin(), weight->end()};
res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
@ -662,8 +663,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
res_data.add_keys(key);
res_data.add_values(ready);
res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
@ -679,8 +681,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_
res_data.add_keys(key);
res_data.add_values(ready);
res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
@ -699,8 +702,9 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
ps_->DoEmbeddingLookup(key, keys, &res_data);
res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}

View File

@ -936,7 +936,8 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in
return true;
}
bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) {
bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids,
size_t key) {
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_out_ids);

View File

@ -165,7 +165,7 @@ class PsCacheManager {
const float *insert_data, float *hash_table_addr);
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr);
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key);
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;

View File

@ -84,7 +84,9 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
int size = sizes[i] * sizeof(float);
auto ret = memcpy_s(dst_data, size, src_data, size);
size_t dest_size = IntToSize(size);
size_t src_size = IntToSize(size);
auto ret = memcpy_s(dst_data, dest_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -222,7 +224,8 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
std::string kv_data = embedding_table_meta.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -288,7 +291,8 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -371,7 +375,8 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -391,7 +396,8 @@ void Worker::Finalize() {
kvs.add_values(0.0f);
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -482,7 +488,7 @@ void Worker::InitPSOptimInputShapes(const size_t key) {
PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
}
void Worker::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) {
MS_EXCEPTION_IF_NULL(origin_addr);
std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
@ -632,7 +638,8 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va
} else {
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -658,7 +665,7 @@ void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<floa
}
}
void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens, int cmd,
void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd,
int64_t priority) {
MS_EXCEPTION_IF_NULL(vals);
KVMessage kvs;
@ -933,7 +940,8 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
@ -959,7 +967,8 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
std::string kv_data = messages.at(i).second.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;

View File

@ -92,7 +92,7 @@ class Worker {
void AddKeyByHashMod(const Key &key);
void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size);
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
@ -105,8 +105,8 @@ class Worker {
int command = 0, int64_t priority = 0);
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
int64_t priority = 0);
void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr,
int cmd = 0, int64_t priority = 0);
void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs);