forked from mindspore-Ecosystem/mindspore
!14686 fix a split_index bug in launch allreduce
From: @lvchangquan Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
48aed2796e
|
@ -2289,6 +2289,23 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
|
|||
return bucket_size_list;
|
||||
}
|
||||
|
||||
void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
|
||||
if (split_index->empty()) {
|
||||
return;
|
||||
}
|
||||
// calculate split index num
|
||||
auto split_index_len = split_index->size();
|
||||
uint32_t split_index_num = (*split_index)[split_index_len - 1];
|
||||
// obtain graph output tensor num
|
||||
auto grads_count = GetBpropGraphGradsCount(graph);
|
||||
if (split_index_num == 0 || split_index_num >= grads_count) {
|
||||
MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count;
|
||||
}
|
||||
if (split_index_num < grads_count - 1) {
|
||||
split_index->push_back(grads_count - 1);
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
|
||||
|
@ -2301,6 +2318,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
|
|||
std::vector<std::shared_ptr<device::Bucket>> bucket_list;
|
||||
// Create bucket for every split allreduce ops
|
||||
auto split_index = GetAllReduceSplitIndex();
|
||||
PreProcessOnSplitIndex(graph, &split_index);
|
||||
auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
|
||||
uint32_t bucket_id = 0;
|
||||
for (auto bucket_size : bucket_size_list) {
|
||||
|
|
Loading…
Reference in New Issue