forked from mindspore-Ecosystem/mindspore
use cub lib
This commit is contained in:
parent
c218d871be
commit
476484c2d9
|
@ -15,9 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/masked_select_grad_impl.cuh"
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
|
@ -103,7 +101,6 @@ CUDA_LIB_EXPORT void MaskedSelectGrad(T *input_grad_ptr, const bool *mask_ptr, s
|
|||
const std::vector<size_t> i, const std::vector<size_t> j, const std::vector<size_t> o,
|
||||
T *input_broadcast_grad_ptr, bool *mask_broadcast_ptr,
|
||||
T *output_grad_ptr, cudaStream_t cuda_stream) {
|
||||
auto device = thrust::cuda::par.on(cuda_stream);
|
||||
size_t broadcast_size = o[0] * o[1] * o[2] * o[3] * o[4] * o[5] * o[6];
|
||||
const bool *last_mask = nullptr;
|
||||
|
||||
|
@ -116,12 +113,14 @@ CUDA_LIB_EXPORT void MaskedSelectGrad(T *input_grad_ptr, const bool *mask_ptr, s
|
|||
last_mask = mask_ptr;
|
||||
}
|
||||
|
||||
auto thrust_mask_ptr = thrust::device_pointer_cast(last_mask);
|
||||
auto thrust_index_ptr = thrust::device_pointer_cast(index_ptr);
|
||||
|
||||
// using scan method to calculate prefix sum of 01 transformed sequence
|
||||
thrust::transform(device, thrust_mask_ptr, thrust_mask_ptr + broadcast_size, thrust_index_ptr, BoolToSize());
|
||||
thrust::inclusive_scan(device, thrust_index_ptr, thrust_index_ptr + broadcast_size, thrust_index_ptr);
|
||||
// using cub to calculate prefix sum of 01 transformed sequence
|
||||
BoolToSize op;
|
||||
cub::TransformInputIterator<size_t, BoolToSize, const bool*> iter(last_mask, op);
|
||||
size_t temp_storage_bytes = 0;
|
||||
(void)cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, iter, index_ptr, broadcast_size, cuda_stream);
|
||||
void *d_temp_storage = nullptr;
|
||||
(void)cudaMalloc(&d_temp_storage, temp_storage_bytes);
|
||||
(void)cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, iter, index_ptr, broadcast_size, cuda_stream);
|
||||
|
||||
// Extract the first index to appear and transform into output index
|
||||
if (input_broadcast_grad_ptr != nullptr) {
|
||||
|
@ -136,6 +135,7 @@ CUDA_LIB_EXPORT void MaskedSelectGrad(T *input_grad_ptr, const bool *mask_ptr, s
|
|||
index_ptr, output_grad_ptr,
|
||||
broadcast_size);
|
||||
}
|
||||
(void)cudaFree(d_temp_storage);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MaskedSelectGrad<uint8_t>(uint8_t *input_grad_ptr,
|
||||
|
|
Loading…
Reference in New Issue