forked from mindspore-Ecosystem/mindspore
fixBCE
This commit is contained in:
parent
c5867d2d7e
commit
651d26b1da
|
@ -18,11 +18,18 @@
|
|||
#include "loss_with_reduction_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
inline __device__ float logT(float x) { return logf(x); }
|
||||
inline __device__ half logT(half x) { return hlog(x); }
|
||||
inline __device__ float castT(float ref, int x) { return __int2float_rd(x); }
|
||||
inline __device__ half castT(half ref, int x) { return __int2half_rd(x); }
|
||||
inline __device__ float maxT(float a, float b) { return fmaxf(a, b); }
|
||||
inline __device__ half maxT(half a, half b) { return a > b ? a : b; }
|
||||
|
||||
template <typename T>
|
||||
__global__ void Copy(T *loss, T *tmp_loss, int reduction, int input_size) {
|
||||
loss[0] += tmp_loss[0];
|
||||
if (reduction == 1) {
|
||||
loss[0] /= input_size;
|
||||
loss[0] /= castT(loss[0], input_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,14 +55,14 @@ __global__ void KLDivLossKernel(const int input_size, const int reduction, const
|
|||
T epsilon = 1e-6;
|
||||
if (reduction == 0) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_y[i], epsilon);
|
||||
T value = input_y[i] * (logf(denominator) - input_x[i]);
|
||||
T denominator = maxT(input_y[i], epsilon);
|
||||
T value = input_y[i] * (logT(denominator) - input_x[i]);
|
||||
loss[i] = value;
|
||||
}
|
||||
} else {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_y[i], epsilon);
|
||||
T value = input_y[i] * (logf(denominator) - input_x[i]);
|
||||
T denominator = maxT(input_y[i], epsilon);
|
||||
T value = input_y[i] * (logT(denominator) - input_x[i]);
|
||||
tmp_loss[i] = value;
|
||||
}
|
||||
}
|
||||
|
@ -85,21 +92,22 @@ template <typename T>
|
|||
__global__ void KLDivLossGradKernel(const int input_size, const int reduction, const T *input_x, const T *input_y,
|
||||
const T *dloss, T *dx, T *dy) {
|
||||
T epsilon = 1e-6;
|
||||
T one = static_cast<T>(1);
|
||||
if (reduction == 0) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_y[i], epsilon);
|
||||
T denominator = maxT(input_y[i], epsilon);
|
||||
dx[i] = -input_y[i] * dloss[i];
|
||||
dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss[i];
|
||||
dy[i] = (logT(denominator) + one - input_x[i]) * dloss[i];
|
||||
}
|
||||
} else {
|
||||
T dloss1 = dloss[0];
|
||||
if (reduction == 1) {
|
||||
dloss1 = dloss[0] / input_size;
|
||||
dloss1 = dloss[0] / castT(dloss[0], input_size);
|
||||
}
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_y[i], epsilon);
|
||||
T denominator = maxT(input_y[i], epsilon);
|
||||
dx[i] = -input_y[i] * dloss1;
|
||||
dy[i] = (logf(denominator) + 1 - input_x[i]) * dloss1;
|
||||
dy[i] = (logT(denominator) + one - input_x[i]) * dloss1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -115,16 +123,17 @@ template <typename T>
|
|||
__global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x,
|
||||
const T *input_y, const T *weight, T *loss, T *tmp_loss) {
|
||||
T epsilon = 1e-12;
|
||||
T one = static_cast<T>(1);
|
||||
if (reduction == 0) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T value =
|
||||
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
|
||||
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
|
||||
loss[i] = value;
|
||||
}
|
||||
} else {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T value =
|
||||
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
|
||||
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
|
||||
tmp_loss[i] = value;
|
||||
}
|
||||
}
|
||||
|
@ -154,19 +163,20 @@ template <typename T>
|
|||
__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x,
|
||||
const T *input_y, const T *weight, const T *dloss, T *dx) {
|
||||
T epsilon = 1e-12;
|
||||
T one = static_cast<T>(1);
|
||||
if (reduction == 0) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_x[i] * (1 - input_x[i]), epsilon);
|
||||
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
|
||||
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
|
||||
dx[i] = value * dloss[i];
|
||||
}
|
||||
} else {
|
||||
T dloss1 = dloss[0];
|
||||
if (reduction == 1) {
|
||||
dloss1 = dloss[0] / input_size;
|
||||
dloss1 = dloss[0] / castT(dloss[0], input_size);
|
||||
}
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
|
||||
T denominator = max(input_x[i] * (1 - input_x[i]), epsilon);
|
||||
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
|
||||
T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
|
||||
dx[i] = value * dloss1;
|
||||
}
|
||||
|
@ -180,16 +190,30 @@ void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, con
|
|||
input_y, weight, dloss, dx);
|
||||
}
|
||||
|
||||
template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
|
||||
float *loss, float *tmp_loss, cudaStream_t stream);
|
||||
template void KLDivLoss<float>(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
|
||||
float *loss, float *tmp_loss, cudaStream_t stream);
|
||||
|
||||
template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
|
||||
const float *dloss, float *dx, float *dy, cudaStream_t stream);
|
||||
template void KLDivLossGrad<float>(const int &input_size, const int &reduction, const float *input_x,
|
||||
const float *input_y, const float *dloss, float *dx, float *dy, cudaStream_t stream);
|
||||
|
||||
template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x,
|
||||
const float *input_y, const float *weight, float *loss, float *tmp_loss,
|
||||
cudaStream_t stream);
|
||||
template void BinaryCrossEntropyLoss<float>(const int &input_size, const int &reduction, const float *input_x,
|
||||
const float *input_y, const float *weight, float *loss, float *tmp_loss,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x,
|
||||
const float *input_y, const float *weight, const float *dloss, float *dx,
|
||||
cudaStream_t stream);
|
||||
template void BinaryCrossEntropyLossGrad<float>(const int &input_size, const int &reduction, const float *input_x,
|
||||
const float *input_y, const float *weight, const float *dloss,
|
||||
float *dx, cudaStream_t stream);
|
||||
|
||||
template void KLDivLoss<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
|
||||
half *loss, half *tmp_loss, cudaStream_t stream);
|
||||
|
||||
template void KLDivLossGrad<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
|
||||
const half *dloss, half *dx, half *dy, cudaStream_t stream);
|
||||
|
||||
template void BinaryCrossEntropyLoss<half>(const int &input_size, const int &reduction, const half *input_x,
|
||||
const half *input_y, const half *weight, half *loss, half *tmp_loss,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void BinaryCrossEntropyLossGrad<half>(const int &input_size, const int &reduction, const half *input_x,
|
||||
const half *input_y, const half *weight, const half *dloss, half *dx,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -24,5 +24,12 @@ MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropy,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropy,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,5 +26,13 @@ MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BinaryCrossEntropyGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,16 +65,15 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
|
|||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
if (reduction_ == 0) {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(input_size_ * sizeof(T));
|
||||
} else {
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(sizeof(T));
|
||||
}
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(input_size_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
KLDivLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KLDivLossGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
KLDivLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KLDivLossGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,5 +26,13 @@ MS_REG_GPU_KERNEL_ONE(KLDivLossGrad,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KLDivLossGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(KLDivLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KLDivLossGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue