Added ReductionMode enum for loss with reduction gpu kernels.

This commit is contained in:
markuskunej 2021-12-15 16:57:35 +00:00
parent 4705d5dce2
commit 80e86a697a
10 changed files with 111 additions and 115 deletions

View File

@ -531,17 +531,6 @@ int Sign(float x) {
return 0;
}
int GetReductionInt(const std::string &reduction) {
if (reduction == "none") {
return 0;
} else if (reduction == "sum") {
return 2;
} else {
// reduction = 'mean'
return 1;
}
}
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {

View File

@ -86,7 +86,6 @@ std::string GetProcessor(const AnfNodePtr &anf_node);
Processor GetProcessor(const string &processor);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
int Sign(float x);
int GetReductionInt(const std::string &reduction);
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);

View File

@ -27,9 +27,9 @@ inline __device__ float maxT(float a, float b) { return fmaxf(a, b); }
inline __device__ half maxT(half a, half b) { return a > b ? a : b; }
template <typename T>
__global__ void Copy(T *loss, T *tmp_loss, int reduction, int input_size) {
__global__ void Copy(T *loss, T *tmp_loss, ReductionMode reduction, int input_size) {
loss[0] += tmp_loss[0];
if (reduction == 1) {
if (reduction == ReductionMode::kMean) {
loss[0] /= castT(loss[0], input_size);
}
}
@ -108,14 +108,14 @@ void Sum(T *array, const int &size, cudaStream_t stream) {
}
template <typename T, typename S>
void Reduce(T *tmp_loss, const int &size, S *denom, const int &reduction, T *output, cudaStream_t stream) {
void Reduce(T *tmp_loss, const int &size, S *denom, const ReductionMode &reduction, T *output, cudaStream_t stream) {
// sum losses together
Sum(tmp_loss, size, stream);
if (reduction == 1) {
if (reduction == ReductionMode::kMean) {
// mean reduction, divide sum by denominator, store result in output
Divide<<<1, 1, 0, stream>>>(tmp_loss, denom, output);
} else if (reduction == 2) {
} else if (reduction == ReductionMode::kSum) {
// sum reduction, copy sum to output
CopyEqual<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(tmp_loss, output, size);
}
@ -134,10 +134,10 @@ __global__ void InitZero(T *array, int size) {
}
template <typename T>
__global__ void KLDivLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, T *loss,
T *tmp_loss) {
__global__ void KLDivLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y,
T *loss, T *tmp_loss) {
T epsilon = 1e-6;
if (reduction == 0) {
if (reduction == ReductionMode::kNone) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_y[i], epsilon);
T value = input_y[i] * (logT(denominator) - input_x[i]);
@ -153,12 +153,12 @@ __global__ void KLDivLossKernel(const int input_size, const int reduction, const
}
template <typename T>
void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss,
cudaStream_t stream) {
void KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, T *loss,
T *tmp_loss, cudaStream_t stream) {
LossInitKernel<<<1, 1, 0, stream>>>(loss);
KLDivLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, input_y, loss,
tmp_loss);
if (reduction != 0) {
if (reduction != ReductionMode::kNone) {
if (input_size % 2 == 1) {
AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1);
}
@ -173,11 +173,11 @@ void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, co
}
template <typename T>
__global__ void KLDivLossGradKernel(const int input_size, const int reduction, const T *input_x, const T *input_y,
const T *dloss, T *dx, T *dy) {
__global__ void KLDivLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x,
const T *input_y, const T *dloss, T *dx, T *dy) {
T epsilon = 1e-6;
T one = static_cast<T>(1);
if (reduction == 0) {
if (reduction == ReductionMode::kNone) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_y[i], epsilon);
dx[i] = -input_y[i] * dloss[i];
@ -185,7 +185,7 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c
}
} else {
T dloss1 = dloss[0];
if (reduction == 1) {
if (reduction == ReductionMode::kMean) {
dloss1 = dloss[0] / castT(dloss[0], input_size);
}
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
@ -197,29 +197,29 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c
}
template <typename T>
void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss,
T *dx, T *dy, cudaStream_t stream) {
void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y,
const T *dloss, T *dx, T *dy, cudaStream_t stream) {
KLDivLossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, input_y,
dloss, dx, dy);
}
template <typename T>
__global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x,
__global__ void BinaryCrossEntropyLossKernel(const int input_size, const ReductionMode reduction, const T *input_x,
const T *input_y, const T *weight, T *loss, T *tmp_loss) {
T epsilon = 1e-12;
T one = static_cast<T>(1);
if (reduction == 0 && weight != nullptr) {
if (reduction == ReductionMode::kNone && weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value =
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
loss[i] = value;
}
} else if (reduction == 0 && weight == nullptr) {
} else if (reduction == ReductionMode::kNone && weight == nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
loss[i] = value;
}
} else if (reduction != 0 && weight != nullptr) {
} else if (reduction != ReductionMode::kNone && weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T value =
-weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon));
@ -234,12 +234,12 @@ __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int red
}
template <typename T>
void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y,
void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y,
const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) {
LossInitKernel<<<1, 1, 0, stream>>>(loss);
BinaryCrossEntropyLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x,
input_y, weight, loss, tmp_loss);
if (reduction != 0) {
if (reduction != ReductionMode::kNone) {
if (input_size % 2 == 1) {
AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1);
}
@ -254,11 +254,11 @@ void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T
}
template <typename T>
__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x,
__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x,
const T *input_y, const T *weight, const T *dloss, T *dx) {
T epsilon = 1e-12;
T one = static_cast<T>(1);
if (reduction == 0) {
if (reduction == ReductionMode::kNone) {
if (weight != nullptr) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) {
T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon);
@ -274,7 +274,7 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int
}
} else {
T dloss1 = dloss[0];
if (reduction == 1) {
if (reduction == ReductionMode::kMean) {
dloss1 = dloss[0] / castT(dloss[0], input_size);
}
if (weight != nullptr) {
@ -294,8 +294,8 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int
}
template <typename T>
void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y,
const T *weight, const T *dloss, T *dx, cudaStream_t stream) {
void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x,
const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream) {
BinaryCrossEntropyLossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x,
input_y, weight, dloss, dx);
}
@ -319,10 +319,9 @@ __global__ void NLLLossKernel(const int n, const int c, const T *input, const in
}
template <typename T, typename S>
void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) {
// if reduction != "none"
if (reduction != 0) {
void NLLLoss(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target,
const S *weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) {
if (reduction != ReductionMode::kNone) {
NLLLossKernel<<<GET_BLOCKS(n), GET_THREADS, 0, stream>>>(n, c, input, target, weight, tmp_target_weight, tmp_loss);
// sum target weights after populating them
Sum(tmp_target_weight, n, stream);
@ -340,12 +339,13 @@ void NLLLoss(const int n, const int c, const int reduction, const T *input, cons
}
template <typename T, typename S>
__global__ void NLLLossGradKernel(const int n, const int c, const int reduction, const T *input, const int32_t *target,
const S *weight, const S *total_weight, const T *dloss, T *dinput) {
__global__ void NLLLossGradKernel(const int n, const int c, const ReductionMode reduction, const T *input,
const int32_t *target, const S *weight, const S *total_weight, const T *dloss,
T *dinput) {
int input_idx;
int target_class;
S tmp_quot;
if (reduction == 0) {
if (reduction == ReductionMode::kNone) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
@ -353,7 +353,7 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction,
MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx);
}
} else if (reduction == 1) {
} else if (reduction == ReductionMode::kMean) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
@ -362,7 +362,7 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction,
tmp_quot = (-weight[target_class]) / *total_weight;
MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx);
}
} else if (reduction == 2) {
} else if (reduction == ReductionMode::kSum) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
@ -374,8 +374,8 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction,
}
template <typename T, typename S>
void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) {
void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target,
const S *weight, const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) {
int input_size = n * c;
InitZero<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(dinput, input_size);
@ -383,62 +383,62 @@ void NLLLossGrad(const int n, const int c, const int reduction, const T *input,
dloss, dinput);
}
template void KLDivLoss<float>(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
float *loss, float *tmp_loss, cudaStream_t stream);
template void KLDivLoss<float>(const int &input_size, const ReductionMode &reduction, const float *input_x,
const float *input_y, float *loss, float *tmp_loss, cudaStream_t stream);
template void KLDivLossGrad<float>(const int &input_size, const int &reduction, const float *input_x,
template void KLDivLossGrad<float>(const int &input_size, const ReductionMode &reduction, const float *input_x,
const float *input_y, const float *dloss, float *dx, float *dy, cudaStream_t stream);
template void BinaryCrossEntropyLoss<float>(const int &input_size, const int &reduction, const float *input_x,
template void BinaryCrossEntropyLoss<float>(const int &input_size, const ReductionMode &reduction, const float *input_x,
const float *input_y, const float *weight, float *loss, float *tmp_loss,
cudaStream_t stream);
template void BinaryCrossEntropyLossGrad<float>(const int &input_size, const int &reduction, const float *input_x,
const float *input_y, const float *weight, const float *dloss,
float *dx, cudaStream_t stream);
template void BinaryCrossEntropyLossGrad<float>(const int &input_size, const ReductionMode &reduction,
const float *input_x, const float *input_y, const float *weight,
const float *dloss, float *dx, cudaStream_t stream);
template void NLLLoss<float, float>(const int n, const int c, const int reduction, const float *input,
template void NLLLoss<float, float>(const int n, const int c, const ReductionMode reduction, const float *input,
const int32_t *target, const float *weight, float *loss, float *total_weight,
float *tmp_loss, float *tmp_target_weight, cudaStream_t stream);
template void NLLLoss<float, half>(const int n, const int c, const int reduction, const float *input,
template void NLLLoss<float, half>(const int n, const int c, const ReductionMode reduction, const float *input,
const int32_t *target, const half *weight, float *loss, half *total_weight,
float *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
template void NLLLossGrad<float, float>(const int n, const int c, const int reduction, const float *input,
template void NLLLossGrad<float, float>(const int n, const int c, const ReductionMode reduction, const float *input,
const int32_t *target, const float *weight, const float *total_weight,
const float *dloss, float *dinput, cudaStream_t stream);
template void NLLLossGrad<float, half>(const int n, const int c, const int reduction, const float *input,
template void NLLLossGrad<float, half>(const int n, const int c, const ReductionMode reduction, const float *input,
const int32_t *target, const half *weight, const half *total_weight,
const float *dloss, float *dinput, cudaStream_t stream);
template void KLDivLoss<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
half *loss, half *tmp_loss, cudaStream_t stream);
template void KLDivLoss<half>(const int &input_size, const ReductionMode &reduction, const half *input_x,
const half *input_y, half *loss, half *tmp_loss, cudaStream_t stream);
template void KLDivLossGrad<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
const half *dloss, half *dx, half *dy, cudaStream_t stream);
template void KLDivLossGrad<half>(const int &input_size, const ReductionMode &reduction, const half *input_x,
const half *input_y, const half *dloss, half *dx, half *dy, cudaStream_t stream);
template void BinaryCrossEntropyLoss<half>(const int &input_size, const int &reduction, const half *input_x,
template void BinaryCrossEntropyLoss<half>(const int &input_size, const ReductionMode &reduction, const half *input_x,
const half *input_y, const half *weight, half *loss, half *tmp_loss,
cudaStream_t stream);
template void BinaryCrossEntropyLossGrad<half>(const int &input_size, const int &reduction, const half *input_x,
const half *input_y, const half *weight, const half *dloss, half *dx,
cudaStream_t stream);
template void BinaryCrossEntropyLossGrad<half>(const int &input_size, const ReductionMode &reduction,
const half *input_x, const half *input_y, const half *weight,
const half *dloss, half *dx, cudaStream_t stream);
template void NLLLoss<half, half>(const int n, const int c, const int reduction, const half *input,
template void NLLLoss<half, half>(const int n, const int c, const ReductionMode reduction, const half *input,
const int32_t *target, const half *weight, half *loss, half *total_weight,
half *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
template void NLLLoss<half, float>(const int n, const int c, const int reduction, const half *input,
template void NLLLoss<half, float>(const int n, const int c, const ReductionMode reduction, const half *input,
const int32_t *target, const float *weight, half *loss, float *total_weight,
half *tmp_loss, float *tmp_target_weight, cudaStream_t stream);
template void NLLLossGrad<half, half>(const int n, const int c, const int reduction, const half *input,
template void NLLLossGrad<half, half>(const int n, const int c, const ReductionMode reduction, const half *input,
const int32_t *target, const half *weight, const half *total_weight,
const half *dloss, half *dinput, cudaStream_t stream);
template void NLLLossGrad<half, float>(const int n, const int c, const int reduction, const half *input,
template void NLLLossGrad<half, float>(const int n, const int c, const ReductionMode reduction, const half *input,
const int32_t *target, const float *weight, const float *total_weight,
const half *dloss, half *dinput, cudaStream_t stream);

View File

@ -17,23 +17,31 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH
#include <map>
#include <string>
enum class ReductionMode { kNone, kMean, kSum };
static std::map<std::string, ReductionMode> kReductionModeMap{
{"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}};
template <typename T>
void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y,
void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y,
const T *weight, T *loss, T *tmp_loss, cudaStream_t stream);
template <typename T>
void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y,
const T *weight, const T *dloss, T *dx, cudaStream_t stream);
void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x,
const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream);
template <typename T>
void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss,
cudaStream_t stream);
void KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, T *loss,
T *tmp_loss, cudaStream_t stream);
template <typename T>
void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss,
T *dx, T *dy, cudaStream_t stream);
void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y,
const T *dloss, T *dx, T *dy, cudaStream_t stream);
template <typename T, typename S>
void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream);
void NLLLoss(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target,
const S *weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream);
template <typename T, typename S>
void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream);
void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target,
const S *weight, const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH

View File

@ -34,7 +34,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
is_null_input_(false),
kernel_name_("BinaryCrossEntropy"),
input_size_(1),
reduction_(1),
reduction_(ReductionMode::kMean),
workspace_size_(1) {}
~BinaryCrossEntropyGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -74,9 +74,9 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
reduction_ = kReductionModeMap[reduction];
workspace_size_ = sizeof(T);
if (reduction_ != 0) {
if (reduction_ != ReductionMode::kNone) {
workspace_size_ *= input_size_;
}
InitSizeLists();
@ -90,7 +90,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
if (weight_defined_) {
input_size_list_.push_back(input_size_ * sizeof(T));
}
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
output_size_list_.push_back(input_size_ * sizeof(T));
} else {
output_size_list_.push_back(sizeof(T));
@ -103,7 +103,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
bool is_null_input_;
std::string kernel_name_;
size_t input_size_;
int reduction_;
ReductionMode reduction_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -31,7 +31,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
public:
BinaryCrossEntropyGradGpuKernel()
: input_size_(1),
reduction_(1),
reduction_(ReductionMode::kMean),
weight_defined_(false),
is_null_input_(false),
kernel_name_("BinaryCrossEntropyGrad") {}
@ -75,7 +75,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
reduction_ = kReductionModeMap[reduction];
InitSizeLists();
return true;
}
@ -84,7 +84,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
input_size_list_.push_back(input_size_ * sizeof(T));
} else {
input_size_list_.push_back(sizeof(T));
@ -97,7 +97,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
int reduction_;
ReductionMode reduction_;
bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
bool is_null_input_;
std::string kernel_name_;

View File

@ -29,7 +29,7 @@ namespace kernel {
template <typename T>
class KLDivLossGpuKernel : public GpuKernel {
public:
KLDivLossGpuKernel() : input_size_(1), reduction_(1), is_null_input_(false), workspace_size_(0) {}
KLDivLossGpuKernel() : input_size_(1), reduction_(ReductionMode::kMean), is_null_input_(false), workspace_size_(0) {}
~KLDivLossGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -61,9 +61,9 @@ class KLDivLossGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
reduction_ = kReductionModeMap[reduction];
workspace_size_ = sizeof(T);
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
workspace_size_ *= input_size_;
}
InitSizeLists();
@ -74,7 +74,7 @@ class KLDivLossGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_size_ * sizeof(T));
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
output_size_list_.push_back(input_size_ * sizeof(T));
} else {
output_size_list_.push_back(sizeof(T));
@ -84,7 +84,7 @@ class KLDivLossGpuKernel : public GpuKernel {
private:
size_t input_size_;
int reduction_;
ReductionMode reduction_;
bool is_null_input_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;

View File

@ -29,7 +29,7 @@ namespace kernel {
template <typename T>
class KLDivLossGradGpuKernel : public GpuKernel {
public:
KLDivLossGradGpuKernel() : input_size_(1), reduction_(1), is_null_input_(false) {}
KLDivLossGradGpuKernel() : input_size_(1), reduction_(ReductionMode::kMean), is_null_input_(false) {}
~KLDivLossGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -62,7 +62,7 @@ class KLDivLossGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
reduction_ = kReductionModeMap[reduction];
InitSizeLists();
return true;
}
@ -73,7 +73,7 @@ class KLDivLossGradGpuKernel : public GpuKernel {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
input_size_list_.push_back(input_size_ * sizeof(T));
} else {
input_size_list_.push_back(sizeof(T));
@ -82,7 +82,7 @@ class KLDivLossGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
int reduction_;
ReductionMode reduction_;
bool is_null_input_;
std::vector<size_t> input_size_list_;

View File

@ -48,8 +48,8 @@ class NLLLossGpuKernel : public GpuKernel {
T *loss_device = GetDeviceAddress<T>(outputs, 0);
S *total_weight_device = GetDeviceAddress<S>(outputs, 1);
T *tmp_loss_device =
reduction_ != 0 ? GetDeviceAddress<T>(workspace, 0) : GetPossiblyNullDeviceAddress<T>(workspace, 0);
T *tmp_loss_device = reduction_ != ReductionMode::kNone ? GetDeviceAddress<T>(workspace, 0)
: GetPossiblyNullDeviceAddress<T>(workspace, 0);
S *tmp_target_weight_device = GetDeviceAddress<S>(workspace, 1);
@ -76,8 +76,8 @@ class NLLLossGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
if ((reduction_ == 2) || (reduction_ == 1)) {
reduction_ = kReductionModeMap[reduction];
if ((reduction_ == ReductionMode::kSum) || (reduction_ == ReductionMode::kMean)) {
tmp_loss_size_ = sizeof(T) * n_;
}
tmp_target_weight_size_ = n_ * sizeof(S);
@ -91,7 +91,7 @@ class NLLLossGpuKernel : public GpuKernel {
n_ = 0;
c_ = 0;
is_null_input_ = false;
reduction_ = 1; // default value
reduction_ = ReductionMode::kMean; // default value
tmp_loss_size_ = 0;
tmp_target_weight_size_ = 0; // tmp_target_weight (N,) array
input_size_list_.clear();
@ -105,7 +105,7 @@ class NLLLossGpuKernel : public GpuKernel {
input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N)
input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C)
if (reduction_ == 0) {
if (reduction_ == ReductionMode::kNone) {
output_size_list_.push_back(n_ * sizeof(T)); // loss output of shape (N,)
} else {
output_size_list_.push_back(sizeof(T)); // scalar loss output
@ -117,7 +117,7 @@ class NLLLossGpuKernel : public GpuKernel {
private:
size_t input_size_;
int reduction_;
ReductionMode reduction_;
size_t tmp_loss_size_;
size_t tmp_target_weight_size_;
int n_;

View File

@ -73,8 +73,8 @@ class NLLLossGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = GetReductionInt(reduction);
if (reduction_ == 0) {
reduction_ = kReductionModeMap[reduction];
if (reduction_ == ReductionMode::kNone) {
num_dloss_ = n_;
}
@ -87,8 +87,8 @@ class NLLLossGradGpuKernel : public GpuKernel {
n_ = 0;
c_ = 0;
is_null_input_ = false;
reduction_ = 1; // default value
num_dloss_ = 1; // default size (scalar)
reduction_ = ReductionMode::kMean; // default value
num_dloss_ = 1; // default size (scalar)
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
@ -107,7 +107,7 @@ class NLLLossGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
int reduction_;
ReductionMode reduction_;
int n_;
int c_;
bool is_null_input_;