!4954 Fix GPU non-sparse cross-entropy op returning all zeros

Merge pull request !4954 from tom_chen/cross_entropy
This commit is contained in:
mindspore-ci-bot 2020-08-24 09:22:25 +08:00 committed by Gitee
commit 75af54647f
1 changed files with 15 additions and 7 deletions

View File

@ -52,12 +52,18 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
} }
template <typename T, typename S> template <typename T, typename S>
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
losses[threadIdx.x] = 0; T epsilon, T *losses, T *dlogits) {
T epsilon = 1e-6; for (size_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { index < batch_size;
losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i]; index += blockDim.x * gridDim.x) {
dlogits[i] = logits[i] - labels[i]; losses[index] = 0;
const int start = index * class_num;
const int end = (index + 1) * class_num;
for (int i = start; i < end; ++i) {
losses[index] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i];
dlogits[i] = logits[i] - labels[i];
}
} }
} }
@ -79,7 +85,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
template <typename T, typename S> template <typename T, typename S>
void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
T *dlogits, cudaStream_t cuda_stream) { T *dlogits, cudaStream_t cuda_stream) {
CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); T epsilon = 1e-6;
CrossEntropyKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size, class_num,
epsilon, losses, dlogits);
} }
template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size, template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size,