forked from mindspore-Ecosystem/mindspore
!4954 Fix GPU non-sparse cross-entropy op returning all zeros
Merge pull request !4954 from tom_chen/cross_entropy
This commit is contained in:
commit
75af54647f
|
@ -52,12 +52,18 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) {
|
||||
losses[threadIdx.x] = 0;
|
||||
T epsilon = 1e-6;
|
||||
for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) {
|
||||
losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i];
|
||||
dlogits[i] = logits[i] - labels[i];
|
||||
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
|
||||
T epsilon, T *losses, T *dlogits) {
|
||||
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
index < batch_size;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
losses[index] = 0;
|
||||
const int start = index * class_num;
|
||||
const int end = (index + 1) * class_num;
|
||||
for (int i = start; i < end; ++i) {
|
||||
losses[index] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i];
|
||||
dlogits[i] = logits[i] - labels[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,7 +85,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
|
|||
template <typename T, typename S>
|
||||
void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
|
||||
T *dlogits, cudaStream_t cuda_stream) {
|
||||
CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits);
|
||||
T epsilon = 1e-6;
|
||||
CrossEntropyKernel<<<GET_BLOCKS(batch_size), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size, class_num,
|
||||
epsilon, losses, dlogits);
|
||||
}
|
||||
|
||||
template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size,
|
||||
|
|
Loading…
Reference in New Issue