forked from mindspore-Ecosystem/mindspore
fixed rank id
This commit is contained in:
parent
bbc0122bf0
commit
fcca87f9c4
|
@ -99,7 +99,7 @@ void AbstractNode::set_ready_for_scale_in() {
|
|||
connected_nodes_.clear();
|
||||
} else {
|
||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
|
||||
node_info_.rank_id_ = UINT_MAX;
|
||||
node_info_.rank_id_ = UINT32_MAX;
|
||||
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
|
||||
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
|
||||
}
|
||||
|
@ -635,7 +635,7 @@ bool AbstractNode::InitClientToScheduler() {
|
|||
return client_to_scheduler_->WaitConnected();
|
||||
}
|
||||
|
||||
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &rank_id) {
|
||||
const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint32_t &rank_id) {
|
||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||
if (connected_nodes_.find(rank_id) != connected_nodes_.end()) {
|
||||
return connected_nodes_[rank_id];
|
||||
|
|
|
@ -135,7 +135,7 @@ class AbstractNode : public Node {
|
|||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool InitClientToScheduler();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id);
|
||||
|
||||
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
|
@ -162,7 +162,7 @@ class AbstractNode : public Node {
|
|||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||
std::mutex client_mutex_;
|
||||
// the map's key is: rank_id
|
||||
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
std::unordered_map<uint32_t, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
// the key is: request_id, the value is: <rank_id, RecvMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
|
||||
|
|
|
@ -27,9 +27,9 @@ void NodeManager::InitNode() {
|
|||
total_node_num_ = initial_total_node_num_;
|
||||
}
|
||||
|
||||
int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
|
||||
int rank_id = -1;
|
||||
uint32_t rank_id = UINT_MAX;
|
||||
|
||||
const std::string &node_id = register_message.node_id();
|
||||
if (nodes_info_.find(node_id) != nodes_info_.end()) {
|
||||
|
@ -43,9 +43,9 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
uint32_t port = register_message.port();
|
||||
|
||||
rank_id = ++next_server_rank_id_;
|
||||
if (IntToUint(rank_id) >= meta_data_->server_num) {
|
||||
if (rank_id >= meta_data_->server_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
|
||||
rank_id = -1;
|
||||
rank_id = UINT_MAX;
|
||||
--next_server_rank_id_;
|
||||
}
|
||||
NodeInfo node_info;
|
||||
|
@ -61,9 +61,9 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
if (IntToUint(rank_id) >= meta_data_->worker_num) {
|
||||
if (rank_id >= meta_data_->worker_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
|
||||
rank_id = -1;
|
||||
rank_id = UINT_MAX;
|
||||
--next_worker_rank_id_;
|
||||
}
|
||||
NodeInfo node_info;
|
||||
|
|
|
@ -54,7 +54,7 @@ class NodeManager {
|
|||
|
||||
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
|
||||
void InitNode();
|
||||
int NextRankId(const RegisterMessage ®ister_message);
|
||||
uint32_t NextRankId(const RegisterMessage ®ister_message);
|
||||
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
bool CheckNodesScaluOutState();
|
||||
|
|
|
@ -53,7 +53,7 @@ message MessageMeta {
|
|||
// the role of the current node: worker,server,scheduler
|
||||
NodeRole role = 3;
|
||||
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
|
||||
int32 rank_id = 4;
|
||||
uint32 rank_id = 4;
|
||||
// User-defined commands
|
||||
int32 user_cmd = 5;
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ message RegisterMessage {
|
|||
|
||||
message RegisterRespMessage {
|
||||
string node_id = 1;
|
||||
int32 rank_id = 2;
|
||||
uint32 rank_id = 2;
|
||||
}
|
||||
|
||||
message HeartbeatMessage {
|
||||
|
@ -109,7 +109,7 @@ message FetchServersRespMessage {
|
|||
}
|
||||
|
||||
message ServersMeta {
|
||||
int32 rank_id = 1;
|
||||
uint32 rank_id = 1;
|
||||
string ip = 2;
|
||||
int32 port = 3;
|
||||
|
||||
|
|
|
@ -117,8 +117,8 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
register_message.ParseFromArray(data, size);
|
||||
|
||||
// assign worker node and server node rank id
|
||||
int rank_id = node_manager_.NextRankId(register_message);
|
||||
if (rank_id < 0) {
|
||||
uint32_t rank_id = node_manager_.NextRankId(register_message);
|
||||
if (rank_id == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "The rank id is wrong!";
|
||||
}
|
||||
const std::string &node_id = register_message.node_id();
|
||||
|
|
|
@ -126,7 +126,7 @@ class PsCacheManager {
|
|||
const size_t &QueryHashTableSize(const std::string ¶m_name) const;
|
||||
bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; }
|
||||
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
|
||||
void set_rank_id(int rank_id) { rank_id_ = rank_id; }
|
||||
void set_rank_id(uint32_t rank_id) { rank_id_ = rank_id; }
|
||||
bool initialized_ps_cache() const { return initialized_ps_cache_; }
|
||||
size_t vocab_cache_size() const { return vocab_cache_size_; }
|
||||
int cache_indices_lower_bound() const;
|
||||
|
@ -203,7 +203,7 @@ class PsCacheManager {
|
|||
std::pair<int, int> emb_table_slice_bounds_;
|
||||
std::pair<int, int> cache_indices_bounds_;
|
||||
int vocab_cache_size_diff_{0};
|
||||
int rank_id_{0};
|
||||
uint32_t rank_id_{0};
|
||||
std::atomic_bool finish_insert_init_info_{false};
|
||||
std::atomic_bool finish_init_parameter_server_{false};
|
||||
std::atomic_bool running_{false};
|
||||
|
|
|
@ -124,9 +124,9 @@ uint32_t PSContext::initial_server_num() const { return server_num_; }
|
|||
|
||||
std::string PSContext::scheduler_host() const { return scheduler_host_; }
|
||||
|
||||
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
|
||||
|
||||
int PSContext::ps_rank_id() const { return rank_id_; }
|
||||
uint32_t PSContext::ps_rank_id() const { return rank_id_; }
|
||||
|
||||
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const {
|
||||
|
@ -166,7 +166,7 @@ void PSContext::set_cache_enable(bool cache_enable) const {
|
|||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_rank_id(int rank_id) const {
|
||||
void PSContext::set_rank_id(uint32_t rank_id) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
ps_cache_instance.set_rank_id(rank_id);
|
||||
#endif
|
||||
|
|
|
@ -67,8 +67,8 @@ class PSContext {
|
|||
uint32_t initial_worker_num() const;
|
||||
uint32_t initial_server_num() const;
|
||||
std::string scheduler_host() const;
|
||||
void SetPSRankId(int rank_id);
|
||||
int ps_rank_id() const;
|
||||
void SetPSRankId(uint32_t rank_id);
|
||||
uint32_t ps_rank_id() const;
|
||||
void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const;
|
||||
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
|
@ -77,7 +77,7 @@ class PSContext {
|
|||
void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const;
|
||||
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
|
||||
void set_cache_enable(bool cache_enable) const;
|
||||
void set_rank_id(int rank_id) const;
|
||||
void set_rank_id(uint32_t rank_id) const;
|
||||
bool enable_ssl() const;
|
||||
void set_enable_ssl(bool enabled);
|
||||
|
||||
|
@ -155,7 +155,7 @@ class PSContext {
|
|||
is_pserver_(false),
|
||||
is_sched_(false),
|
||||
enable_ssl_(false),
|
||||
rank_id_(-1),
|
||||
rank_id_(0),
|
||||
worker_num_(0),
|
||||
server_num_(0),
|
||||
scheduler_host_(""),
|
||||
|
@ -182,7 +182,7 @@ class PSContext {
|
|||
bool is_pserver_;
|
||||
bool is_sched_;
|
||||
bool enable_ssl_;
|
||||
int rank_id_;
|
||||
uint32_t rank_id_;
|
||||
uint32_t worker_num_;
|
||||
uint32_t server_num_;
|
||||
std::string scheduler_host_;
|
||||
|
|
|
@ -24,8 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
int64_t Util::rank_id_ = -1;
|
||||
|
||||
std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
|
||||
{kApplyMomentum, 0},
|
||||
{kSparseAdam, 1},
|
||||
|
|
Loading…
Reference in New Issue