diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 29ad141324..5945a86e2b 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptrfront(); if (original_row_count > 0) { size_t offset = 0; - if ((original_row_count % server_num) == 0) { - offset = original_row_count / server_num * rank_id; - } else { - offset = std::round((static_cast(original_row_count)) / server_num) * rank_id; + std::map rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num); + for (size_t i = 0; i < rank_id; i++) { + if (rank_dims.count(i) == 0) { + MS_LOG(EXCEPTION) << "No local shard number for rank " << i; + } + offset += rank_dims[i]; } for (size_t i = 0; i < indices_size; i++) { indices_data[i] -= offset; diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc index 326e1c113e..ec1f01626e 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) { bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } int Util::LocalShard(int first_dim, int rank_id, int server_num) { - int shard_size = std::round((static_cast(first_dim)) / server_num); - int remain_size = first_dim % server_num; - if (remain_size == 0 || rank_id < server_num - 1) { - return shard_size; - } else { - return first_dim - (shard_size * (server_num - 1)); + std::map shard_dims = AllRankLocalShard(first_dim, rank_id, server_num); + if (shard_dims.count(rank_id) == 0) { + MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id; } + return shard_dims[rank_id]; +} + +std::map Util::AllRankLocalShard(int first_dim, int rank_id, int server_num) { + if (rank_id >= server_num) { + MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num; + } + std::map shard_dims; + for (int i = 0; i < server_num; i++) { + shard_dims[i] = 0; + } + if (server_num != static_cast(shard_dims.size())) { + MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size(); + } + int server_index = -1; + for (int i = 0; i < first_dim; i++) { + server_index = (server_index + 1) % server_num; + shard_dims[server_index] = shard_dims[server_index] + 1; + } + if (shard_dims.count(rank_id) == 0) { + MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num; + } + return shard_dims; } void Util::SetRankId(int rank_id) { rank_id_ = rank_id; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h index fe55f51222..9974482e2b 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/util.h +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -39,6 +39,7 @@ class Util { static std::string optimizer_node_name(int id); static bool is_optimizer(std::string name); static int LocalShard(int first_dim, int rank_id, int server_num); + static std::map AllRankLocalShard(int first_dim, int rank_id, int server_num); static void SetRankId(int rank_id); static int GetRankId(); static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,