forked from OSSInnovation/mindspore
!5590 Fixbugfix for server shard range computation
Merge pull request !5590 from ZPaC/master-fix-server-shard-method
This commit is contained in:
commit
d92c220cc0
|
@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr<std::vector<std::shared_
|
|||
size_t original_row_count = input_shapes->front();
|
||||
if (original_row_count > 0) {
|
||||
size_t offset = 0;
|
||||
if ((original_row_count % server_num) == 0) {
|
||||
offset = original_row_count / server_num * rank_id;
|
||||
} else {
|
||||
offset = std::round((static_cast<float>(original_row_count)) / server_num) * rank_id;
|
||||
std::map<int, int> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num);
|
||||
for (size_t i = 0; i < rank_id; i++) {
|
||||
if (rank_dims.count(i) == 0) {
|
||||
MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
|
||||
}
|
||||
offset += rank_dims[i];
|
||||
}
|
||||
for (size_t i = 0; i < indices_size; i++) {
|
||||
indices_data[i] -= offset;
|
||||
|
|
|
@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) {
|
|||
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
|
||||
|
||||
int Util::LocalShard(int first_dim, int rank_id, int server_num) {
|
||||
int shard_size = std::round((static_cast<float>(first_dim)) / server_num);
|
||||
int remain_size = first_dim % server_num;
|
||||
if (remain_size == 0 || rank_id < server_num - 1) {
|
||||
return shard_size;
|
||||
} else {
|
||||
return first_dim - (shard_size * (server_num - 1));
|
||||
std::map<int, int> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
|
||||
if (shard_dims.count(rank_id) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
|
||||
}
|
||||
return shard_dims[rank_id];
|
||||
}
|
||||
|
||||
std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int server_num) {
|
||||
if (rank_id >= server_num) {
|
||||
MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
|
||||
}
|
||||
std::map<int, int> shard_dims;
|
||||
for (int i = 0; i < server_num; i++) {
|
||||
shard_dims[i] = 0;
|
||||
}
|
||||
if (server_num != static_cast<int>(shard_dims.size())) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
|
||||
}
|
||||
int server_index = -1;
|
||||
for (int i = 0; i < first_dim; i++) {
|
||||
server_index = (server_index + 1) % server_num;
|
||||
shard_dims[server_index] = shard_dims[server_index] + 1;
|
||||
}
|
||||
if (shard_dims.count(rank_id) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
|
||||
}
|
||||
return shard_dims;
|
||||
}
|
||||
|
||||
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
|
|
|
@ -39,6 +39,7 @@ class Util {
|
|||
static std::string optimizer_node_name(int id);
|
||||
static bool is_optimizer(std::string name);
|
||||
static int LocalShard(int first_dim, int rank_id, int server_num);
|
||||
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
|
||||
static void SetRankId(int rank_id);
|
||||
static int GetRankId();
|
||||
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
|
||||
|
|
Loading…
Reference in New Issue