Fix code review
This commit is contained in:
parent
65ec61a480
commit
be1ae51180
|
@ -22,19 +22,20 @@ namespace ps {
|
|||
namespace core {
|
||||
HttpMsgHandler::HttpMsgHandler(std::shared_ptr<HttpMessageHandler> http_msg)
|
||||
: http_msg_(http_msg), data_(nullptr), len_(0) {
|
||||
len_ = http_msg_->GetPostMsg(&data_);
|
||||
if (http_msg != nullptr) {
|
||||
len_ = http_msg_->GetPostMsg(&data_);
|
||||
}
|
||||
}
|
||||
|
||||
void *HttpMsgHandler::data() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "HttpMsgHandler data is nullptr.";
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data_, nullptr);
|
||||
return data_;
|
||||
}
|
||||
|
||||
size_t HttpMsgHandler::len() const { return len_; }
|
||||
|
||||
bool HttpMsgHandler::SendResponse(const void *data, const size_t &len) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data, false);
|
||||
http_msg_->QuickResponse(200, reinterpret_cast<unsigned char *>(const_cast<void *>(data)), len);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,9 @@ bool TcpCommunicator::Start() {
|
|||
tcp_msg_callback_ = std::bind(
|
||||
[&](std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data,
|
||||
size_t size) -> void {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(data);
|
||||
TcpUserCommand user_command = static_cast<TcpUserCommand>(meta->user_cmd());
|
||||
const std::string &msg_type = kUserCommandToMsgType.at(user_command);
|
||||
if (msg_type == "" || !msg_callbacks_[msg_type]) {
|
||||
|
@ -41,7 +44,7 @@ bool TcpCommunicator::Start() {
|
|||
MS_LOG(DEBUG) << "TcpCommunicator receives message for " << msg_type;
|
||||
std::shared_ptr<MessageHandler> tcp_msg_handler =
|
||||
std::make_shared<TcpMsgHandler>(server_node_, conn, meta, data, size);
|
||||
MS_EXCEPTION_IF_NULL(tcp_msg_handler);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_msg_handler);
|
||||
// The Submit function timed out for 30s, if it returns false, it will retry 60 times.
|
||||
bool res = CommUtil::Retry([&] { return task_executor_->Submit(msg_callbacks_[msg_type], tcp_msg_handler); },
|
||||
kRetryCount, kRetryIntervalInMs);
|
||||
|
|
|
@ -29,15 +29,16 @@ TcpMsgHandler::TcpMsgHandler(ServerNode *server_node, std::shared_ptr<core::TcpC
|
|||
}
|
||||
|
||||
void *TcpMsgHandler::data() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "TcpMsgHandler data is nullptr.";
|
||||
}
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data_, nullptr);
|
||||
return data_;
|
||||
}
|
||||
|
||||
size_t TcpMsgHandler::len() const { return len_; }
|
||||
|
||||
bool TcpMsgHandler::SendResponse(const void *data, const size_t &len) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(tcp_conn_, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(meta_, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data, false);
|
||||
server_node_->Response(tcp_conn_, meta_, const_cast<void *>(data), len);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -148,6 +148,7 @@ bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
|
|||
|
||||
bool AscendPsCache::RecordEvent() {
|
||||
event_.reset(new rtEvent_t());
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
auto ret = rtEventCreate(&(*event_));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Create event failed";
|
||||
|
@ -162,6 +163,7 @@ bool AscendPsCache::RecordEvent() {
|
|||
}
|
||||
|
||||
bool AscendPsCache::SynchronizeEvent() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
auto ret = rtEventSynchronize(*event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "tEventSynchronize failed";
|
||||
|
@ -176,6 +178,7 @@ bool AscendPsCache::SynchronizeEvent() {
|
|||
}
|
||||
|
||||
bool AscendPsCache::SynchronizeStream() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(stream_, false);
|
||||
auto ret = rtStreamSynchronize(stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "rtStreamSynchronize failed";
|
||||
|
@ -227,6 +230,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
|
|||
std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
|
||||
auto op_info =
|
||||
std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(op_info, false);
|
||||
RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod));
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
|
@ -267,6 +271,7 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
|
|||
std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
|
||||
auto op_info =
|
||||
std::make_shared<KernelNodeInfo>(kernel::kUpdateCache, input_shape, input_type, output_shape, output_type);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(op_info, false);
|
||||
SetNodedefProto(op_info, hash_swap_in_mod);
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
|
|
|
@ -43,6 +43,7 @@ void *GPUPsCache::MallocMemory(size_t size) {
|
|||
|
||||
bool GPUPsCache::RecordEvent() {
|
||||
event_.reset(new cudaEvent_t());
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda record event failed");
|
||||
|
@ -50,12 +51,14 @@ bool GPUPsCache::RecordEvent() {
|
|||
}
|
||||
|
||||
bool GPUPsCache::SynchronizeEvent() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUPsCache::SynchronizeStream() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(stream_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda sync stream failed");
|
||||
return true;
|
||||
|
|
|
@ -148,7 +148,9 @@ void PsCacheManager::Initialize() {
|
|||
Worker::GetInstance().Run();
|
||||
}
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
|
||||
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_host_cache_);
|
||||
AddEmbeddingTable();
|
||||
AllocMemForHashTable();
|
||||
SetLocalIdRank();
|
||||
|
@ -323,6 +325,8 @@ void PsCacheManager::DoProcessData(uint32_t device_id, const void *context) {
|
|||
void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) {
|
||||
MS_LOG(INFO) << "PS embedding cache process data task begin.";
|
||||
running_ = true;
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_->cache_);
|
||||
embedding_device_cache_->cache_->InitDevice(device_id, context);
|
||||
InitParameterServer();
|
||||
InitDataChannel();
|
||||
|
@ -371,6 +375,7 @@ bool PsCacheManager::ProcessData() {
|
|||
auto batch_ids = reinterpret_cast<int *>(data);
|
||||
auto batch_ids_len = data_size / sizeof(int);
|
||||
std::unique_ptr<int[]> hash_index = std::make_unique<int[]>(batch_ids_len);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_index, false);
|
||||
if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
|
||||
MS_LOG(ERROR) << "Process data memset failed.";
|
||||
return false;
|
||||
|
@ -667,6 +672,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() {
|
|||
|
||||
void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
|
||||
const float *input_addr, const int *indices_addr, float *output_addr) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(input_addr);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(indices_addr);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(output_addr);
|
||||
auto type_size = sizeof(float);
|
||||
size_t lens = outer_dim_size * type_size;
|
||||
for (size_t i = 0; i < indices_lens; ++i) {
|
||||
|
@ -693,6 +701,9 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
|
|||
|
||||
bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
|
||||
const int *indices_addr, float *output_addr) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(indices_addr, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(output_addr, false);
|
||||
size_t first_dim_size = host_vocab_cache_size_;
|
||||
size_t outer_dim_size = embedding_size;
|
||||
|
||||
|
@ -723,6 +734,9 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|||
|
||||
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices,
|
||||
const float *insert_data, float *hash_table_addr) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(insert_indices, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(insert_data, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_table_addr, false);
|
||||
size_t first_dim_size = host_vocab_cache_size_;
|
||||
size_t thread_num = insert_indices_size / kMaxIdsPerThread + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
|
@ -780,8 +794,10 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
|||
return true;
|
||||
}
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_info.host_address, false);
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
|
||||
|
@ -832,7 +848,9 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
|
|||
bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
|
||||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_ids, false);
|
||||
auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(host_to_server_index, false);
|
||||
auto swap_indices_size = statistics_info_.host_to_server_size_;
|
||||
if (swap_indices_size == 0) {
|
||||
return true;
|
||||
|
@ -860,11 +878,14 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
|
|||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
auto swap_indices_size = statistics_info_.server_to_host_size_;
|
||||
auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_ids, false);
|
||||
auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(server_to_host_index, false);
|
||||
if (swap_indices_size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(host_hash_table_addr, false);
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
|
||||
std::vector<int> lookup_ids(swap_indices_size, 0);
|
||||
|
@ -891,6 +912,8 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *
|
|||
if (swap_out_index_size == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
|
@ -917,6 +940,8 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in
|
|||
if (swap_in_ids_size == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(hash_info.device_address.addr, false);
|
||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||
auto embedding_size = hash_info.embedding_size;
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed,
|
||||
float *output_data) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(output_data, false);
|
||||
if (out_shape.size() == 0) {
|
||||
std::cout << "output data shape is error" << std::endl;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue