!40789 optimize CTC forward by thread parallelization

Merge pull request !40789 from zhujingxuan/CTCLoss
This commit is contained in:
i-robot 2022-08-25 14:19:51 +00:00 committed by Gitee
commit 7978f99dd2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 117 additions and 62 deletions

View File

@ -91,10 +91,12 @@ void CTCLossV2CpuKernelMod::LossCompute(S *log_probs_p, S *log_alpha_p, T *tar_p
log_alpha_p[log_alpha_it(batch, 0, 1)] =
log_probs_p[log_probs_it(0, batch, GetBlankPaddedTarget(tar_p, offset, 1))];
}
for (int64_t t = 1; t < input_length; t++) {
for (int64_t s = 0; s < target_mul * target_length + 1; s++) {
auto current_target_prime = GetBlankPaddedTarget(tar_p, offset, s);
S log_a1 = log_alpha_p[log_alpha_it(batch, t - 1, s)];
for (int64_t s = 0; s < target_mul * target_length + 1; s++) {
auto current_target_prime = GetBlankPaddedTarget(tar_p, offset, s);
bool three_sum = (s > 1) && (GetBlankPaddedTarget(tar_p, offset, s - target_mul) != current_target_prime);
// a1 is the result of the previous loop
S log_a1 = log_alpha_p[log_alpha_it(batch, 0, s)];
for (int64_t t = 1; t < input_length; t++) {
S log_max = log_a1;
S log_a2, log_a3;
if (s > 0) {
@ -103,7 +105,7 @@ void CTCLossV2CpuKernelMod::LossCompute(S *log_probs_p, S *log_alpha_p, T *tar_p
} else {
log_a2 = neg_inf;
}
if ((s > 1) && (GetBlankPaddedTarget(tar_p, offset, s - target_mul) != current_target_prime)) {
if (three_sum) {
log_a3 = log_alpha_p[log_alpha_it(batch, t - 1, s - target_mul)];
log_max = std::max(log_a3, log_max);
} else {
@ -112,9 +114,10 @@ void CTCLossV2CpuKernelMod::LossCompute(S *log_probs_p, S *log_alpha_p, T *tar_p
if (log_max == neg_inf) {
log_max = 0;
}
log_alpha_p[log_alpha_it(batch, t, s)] =
std::log(std::exp(log_a1 - log_max) + std::exp(log_a2 - log_max) + std::exp(log_a3 - log_max)) + log_max +
log_probs_p[log_probs_it(t, batch, current_target_prime)];
S log_three_sum = std::log(std::exp(log_a1 - log_max) + std::exp(log_a2 - log_max) + std::exp(log_a3 - log_max)) +
log_max + log_probs_p[log_probs_it(t, batch, current_target_prime)];
log_alpha_p[log_alpha_it(batch, t, s)] = log_three_sum;
log_a1 = log_three_sum;
}
}
}

View File

@ -17,8 +17,10 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_v2_impl.cuh"
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <type_traits>
#include <limits>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T>
__device__ __forceinline__ T LogSumExp(T a, T b) {
@ -44,43 +46,75 @@ __device__ __forceinline__ int64_t GetBlankPaddedTarget(const T *target, int64_t
}
}
template <typename T>
__device__ __forceinline__ int64_t GetBlankPaddedTarget(const T *target, int64_t idx, T blank) {
constexpr int64_t interval = 2;
if (idx % interval == 0) {
return blank;
} else {
return target[(idx / interval)];
}
}
__device__ __forceinline__ size_t GetOffset3D(dim3 dims, size_t x, size_t y, size_t z) {
return x * dims.y * dims.z + y * dims.z + z;
}
template <typename S, typename T>
__device__ __forceinline__ void LossCompute(const S *log_probs_p, S *log_alpha_p, const T *tar_p, int64_t input_length,
int64_t target_length, int64_t offset, int64_t batch, T blank,
dim3 log_probs_shape, dim3 log_alpha_shape) {
__device__ __forceinline__ void LossCompute(const S *log_probs_p, const T *target_p, int64_t input_length,
int64_t target_length, int64_t max_target_length, int64_t batch, T blank,
dim3 log_probs_shape, dim3 log_alpha_shape, S *log_alpha_p) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
if (target_length > 0) {
log_alpha_p[GetOffset3D(log_alpha_shape, batch, 0, 1)] =
log_probs_p[GetOffset3D(log_probs_shape, 0, batch, GetBlankPaddedTarget(tar_p, offset, 1, blank))];
const int64_t offset = max_target_length * batch;
const int64_t padded_max_target_length = 2 * max_target_length + 1;
const int64_t padded_target_length = 2 * target_length + 1;
// Init first line where t == 0
for (int64_t block_s = 0; block_s < padded_max_target_length; block_s += blockDim.y) {
int64_t s = block_s + threadIdx.y;
if (s == 0) {
log_alpha_p[GetOffset3D(log_alpha_shape, batch, 0, 0)] =
log_probs_p[GetOffset3D(log_probs_shape, 0, batch, blank)];
} else if (s == 1 && target_length > 0) {
log_alpha_p[GetOffset3D(log_alpha_shape, batch, 0, 1)] =
log_probs_p[GetOffset3D(log_probs_shape, 0, batch, GetBlankPaddedTarget(target_p, offset, 1, blank))];
}
}
for (int64_t t = 1; t < input_length; t++) {
for (int64_t s = 0; s < 2 * target_length + 1; s++) {
auto current_target_prime = GetBlankPaddedTarget(tar_p, offset, s, blank);
S log_a1 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, t - 1, s)];
S log_max = log_a1;
S log_a2, log_a3;
if (s > 0) {
log_a2 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, t - 1, s - 1)];
log_max = max(log_a2, log_max);
} else {
log_a2 = neg_inf;
for (int64_t block_s = 0; block_s < padded_max_target_length; block_s += blockDim.y) {
int64_t s = block_s + threadIdx.y;
// Loop is based on max_target_length to
if (s < padded_target_length) {
bool valid_s = target_length > 0;
auto current_target_prime = valid_s ? GetBlankPaddedTarget(target_p, offset, s, blank) : blank;
bool three_sum =
valid_s && (s > 1) && (GetBlankPaddedTarget(target_p, offset, s - 2, blank) != current_target_prime);
// a1 is the result of the previous loop
S log_a1 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, 0, s)];
// Starts with t = 1
for (int64_t t = 1; t < input_length; t++) {
__syncthreads();
S log_max = log_a1;
S log_a2, log_a3;
if (s > 0) {
log_a2 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, t - 1, s - 1)];
log_max = max(log_a2, log_max);
} else {
log_a2 = neg_inf;
}
if (three_sum) {
log_a3 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, t - 1, s - 2)];
log_max = max(log_a3, log_max);
} else {
log_a3 = neg_inf;
}
if (log_max == neg_inf) {
log_max = 0;
}
S log_three_sum =
std::log(std::exp(log_a1 - log_max) + std::exp(log_a2 - log_max) + std::exp(log_a3 - log_max)) + log_max +
log_probs_p[GetOffset3D(log_probs_shape, t, batch, current_target_prime)];
log_alpha_p[GetOffset3D(log_alpha_shape, batch, t, s)] = log_three_sum;
log_a1 = log_three_sum;
}
if ((s > 1) && (GetBlankPaddedTarget(tar_p, offset, s - 2, blank) != current_target_prime)) {
log_a3 = log_alpha_p[GetOffset3D(log_alpha_shape, batch, t - 1, s - 2)];
log_max = max(log_a3, log_max);
} else {
log_a3 = neg_inf;
}
if (log_max == neg_inf) {
log_max = 0;
}
log_alpha_p[GetOffset3D(log_alpha_shape, batch, t, s)] =
std::log(std::exp(log_a1 - log_max) + std::exp(log_a2 - log_max) + std::exp(log_a3 - log_max)) + log_max +
log_probs_p[GetOffset3D(log_probs_shape, t, batch, current_target_prime)];
}
}
}
@ -136,29 +170,37 @@ __device__ __forceinline__ void GradCompute(const S *log_probs, const S *log_alp
template <typename S, typename T>
__global__ void CTCLossV2Kernel(const S *log_probs_p, const T *target_p, const T *input_len_p, const T *target_len_p,
int64_t max_target_length, int64_t time_series, int64_t batch_size, T blank,
dim3 log_probs_shape, dim3 log_alpha_shape, S *neg_log_p, S *log_alpha_p) {
dim3 log_probs_shape, dim3 log_alpha_shape, S *log_alpha_p) {
int64_t b = threadIdx.x + blockIdx.x * blockDim.x;
if (b >= batch_size) {
return;
}
int64_t input_length = input_len_p[b];
int64_t target_length = target_len_p[b];
CUDA_KERNEL_ASSERT(input_length >= 0 && "For 'CTCLossV2', input_length should be non-negative.")
CUDA_KERNEL_ASSERT(target_length >= 0 && "For 'CTCLossV2', target_length should be non-negative.")
CUDA_KERNEL_ASSERT(target_length <= max_target_length &&
"For 'CTCLossV2', target_length should be less equal to targets.shape[1].")
CUDA_KERNEL_ASSERT(input_length >= target_length &&
"For 'CTCLossV2', input_length should be greater equal to target_length.")
CUDA_KERNEL_ASSERT(input_length <= time_series &&
"For 'CTCLossV2', input_length should be less equal to probs.shape[0].")
LossCompute<S, T>(log_probs_p, target_p, input_length, target_length, max_target_length, b, blank, log_probs_shape,
log_alpha_shape, log_alpha_p);
}
template <typename S, typename T>
__global__ void LogLikelihoodKernel(const S *log_alpha_p, const T *input_length_p, const T *target_length_p,
int64_t batch_size, dim3 log_alpha_shape, S *neg_log_p) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < batch_size; b += blockDim.x * gridDim.x) {
int64_t input_len = input_len_p[b];
int64_t tar_len = target_len_p[b];
CUDA_KERNEL_ASSERT(input_len >= 0 && "For 'CTCLossV2', input_length should be non-negative.")
CUDA_KERNEL_ASSERT(tar_len >= 0 && "For 'CTCLossV2', target_length should be non-negative.")
CUDA_KERNEL_ASSERT(tar_len <= max_target_length &&
"For 'CTCLossV2', target_length should be less equal to targets.shape[1].")
CUDA_KERNEL_ASSERT(input_len >= tar_len &&
"For 'CTCLossV2', input_length should be greater equal to target_length.")
CUDA_KERNEL_ASSERT(input_len <= time_series &&
"For 'CTCLossV2', input_length should be less equal to probs.shape[0].")
int64_t offset = max_target_length * b;
log_alpha_p[GetOffset3D(log_alpha_shape, b, 0, 0)] = log_probs_p[GetOffset3D(log_probs_shape, 0, b, blank)];
LossCompute<S, T>(log_probs_p, log_alpha_p, target_p, input_len, tar_len, offset, b, blank, log_probs_shape,
log_alpha_shape);
if (tar_len == 0) {
neg_log_p[b] = -log_alpha_p[GetOffset3D(log_alpha_shape, b, input_len - 1, 0)];
int64_t input_length = input_length_p[b];
int64_t target_length = target_length_p[b];
if (target_length == 0) {
neg_log_p[b] = -log_alpha_p[GetOffset3D(log_alpha_shape, b, input_length - 1, 0)];
} else {
S l1 = log_alpha_p[GetOffset3D(log_alpha_shape, b, input_len - 1, tar_len * 2)];
S l2 = log_alpha_p[GetOffset3D(log_alpha_shape, b, input_len - 1, tar_len * 2 - 1)];
S l1 = log_alpha_p[GetOffset3D(log_alpha_shape, b, input_length - 1, target_length * 2)];
S l2 = log_alpha_p[GetOffset3D(log_alpha_shape, b, input_length - 1, target_length * 2 - 1)];
S m = max(l1, l2);
m = ((m == neg_inf) ? 0 : m);
S log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
@ -176,9 +218,20 @@ void CalCTCLossV2(const S *log_probs_p, const T *target_p, const T *input_len_p,
thrust::device_ptr<S> dev_ptr(log_alpha_p);
thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + alpha_size, neg_inf);
CTCLossV2Kernel<<<CUDA_BLOCKS(device_id, batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
log_probs_p, target_p, input_len_p, target_len_p, max_target_length, time_series, batch_size, blank,
log_probs_shape, log_alpha_shape, neg_log_p, log_alpha_p);
const int64_t padded_target_length = 2 * max_target_length + 1;
const uint64_t padded_target_length_power2 = 1ull << Log2Ceil64(padded_target_length);
const uint64_t max_threads = CUDA_THREADS(device_id);
const uint64_t threads_per_batch = std::min(max_threads, padded_target_length_power2);
const unsigned int batches_per_block = std::min(max_threads / threads_per_batch, static_cast<uint64_t>(batch_size));
dim3 blocks((batch_size + batches_per_block - 1) / batches_per_block);
dim3 threads(batches_per_block, threads_per_batch);
CTCLossV2Kernel<<<blocks, threads, 0, cuda_stream>>>(log_probs_p, target_p, input_len_p, target_len_p,
max_target_length, time_series, batch_size, blank,
log_probs_shape, log_alpha_shape, log_alpha_p);
LogLikelihoodKernel<<<CUDA_BLOCKS(device_id, batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
log_alpha_p, input_len_p, target_len_p, batch_size, log_alpha_shape, neg_log_p);
}
template <typename S, typename T>
@ -188,7 +241,7 @@ __global__ void CTCLossV2GradKernel(const S *grad_out, const S *log_probs, const
int64_t max_target_length, bool zero_infinity, T blank, dim3 log_probs_shape,
dim3 log_alpha_shape, S *grad) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
for (int64_t b = 0; b < batch_size; b++) {
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < batch_size; b += blockDim.x * gridDim.x) {
S nll = neg_log_likelihood[b];
if (zero_infinity && nll == std::numeric_limits<S>::infinity()) {
for (int t = 0; t < time_series; t++) {

View File

@ -19,7 +19,6 @@
#include <memory>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_v2_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/batchnorm_fold_impl.cuh"
#include "mindspore/core/ops/ctc_loss_v2.h"
#include "abstract/utils.h"