forked from mindspore-Ecosystem/mindspore
!17679 Optimize cse
From: @zpac Reviewed-by: @wilfchen,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
c41ab511cd
|
@ -23,6 +23,23 @@
|
|||
template <typename T, typename S>
|
||||
__global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, const size_t batch_size,
|
||||
const size_t class_num, T *loss) {
|
||||
double total_loss = 0.0;
|
||||
T epsilon = 1e-6;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
T logit = logits[i * class_num + labels[i]];
|
||||
if (logit <= 0) {
|
||||
logit = epsilon;
|
||||
}
|
||||
total_loss += -logf(logit);
|
||||
}
|
||||
total_loss /= batch_size;
|
||||
loss[0] = static_cast<T>(total_loss);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void LargeBatchCrossEntropyWithSparseKernel(const T *logits, const S *labels, const size_t batch_size,
|
||||
const size_t class_num, T *loss) {
|
||||
*loss = 0;
|
||||
T epsilon = 1e-6;
|
||||
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size; index += blockDim.x * gridDim.x) {
|
||||
|
@ -67,8 +84,12 @@ __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_
|
|||
template <typename T, typename S>
|
||||
void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss,
|
||||
cudaStream_t cuda_stream) {
|
||||
CrossEntropyWithSparseKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size,
|
||||
class_num, loss);
|
||||
if (batch_size <= kLargeBatchLowLimit) {
|
||||
CrossEntropyWithSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, loss);
|
||||
} else {
|
||||
LargeBatchCrossEntropyWithSparseKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
logits, labels, batch_size, class_num, loss);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
// The batch size limit to judge whether to use multiple threads.
|
||||
constexpr int kLargeBatchLowLimit = 32768;
|
||||
|
||||
template <typename T, typename S>
|
||||
void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
Loading…
Reference in New Issue