!49671 Fix ReduceSum input change

Merge pull request !49671 from zhanzhan/reducesum
This commit is contained in:
i-robot 2023-03-06 08:21:18 +00:00 committed by Gitee
commit 1bd0d28271
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 5 deletions

View File

@ -209,6 +209,8 @@ bool ArrayReduceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
void ArrayReduceGpuKernelMod::InitCudnnResource() {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed.");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_),
"cudnnGetTensorSizeInBytes failed.");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, inputA_descriptor_, outputC_descriptor_,
@ -351,16 +353,16 @@ void ArrayReduceGpuKernelMod::LaunchIntKernel(const std::vector<AddressPtr> &inp
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
S *input_addr = GetDeviceAddress<S>(inputs, 0);
S *output_addr = GetDeviceAddress<S>(outputs, 0);
S *workspace_addr = GetPossiblyNullDeviceAddress<S>(workspace, 0);
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
S *workspace_addr = GetPossiblyNullDeviceAddress<S>(workspace, 0);
T *casted_input = GetDeviceAddress<T>(inputs, 0);
T *output_before_cast = GetDeviceAddress<T>(outputs, 0);
T *casted_input = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(input_size_));
T *output_before_cast =
reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(output_size_));
const int input_num = input_size_ / sizeof(T);
const int output_num = output_size_list_[kIndex0] / sizeof(S);
const int output_num = output_size_ / sizeof(S);
Cast(input_num, input_addr, casted_input, reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID);
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
@ -368,6 +370,8 @@ void ArrayReduceGpuKernelMod::LaunchIntKernel(const std::vector<AddressPtr> &inp
inputA_descriptor_, casted_input, &beta, outputC_descriptor_, output_before_cast),
"cudnnReduceTensor failed.");
Cast(output_num, output_before_cast, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr), GET_CTX_DEVICE_ID);
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(casted_input);
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_before_cast);
return;
}