!14409 fix a bug in launch allreduce
From: @lvchangquan Reviewed-by: @chujinjin,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
08189a1428
|
@ -92,6 +92,7 @@ void Bucket::CalculateMean() {
|
|||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto grad_mean = parallel_context->gradients_mean();
|
||||
if (!grad_mean) {
|
||||
UpdateTensorOutputAddr(ar_output_addr_);
|
||||
return;
|
||||
}
|
||||
if (launch_mul_ == nullptr) {
|
||||
|
@ -102,12 +103,16 @@ void Bucket::CalculateMean() {
|
|||
launch_mul_->SetInputAddr(ar_output_addr_);
|
||||
// launch mean
|
||||
launch_mul_->LaunchOpKernel();
|
||||
// store output tensor addr
|
||||
// store tensor output addr
|
||||
auto launch_output = launch_mul_->GetKernelOutputAddr();
|
||||
if (launch_output.size() != 1) {
|
||||
MS_LOG(ERROR) << "launch mul outputs should have one output";
|
||||
MS_LOG(EXCEPTION) << "launch mul outputs should have one output";
|
||||
}
|
||||
uint8_t *tensor_output = launch_output[0];
|
||||
UpdateTensorOutputAddr(launch_output[0]);
|
||||
}
|
||||
|
||||
void Bucket::UpdateTensorOutputAddr(uint8_t *addr) {
|
||||
uint8_t *tensor_output = addr;
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
new_tensor_output_addrs_.emplace_back(tensor_output);
|
||||
tensor_output += align_size_list_[i];
|
||||
|
|
|
@ -84,6 +84,7 @@ class Bucket {
|
|||
virtual void FreeAllDeviceMem() = 0;
|
||||
virtual void FreeDeviceMem(void *dev_ptr) = 0;
|
||||
virtual void CopyTensorToContiguousMemory() = 0;
|
||||
void UpdateTensorOutputAddr(uint8_t *addr);
|
||||
void LazyDeleteOldAddr();
|
||||
};
|
||||
} // namespace mindspore::device
|
||||
|
|
Loading…
Reference in New Issue