bugfix:SigmoidCrossEntropyWithLogitsGrad need multiply dout

This commit is contained in:
lizhenyu 2020-08-17 15:45:43 +08:00
parent f41ca6b5c6
commit d667d6ee92
3 changed files with 13 additions and 11 deletions

View File

@ -18,24 +18,24 @@
template <typename T, typename S> template <typename T, typename S>
__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels,
T *outputs) { const T *dout_addr, T *outputs) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (logits[i] >= 0) { if (logits[i] >= 0) {
outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i];
} else { } else {
const T exp_val = exp(logits[i]); const T exp_val = exp(logits[i]);
outputs[i] = exp_val / (1. + exp_val) - labels[i]; outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i];
} }
} }
} }
template <typename T, typename S> template <typename T, typename S>
void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr,
cudaStream_t cuda_stream) { T *outputs, cudaStream_t cuda_stream) {
SigmoidCrossEntropyWithLogitsGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, logits, labels, SigmoidCrossEntropyWithLogitsGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, logits, labels,
outputs); dout_addr, outputs);
} }
template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits,
const float *labels, float *outputs, const float *labels, const float *dout_addr,
cudaStream_t cuda_stream); float *outputs, cudaStream_t cuda_stream);

View File

@ -19,7 +19,7 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S> template <typename T, typename S>
void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr,
cudaStream_t cuda_stream); T *outputs, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_

View File

@ -38,9 +38,10 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override { const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *logits_addr = GetDeviceAddress<T>(inputs, 0); T *logits_addr = GetDeviceAddress<T>(inputs, 0);
S *labels_addr = GetDeviceAddress<S>(inputs, 1); S *labels_addr = GetDeviceAddress<S>(inputs, 1);
T *dout_addr = GetDeviceAddress<T>(inputs, 2);
T *outputs_addr = GetDeviceAddress<T>(outputs, 0); T *outputs_addr = GetDeviceAddress<T>(outputs, 0);
SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, dout_addr, outputs_addr,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
@ -78,6 +79,7 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel {
void InitSizeLists() override { void InitSizeLists() override {
input_size_list_.push_back(logits_size_); input_size_list_.push_back(logits_size_);
input_size_list_.push_back(labels_size_); input_size_list_.push_back(labels_size_);
input_size_list_.push_back(logits_size_);
output_size_list_.push_back(outputs_size_); output_size_list_.push_back(outputs_size_);
} }