!35053 Replace thrust with cub for NonZero operator.
Merge pull request !35053 from hezhenhao1/add_nonzero
This commit is contained in:
commit
5b724e625d
|
@ -0,0 +1,27 @@
|
|||
# The CUDA toolkit 11 and higher versions include cub library,
|
||||
# but the lower version (CUDA 10) doesn't support.
|
||||
if(USE_CUDA)
|
||||
find_package(CUDA REQUIRED)
|
||||
set(cub_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
find_path(CUB_INCLUDE_DIRS
|
||||
HINTS "${CUDA_INCLUDE_DIRS}"
|
||||
NAMES cub/cub.cuh
|
||||
DOC "The directory where cub library reside.")
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(cub
|
||||
FOUND_VAR CUB_FOUND
|
||||
REQUIRED_VARS CUB_INCLUDE_DIRS)
|
||||
if(CUB_FOUND)
|
||||
include_directories(${CUB_INCLUDE_DIRS})
|
||||
else()
|
||||
set(REQ_URL "https://github.com/NVlabs/cub/archive/1.8.0.zip")
|
||||
set(MD5 "a821b9dffbc9d1bacf1c8db2a59094bf")
|
||||
set(INCLUDE "cub")
|
||||
mindspore_add_pkg(cub
|
||||
VER 1.8.0
|
||||
HEAD_ONLY ${INCLUDE}
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5})
|
||||
include_directories(${cub_INC}/../)
|
||||
endif()
|
||||
endif()
|
|
@ -48,6 +48,7 @@ if(ENABLE_CPU)
|
|||
endif()
|
||||
|
||||
if(ENABLE_GPU)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cub.cmake)
|
||||
if(ENABLE_MPI)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
|
||||
endif()
|
||||
|
|
|
@ -99,6 +99,10 @@ bool NonZeroGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, co
|
|||
"NonZero cudaMemcpyAsync failed.");
|
||||
|
||||
NonZero(input_ptr, index_ptr, shape_ptr, output_ptr, input_size_, rank_, cuda_stream_);
|
||||
auto err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Nonzero kernel failed : " << cudaGetErrorString(err);
|
||||
}
|
||||
|
||||
// The last element of index_ptr is the final output size of NonZero.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&real_output_size_, index_ptr + input_size_ - 1, sizeof(int64_t),
|
||||
|
|
|
@ -15,18 +15,12 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_zero_impl.cuh"
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T>
|
||||
struct is_nonzero {
|
||||
typedef T data_type;
|
||||
typedef size_t index_type;
|
||||
|
||||
__device__ index_type operator()(const data_type &x) const { return x == T(0) ? 0 : 1; }
|
||||
template <typename DataType>
|
||||
struct IsZero {
|
||||
__host__ __device__ __forceinline__ size_t operator()(const DataType &x) const { return x == DataType(0) ? 0 : 1; }
|
||||
};
|
||||
|
||||
template <typename IndexType>
|
||||
|
@ -50,19 +44,20 @@ __global__ void NonZeroKernel(const size_t *index_ptr, const size_t *shape_ptr,
|
|||
template <typename DataType, typename IndexType>
|
||||
CUDA_LIB_EXPORT void NonZero(const DataType *input_ptr, size_t *index_ptr, size_t *shape_ptr, IndexType *output_ptr,
|
||||
size_t input_size, size_t rank, cudaStream_t cuda_stream) {
|
||||
auto device = thrust::cuda::par.on(cuda_stream);
|
||||
auto thrust_input_ptr = thrust::device_pointer_cast(input_ptr);
|
||||
auto thrust_index_ptr = thrust::device_pointer_cast(index_ptr);
|
||||
|
||||
// Transform each non-zero element to 0 if number is zero else is 1,
|
||||
// then using scan method to calculate prefix sum of 01 transformed sequence.
|
||||
thrust::transform(device, thrust_input_ptr, thrust_input_ptr + input_size, thrust_index_ptr, is_nonzero<DataType>());
|
||||
thrust::inclusive_scan(device, thrust_index_ptr, thrust_index_ptr + input_size, thrust_index_ptr);
|
||||
cub::TransformInputIterator<size_t, IsZero<DataType>, const DataType *> iter(input_ptr, IsZero<DataType>());
|
||||
void *d_temp_storage = NULL;
|
||||
size_t temp_storage_bytes = 0;
|
||||
(void)cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, iter, index_ptr, input_size, cuda_stream);
|
||||
(void)cudaMalloc(&d_temp_storage, temp_storage_bytes);
|
||||
(void)cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, iter, index_ptr, input_size, cuda_stream);
|
||||
|
||||
// Extract the first index to appear and transform into output index,
|
||||
// e.g., [0, 0, 1, 2, 2, 2] -> [(1, 2), (2, 3)] -> [(0, 0, 2), (0, 1, 0)] when shape is (2, 1, 3)
|
||||
NonZeroKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(index_ptr, shape_ptr, output_ptr, input_size,
|
||||
rank);
|
||||
// Since cudaGetLastError can return the last error from a runtime call,
|
||||
// we catch the error in Launch function.
|
||||
(void)cudaFree(d_temp_storage);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void NonZero<bool, int64_t>(const bool *input_ptr, size_t *index_ptr, size_t *shape_ptr,
|
||||
|
|
Loading…
Reference in New Issue