!17679 Optimize cse

From: @zpac
Reviewed-by: @wilfchen,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-06-04 09:14:26 +08:00 committed by Gitee
commit c41ab511cd
2 changed files with 26 additions and 2 deletions

View File

@ -23,6 +23,23 @@
template <typename T, typename S>
__global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, const size_t batch_size,
const size_t class_num, T *loss) {
double total_loss = 0.0;
T epsilon = 1e-6;
for (size_t i = 0; i < batch_size; ++i) {
T logit = logits[i * class_num + labels[i]];
if (logit <= 0) {
logit = epsilon;
}
total_loss += -logf(logit);
}
total_loss /= batch_size;
loss[0] = static_cast<T>(total_loss);
return;
}
template <typename T, typename S>
__global__ void LargeBatchCrossEntropyWithSparseKernel(const T *logits, const S *labels, const size_t batch_size,
const size_t class_num, T *loss) {
*loss = 0;
T epsilon = 1e-6;
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size; index += blockDim.x * gridDim.x) {
@ -67,8 +84,12 @@ __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_
template <typename T, typename S>
void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss,
cudaStream_t cuda_stream) {
CrossEntropyWithSparseKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size,
class_num, loss);
if (batch_size <= kLargeBatchLowLimit) {
CrossEntropyWithSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, loss);
} else {
LargeBatchCrossEntropyWithSparseKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(
logits, labels, batch_size, class_num, loss);
}
return;
}

View File

@ -19,6 +19,9 @@
#include "runtime/device/gpu/cuda_common.h"
// The batch size limit to judge whether to use multiple threads.
constexpr int kLargeBatchLowLimit = 32768;
template <typename T, typename S>
void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss,
cudaStream_t cuda_stream);