!14409 fix a bug in launch allreduce

From: @lvchangquan
Reviewed-by: @chujinjin,@jjfeing
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2021-04-02 16:42:44 +08:00 committed by Gitee
commit 08189a1428
2 changed files with 9 additions and 3 deletions

View File

@ -92,6 +92,7 @@ void Bucket::CalculateMean() {
MS_EXCEPTION_IF_NULL(parallel_context);
auto grad_mean = parallel_context->gradients_mean();
if (!grad_mean) {
UpdateTensorOutputAddr(ar_output_addr_);
return;
}
if (launch_mul_ == nullptr) {
@ -102,12 +103,16 @@ void Bucket::CalculateMean() {
launch_mul_->SetInputAddr(ar_output_addr_);
// launch mean
launch_mul_->LaunchOpKernel();
// store output tensor addr
// store tensor output addr
auto launch_output = launch_mul_->GetKernelOutputAddr();
if (launch_output.size() != 1) {
MS_LOG(ERROR) << "launch mul outputs should have one output";
MS_LOG(EXCEPTION) << "launch mul outputs should have one output";
}
uint8_t *tensor_output = launch_output[0];
UpdateTensorOutputAddr(launch_output[0]);
}
void Bucket::UpdateTensorOutputAddr(uint8_t *addr) {
uint8_t *tensor_output = addr;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list_[i];

View File

@ -84,6 +84,7 @@ class Bucket {
virtual void FreeAllDeviceMem() = 0;
virtual void FreeDeviceMem(void *dev_ptr) = 0;
virtual void CopyTensorToContiguousMemory() = 0;
void UpdateTensorOutputAddr(uint8_t *addr);
void LazyDeleteOldAddr();
};
} // namespace mindspore::device