!18325 Add check for AllReduce fusion index in PyNative mode
Merge pull request !18325 from caifubi/master-pynative-allreduce
This commit is contained in:
commit
e38dc88d9c
|
@ -2278,14 +2278,25 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
|
|||
return bucket_size_list;
|
||||
}
|
||||
|
||||
void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
|
||||
uint32_t last = 0;
|
||||
for (auto &index : split_index) {
|
||||
if (index <= last) {
|
||||
MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
|
||||
}
|
||||
last = index;
|
||||
}
|
||||
}
|
||||
|
||||
void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
|
||||
MS_EXCEPTION_IF_NULL(split_index);
|
||||
if (split_index->empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
CheckSplitIndexValid(*split_index);
|
||||
// calculate split index num
|
||||
auto split_index_len = split_index->size();
|
||||
uint32_t split_index_num = (*split_index)[split_index_len - 1];
|
||||
auto split_index_num = split_index->back();
|
||||
// obtain graph output tensor num
|
||||
auto grads_count = GetBpropGraphGradsCount(graph);
|
||||
if (split_index_num >= grads_count) {
|
||||
|
|
Loading…
Reference in New Issue