!496 fix bug in cross entropy error

Merge pull request !496 from SanjayChan/cross_entropy
This commit is contained in:
mindspore-ci-bot 2020-04-21 15:11:42 +08:00 committed by Gitee
commit 58488c5dc8
5 changed files with 17 additions and 134 deletions

View File

@ -1,47 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "cross_entropy_cuda_impl.cuh"
#include "include/cuda_runtime.h"
__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits,
const float *labels, const int batch_size, const int num_classes,
float *loss, float *dx) {
extern __shared__ float loss_shared[];
const float mean_scale = 1.0f / static_cast<float>(batch_size);
loss_shared[threadIdx.x] = 0;
for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) {
loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i];
dx[i] = (softmax_logits[i] - labels[i]) * mean_scale;
}
__syncthreads();
if (threadIdx.x == 0) {
*loss = 0;
for (int i = 0; i < batch_size; i++) {
*loss += loss_shared[i];
}
*loss *= mean_scale;
}
}
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
const int batch_size, const int num_classes, float *loss, float *dx,
cudaStream_t cuda_stream) {
CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>(
softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx);
}

View File

@ -1,26 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#include "device/gpu/cuda_common.h"
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
const int batch_size, const int num_classes, float *loss, float *dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_

View File

@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
}
template <typename T, typename S>
__global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size,
const size_t class_num, T *losses) {
T epsilon = 1e-6;
for (size_t i = 0; i < batch_size; ++i) {
T logit = 0.0;
for (size_t j = 0; j < class_num; j++) {
if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) {
logit = logits[i * class_num + j];
break;
}
}
if (logit <= 0) {
logit += epsilon;
}
losses[i] = -logf(logit);
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) {
losses[threadIdx.x] = 0;
for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) {
losses[threadIdx.x] -= logf(logits[i]) * labels[i];
dlogits[i] = logits[i] - labels[i];
}
return;
}
template <typename T, typename S>
__global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size,
const size_t class_num, T *grad) {
for (size_t i = 0; i < batch_size; i++) {
for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) {
if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) {
grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size;
} else {
grad[i * class_num + j] = logits[i * class_num + j] / batch_size;
}
}
}
return;
}
template <typename T, typename S>
@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
}
template <typename T, typename S>
void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
T *losses, cudaStream_t cuda_stream) {
CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses);
return;
}
template <typename T, typename S>
void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
T *grad, cudaStream_t cuda_stream) {
CrossEntropyGradWithoutSparseKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>(
logits, labels, batch_size, class_num, grad);
return;
void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
T *dlogits, cudaStream_t cuda_stream) {
CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits);
}
template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size,
@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse<float, int>(const float *logits, const
template void CrossEntropyGradWithSparse<float, int64_t>(const float *logits, const int64_t *labels,
const size_t batch_size, const size_t class_num, float *grad,
cudaStream_t cuda_stream);
template void CrossEntropyWithoutSparse<float, float>(const float *logits, const float *labels, const size_t batch_size,
const size_t class_num, float *losses, cudaStream_t cuda_stream);
template void CrossEntropyGradWithoutSparse<float, float>(const float *logits, const float *labels,
const size_t batch_size, const size_t class_num, float *grad,
cudaStream_t cuda_stream);
template void CrossEntropy<float, float>(const float *logits, const float *labels, const size_t batch_size,
const size_t class_num, float *losses, float *dlogits,
cudaStream_t cuda_stream);

View File

@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
T *grad, cudaStream_t cuda_stream);
template <typename T, typename S>
void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
T *losses, cudaStream_t cuda_stream);
template <typename T, typename S>
void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
T *grad, cudaStream_t cuda_stream);
void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
T *dlogits, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_

View File

@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
}
T *logits_addr = GetDeviceAddress<T>(inputs, 0);
S *labels_addr = GetDeviceAddress<S>(inputs, 1);
T *output1_addr = GetDeviceAddress<T>(outputs, 0);
T *output2_addr = GetDeviceAddress<T>(outputs, 1);
T *loss_addr = GetDeviceAddress<T>(outputs, 0);
T *dlogits_addr = GetDeviceAddress<T>(outputs, 1);
T *softmax_output_logits = GetDeviceAddress<T>(workspace, 0);
const float alpha = 1;
@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
softmax_output_descriptor_, softmax_output_logits),
"cudnnSoftmaxForward failed.");
CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {