!17733 fixed rank id

From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-06-04 14:30:39 +08:00 committed by Gitee
commit bfd1001367
10 changed files with 26 additions and 28 deletions

View File

@ -99,7 +99,7 @@ void AbstractNode::set_ready_for_scale_in() {
connected_nodes_.clear();
} else {
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
node_info_.rank_id_ = UINT_MAX;
node_info_.rank_id_ = UINT32_MAX;
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
}
@ -635,7 +635,7 @@ bool AbstractNode::InitClientToScheduler() {
return client_to_scheduler_->WaitConnected();
}
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &rank_id) {
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id) {
std::lock_guard<std::mutex> lock(client_mutex_);
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
return connected_nodes_[rank_id];

View File

@ -135,7 +135,7 @@ class AbstractNode : public Node {
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunMessageCallback(const uint64_t &request_id);
@ -162,7 +162,7 @@ class AbstractNode : public Node {
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
std::mutex client_mutex_;
// the map's key is: rank_id
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
std::unordered_map<uint32_t, std::shared_ptr<TcpClient>> connected_nodes_;
// the key is: request_id, the value is: <rank_id, RecvMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;

View File

@ -27,9 +27,9 @@ void NodeManager::InitNode() {
total_node_num_ = initial_total_node_num_;
}
int NodeManager::NextRankId(const RegisterMessage &register_message) {
uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
int rank_id = -1;
uint32_t rank_id = UINT_MAX;
const std::string &node_id = register_message.node_id();
if (nodes_info_.find(node_id) != nodes_info_.end()) {
@ -43,9 +43,9 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
uint32_t port = register_message.port();
rank_id = ++next_server_rank_id_;
if (IntToUint(rank_id) >= meta_data_->server_num) {
if (rank_id >= meta_data_->server_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
rank_id = -1;
rank_id = UINT_MAX;
--next_server_rank_id_;
}
NodeInfo node_info;
@ -61,9 +61,9 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
rank_id = ++next_worker_rank_id_;
if (IntToUint(rank_id) >= meta_data_->worker_num) {
if (rank_id >= meta_data_->worker_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
rank_id = -1;
rank_id = UINT_MAX;
--next_worker_rank_id_;
}
NodeInfo node_info;

View File

@ -54,7 +54,7 @@ class NodeManager {
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
void InitNode();
int NextRankId(const RegisterMessage &register_message);
uint32_t NextRankId(const RegisterMessage &register_message);
void UpdateHeartbeat(const std::string &node_id);
bool CheckNodesScaluOutState();

View File

@ -53,7 +53,7 @@ message MessageMeta {
// the role of the current node: worker,server,scheduler
NodeRole role = 3;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32 rank_id = 4;
uint32 rank_id = 4;
// User-defined commands
int32 user_cmd = 5;
}
@ -71,7 +71,7 @@ message RegisterMessage {
message RegisterRespMessage {
string node_id = 1;
int32 rank_id = 2;
uint32 rank_id = 2;
}
message HeartbeatMessage {
@ -109,7 +109,7 @@ message FetchServersRespMessage {
}
message ServersMeta {
int32 rank_id = 1;
uint32 rank_id = 1;
string ip = 2;
int32 port = 3;

View File

@ -117,8 +117,8 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
register_message.ParseFromArray(data, size);
// assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message);
if (rank_id < 0) {
uint32_t rank_id = node_manager_.NextRankId(register_message);
if (rank_id == UINT32_MAX) {
MS_LOG(WARNING) << "The rank id is wrong!";
}
const std::string &node_id = register_message.node_id();

View File

@ -126,7 +126,7 @@ class PsCacheManager {
const size_t &QueryHashTableSize(const std::string &param_name) const;
bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
void set_rank_id(int rank_id) { rank_id_ = rank_id; }
void set_rank_id(uint32_t rank_id) { rank_id_ = rank_id; }
bool initialized_ps_cache() const { return initialized_ps_cache_; }
size_t vocab_cache_size() const { return vocab_cache_size_; }
int cache_indices_lower_bound() const;
@ -203,7 +203,7 @@ class PsCacheManager {
std::pair<int, int> emb_table_slice_bounds_;
std::pair<int, int> cache_indices_bounds_;
int vocab_cache_size_diff_{0};
int rank_id_{0};
uint32_t rank_id_{0};
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};

View File

@ -124,9 +124,9 @@ uint32_t PSContext::initial_server_num() const { return server_num_; }
std::string PSContext::scheduler_host() const { return scheduler_host_; }
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
int PSContext::ps_rank_id() const { return rank_id_; }
uint32_t PSContext::ps_rank_id() const { return rank_id_; }
void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const {
@ -166,7 +166,7 @@ void PSContext::set_cache_enable(bool cache_enable) const {
#endif
}
void PSContext::set_rank_id(int rank_id) const {
void PSContext::set_rank_id(uint32_t rank_id) const {
#if (ENABLE_CPU && !_WIN32)
ps_cache_instance.set_rank_id(rank_id);
#endif

View File

@ -67,8 +67,8 @@ class PSContext {
uint32_t initial_worker_num() const;
uint32_t initial_server_num() const;
std::string scheduler_host() const;
void SetPSRankId(int rank_id);
int ps_rank_id() const;
void SetPSRankId(uint32_t rank_id);
uint32_t ps_rank_id() const;
void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const;
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
@ -77,7 +77,7 @@ class PSContext {
void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
void set_cache_enable(bool cache_enable) const;
void set_rank_id(int rank_id) const;
void set_rank_id(uint32_t rank_id) const;
bool enable_ssl() const;
void set_enable_ssl(bool enabled);
@ -155,7 +155,7 @@ class PSContext {
is_pserver_(false),
is_sched_(false),
enable_ssl_(false),
rank_id_(-1),
rank_id_(0),
worker_num_(0),
server_num_(0),
scheduler_host_(""),
@ -182,7 +182,7 @@ class PSContext {
bool is_pserver_;
bool is_sched_;
bool enable_ssl_;
int rank_id_;
uint32_t rank_id_;
uint32_t worker_num_;
uint32_t server_num_;
std::string scheduler_host_;

View File

@ -24,8 +24,6 @@
namespace mindspore {
namespace ps {
int64_t Util::rank_id_ = -1;
std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
{kApplyMomentum, 0},
{kSparseAdam, 1},