diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cu index 8fba2edba48..285981b1724 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cu @@ -36,43 +36,48 @@ constexpr int ALIGN_BYTES = 16; } \ } -inline dim3 SpatialSoftMax_getGridSize(dim3 *block, uint32_t max_active_blocks, uint64_t outer_size, uint64_t dim_size, - uint64_t inner_size) { - // First, tile as many blocks as we can over the y axis - uint32_t inner_blocks = (inner_size + block->y - 1) / block->y; - if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks; - // Fill the x axis with as many blocks as we can fit (a little more is ok too) - uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; - if (outer_blocks > outer_size) outer_blocks = outer_size; - return dim3(outer_blocks, inner_blocks); +inline dim3 SpatialSoftMaxGetGridSize(dim3 *block, uint32_t activate_block, uint64_t outer_size, uint64_t dim_size, + uint64_t inner_size) { + uint32_t inner = (inner_size + block->y - 1) / block->y; + if (inner > activate_block) { + inner = activate_block; + } + uint32_t outer = (activate_block + inner - 1) / inner; + if (outer > outer_size) { + outer = outer_size; + } + return dim3(outer, inner); } -inline dim3 SpatialSoftMax_getBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) { - uint32_t inner_threads = inner_size; - inner_threads = std::min(inner_threads, static_cast(max_threads)); +inline dim3 SpatialSoftMaxGetBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) { + uint32_t inner_ths = inner_size; + inner_ths = std::min(inner_ths, static_cast(max_threads)); uint32_t dim_threads = 1; - if (inner_threads <= 64 && dim_size >= 64) { - while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size) dim_threads *= 2; + if (inner_ths <= 64 && dim_size >= 64) { + while ((inner_ths * dim_threads <= max_threads) && (dim_threads <= dim_size)) { + dim_threads *= 2; + } dim_threads /= 2; } - return dim3(dim_threads, inner_threads); + return dim3(dim_threads, inner_ths); } template -void SpatialSoftMax_getLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid, - dim3 *block, uint32_t *smem_size, uint32_t device_id) { - *block = SpatialSoftMax_getBlockSize(outer_size, dim_size, inner_size); - uint32_t block_threads = block->x * block->y; - *smem_size = block->x == 1 ? 0 : block_threads * sizeof(accumulate_t); - - int max_active_blocks; - - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, k, block_threads, *smem_size); +void SpatialSoftMaxGetLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid, + dim3 *block, uint32_t *smem_size, uint32_t device_id) { + *block = SpatialSoftMaxGetBlockSize(outer_size, dim_size, inner_size); + uint32_t block_ths = block->x * block->y; + if (block->x == 1) { + *smem_size = 0; + } else { + *smem_size = block_ths * sizeof(accumulate_t); + } + int activate_size; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&activate_size, k, block_ths, *smem_size); cudaDeviceProp prop; (void)cudaGetDeviceProperties(&prop, device_id); - max_active_blocks *= prop.multiProcessorCount; - - *grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size); + activate_size *= prop.multiProcessorCount; + *grid = SpatialSoftMaxGetGridSize(block, activate_size, outer_size, dim_size, inner_size); } int log2_ceil(int val) { @@ -135,63 +140,58 @@ __forceinline__ __device__ T SpatialBlockReduceX(T *memsha, T val) { return memsha[0]; } -template +template __global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) { extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); const uint32_t outer_stride = inner_size * dim_size; const uint32_t dim_stride = inner_size; - for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { const uint32_t outer_offset = outer_index * outer_stride; for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) { - const uint32_t data_offset = outer_offset + inner_index; - + const uint32_t offset = outer_offset + inner_index; if (blockDim.x > 1) { - accumulate_t max_input = std::numeric_limits::lowest(); + accumulate_t max_data_input = std::numeric_limits::lowest(); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { - const accumulate_t value = static_cast(input[data_offset + d * dim_stride]); - max_input = atomic::Max()(max_input, value); + const accumulate_t value = static_cast(input[offset + d * dim_stride]); + max_data_input = atomic::Max()(max_data_input, value); } - max_input = SpatialBlockReduceX(sdata, max_input); - + max_data_input = SpatialBlockReduceX(sdata, max_data_input); accumulate_t sum = 0; for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) - sum += std::exp(static_cast(input[data_offset + d * dim_stride]) - max_input); + sum += std::exp(static_cast(input[offset + d * dim_stride]) - max_data_input); sum = SpatialBlockReduceX(sdata, sum); - - SoftMaxForwardEpilogue epilogue(max_input, sum); + SoftMaxForwardEpilogue epilogue(max_data_input, sum); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) - output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); + output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]); } else { - accumulate_t max_input = std::numeric_limits::lowest(); + accumulate_t max_data_input = std::numeric_limits::lowest(); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { - const accumulate_t value = static_cast(input[data_offset + d * dim_stride]); - max_input = atomic::Max()(max_input, value); + const accumulate_t value = static_cast(input[offset + d * dim_stride]); + max_data_input = atomic::Max()(max_data_input, value); } accumulate_t sum = 0; for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) - sum += std::exp(static_cast(input[data_offset + d * dim_stride]) - max_input); - SoftMaxForwardEpilogue epilogue(max_input, sum); + sum += std::exp(static_cast(input[offset + d * dim_stride]) - max_data_input); + + SoftMaxForwardEpilogue epilogue(max_data_input, sum); for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) - output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); + output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]); } } } } -template +template __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k, input_t sum_all) { - SoftMaxForwardEpilogue epilogue(max_k, sum_all); + SoftMaxForwardEpilogue epilogue(max_k, sum_all); int offset = threadIdx.x; - int last = clas % (InsP * blockDim.x); for (; offset < clas - last; offset += blockDim.x * InsP) { input_t tmp[InsP]; - #pragma unroll for (int j = 0; j < InsP; ++j) { tmp[j] = input[offset + j * blockDim.x]; @@ -201,54 +201,42 @@ __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t output[offset + j * blockDim.x] = epilogue(tmp[j]); } } - for (; offset < clas; offset += blockDim.x) { output[offset] = epilogue(input[offset]); } } -template -__device__ __forceinline__ void WriteResultsVectorized(int size, const int shift, input_t *input, output_t *output, +template +__device__ __forceinline__ void WriteResultsVectorized(int size, const int deviate, input_t *input, output_t *output, input_t max_k, input_t sum_all) { - SoftMaxForwardEpilogue epilogue(max_k, sum_all); - - using LoadT = aligned_vector; - using StoreT = aligned_vector; - + SoftMaxForwardEpilogue epilogue(max_k, sum_all); + using loadT = aligned_vector; + using storeT = aligned_vector; int offset = threadIdx.x; - - if (shift > 0) { - input -= shift; - output -= shift; - size += shift; - - if (threadIdx.x >= shift) { + if (deviate > 0) { + input -= deviate; + output -= deviate; + size += deviate; + if (threadIdx.x >= deviate) { output[offset] = epilogue(input[offset]); } size -= blockDim.x; input += blockDim.x; output += blockDim.x; } - const int last = size % (InsP * blockDim.x); - input_t in_v[InsP]; - LoadT *in_value = reinterpret_cast(&in_v); - + loadT *in_value = reinterpret_cast(&in_v); output_t out_v[InsP]; - StoreT *out_value = reinterpret_cast(&out_v); - + storeT *out_value = reinterpret_cast(&out_v); for (; offset * InsP < (size - last); offset += blockDim.x) { - *in_value = reinterpret_cast(input)[offset]; - + *in_value = reinterpret_cast(input)[offset]; #pragma unroll for (int j = 0; j < InsP; ++j) { out_v[j] = epilogue(in_v[j]); } - - reinterpret_cast(output)[offset] = *out_value; + reinterpret_cast(output)[offset] = *out_value; } - offset = size - last + threadIdx.x; for (; offset < size; offset += blockDim.x) { output[offset] = epilogue(input[offset]); @@ -256,52 +244,43 @@ __device__ __forceinline__ void WriteResultsVectorized(int size, const int shift } template -__device__ __forceinline__ AccT ReduceBlock(AccT *smem, AccT val, AccT defaultVal) { +__device__ __forceinline__ AccT ReduceBlock(AccT *sharemen, AccT val, AccT initVal) { Reduction r = Reduction(); - __syncthreads(); - - smem[threadIdx.x] = val; - + sharemen[threadIdx.x] = val; __syncthreads(); - - AccT warpVal = defaultVal; - + AccT warpVal = initVal; uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1; if (threadIdx.x < WARPSIZE) { int lane = threadIdx.x % WARPSIZE; if (lane < blockDim.x / WARPSIZE) { #pragma unroll for (int i = 0; i < WARPSIZE; ++i) { - warpVal = r(warpVal, smem[lane * WARPSIZE + i]); + warpVal = r(warpVal, sharemen[lane * WARPSIZE + i]); } #ifndef __HIP_PLATFORM_HCC__ __syncwarp(mask); #endif - smem[lane] = warpVal; + sharemen[lane] = warpVal; } } - __syncthreads(); - AccT blockVal = defaultVal; - + AccT blockVal = initVal; if (threadIdx.x == 0) { for (int i = 0; i < blockDim.x / WARPSIZE; ++i) { - blockVal = r(blockVal, smem[i]); + blockVal = r(blockVal, sharemen[i]); } - smem[0] = blockVal; + sharemen[0] = blockVal; } - __syncthreads(); - return smem[0]; + return sharemen[0]; } template