add GPU CTCLoss
This commit is contained in:
parent
f226789f82
commit
e1b31c7baa
|
@ -0,0 +1,446 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <limits>
|
||||
#include "ctcloss_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
__device__ T LogSumExp(const T logprob1, const T logprob2) {
|
||||
if (logprob1 == logprob2 && logprob1 == -std::numeric_limits<T>::infinity()) {
|
||||
return logprob1;
|
||||
} else {
|
||||
return (logprob1 > logprob2) ? logprob1 + log1pf(expf(logprob2 - logprob1))
|
||||
: logprob2 + log1pf(expf(logprob1 - logprob2));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalculateFwdVarKernel(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs,
|
||||
const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
|
||||
int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
|
||||
bool ignore_longer_outputs_than_inputs) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
if (sequence_length[i] == 0 ||
|
||||
(ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
|
||||
} else {
|
||||
T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
|
||||
int *label_value_with_blank_cur = &label_value_with_blank[0];
|
||||
if (i > 0) {
|
||||
label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
|
||||
}
|
||||
int numclass = blank + 1;
|
||||
int U = 2 * label_squence_length[i] + 1;
|
||||
int Ti = sequence_length[i];
|
||||
int low = 0;
|
||||
int high = 0;
|
||||
log_alpha_b_cur[0] = log(softmax_probs[i * numclass + blank]);
|
||||
int label0 = blank;
|
||||
if (U > 1) {
|
||||
label0 = label_value_with_blank_cur[1];
|
||||
log_alpha_b_cur[maxtime] = log(softmax_probs[i * numclass + label0]);
|
||||
}
|
||||
for (int t = 1; t < Ti; ++t) {
|
||||
low = 0;
|
||||
high = U;
|
||||
int low_limit = U - (2 * (Ti - t));
|
||||
int high_limit = 2 * (t + 1);
|
||||
if (low_limit > low) {
|
||||
low = low_limit;
|
||||
}
|
||||
if (high_limit < U) {
|
||||
high = high_limit;
|
||||
}
|
||||
for (int u = low; u < high; ++u) {
|
||||
T sum_log_alpha = -std::numeric_limits<T>::infinity();
|
||||
if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
|
||||
sum_log_alpha = log_alpha_b_cur[u * maxtime + t - 1];
|
||||
}
|
||||
if (u > 0) {
|
||||
sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 1) * maxtime + t - 1]);
|
||||
}
|
||||
if (u > 1) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u - 2]);
|
||||
if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
|
||||
sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 2) * maxtime + t - 1]);
|
||||
}
|
||||
}
|
||||
log_alpha_b_cur[u * maxtime + t] =
|
||||
log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + t * numclass * batch]) + sum_log_alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalculateBwdVarKernel(T *log_beta_b, int *label_value_with_blank, T *softmax_probs,
|
||||
const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
|
||||
int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
|
||||
bool ignore_longer_outputs_than_inputs) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
if (sequence_length[i] == 0 ||
|
||||
(ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
|
||||
} else {
|
||||
T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
|
||||
int *label_value_with_blank_cur = &label_value_with_blank[0];
|
||||
if (i > 0) {
|
||||
label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
|
||||
}
|
||||
int numclass = blank + 1;
|
||||
int U = 2 * label_squence_length[i] + 1;
|
||||
int Ti = sequence_length[i];
|
||||
int low = 0;
|
||||
int high = 0;
|
||||
if (U > 1) {
|
||||
for (int u = U - 2; u < U; ++u) {
|
||||
log_beta_b_cur[u * maxtime + Ti - 1] = 0;
|
||||
}
|
||||
} else {
|
||||
log_beta_b_cur[Ti - 1] = 0;
|
||||
log_beta_b_cur[Ti - 2] = 0;
|
||||
}
|
||||
for (int t = Ti - 2; t >= 0; --t) {
|
||||
low = 0;
|
||||
high = U;
|
||||
int low_limit = U - (2 * (Ti - t));
|
||||
int high_limit = 2 * (t + 1);
|
||||
if (low_limit > low) {
|
||||
low = low_limit;
|
||||
}
|
||||
if (high_limit < U) {
|
||||
high = high_limit;
|
||||
}
|
||||
for (int u = low; u < high; ++u) {
|
||||
if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
|
||||
log_beta_b_cur[u * maxtime + t] = LogSumExp(
|
||||
log_beta_b_cur[u * maxtime + t],
|
||||
log_beta_b_cur[u * maxtime + t + 1] +
|
||||
log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + (t + 1) * numclass * batch]));
|
||||
}
|
||||
if (u + 1 < U) {
|
||||
log_beta_b_cur[u * maxtime + t] = LogSumExp(
|
||||
log_beta_b_cur[u * maxtime + t],
|
||||
log_beta_b_cur[(u + 1) * maxtime + t + 1] +
|
||||
log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 1] + (t + 1) * numclass * batch]));
|
||||
}
|
||||
if (u + 2 < U) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u + 2]);
|
||||
if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
|
||||
log_beta_b_cur[u * maxtime + t] = LogSumExp(
|
||||
log_beta_b_cur[u * maxtime + t],
|
||||
log_beta_b_cur[(u + 2) * maxtime + t + 1] +
|
||||
log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 2] + (t + 1) * numclass * batch]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ProbInitKernel(T *prob_num, int size) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
prob_num[i] = -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void LogBInitKernel(T *log_b, int log_prob_size) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < log_prob_size; i += blockDim.x * gridDim.x) {
|
||||
log_b[i] = -std::numeric_limits<T>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CTCLossKernel(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch,
|
||||
int SOffSet, int maxtime, int numclass, const int *sequence_length,
|
||||
int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num,
|
||||
bool ignore_longer_outputs_than_inputs) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
if (sequence_length[i] == 0 ||
|
||||
(ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
|
||||
} else {
|
||||
T *grad_cur = &grads[i * numclass];
|
||||
const T *softmax_probs_cur = &softmax_probs[i * numclass];
|
||||
T *prob_num_cur = &prob_num[i * numclass];
|
||||
int U = 2 * label_squence_length[i] + 1;
|
||||
T log_pzx = -std::numeric_limits<T>::infinity();
|
||||
const T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
|
||||
const T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
|
||||
int *label_value_with_blank_cur = &label_value_with_blank[0];
|
||||
if (i > 0) {
|
||||
label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
|
||||
}
|
||||
for (int u = 0; u < U; ++u) {
|
||||
log_pzx = LogSumExp(log_pzx, log_alpha_b_cur[u * maxtime] + log_beta_b_cur[u * maxtime]);
|
||||
}
|
||||
cost[i] = -log_pzx;
|
||||
// grad
|
||||
int L = numclass;
|
||||
int Ti = sequence_length[i];
|
||||
if (log_pzx == -std::numeric_limits<T>::infinity()) {
|
||||
for (int t = 0; t < Ti; ++t) {
|
||||
for (int l = 0; l < L; ++l) {
|
||||
grad_cur[t * numclass * batch + l] = softmax_probs_cur[t * numclass * batch + l];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t < Ti; ++t) {
|
||||
for (int u = 0; u < U; ++u) {
|
||||
int l = label_value_with_blank_cur[u];
|
||||
prob_num_cur[t * batch * numclass + l] =
|
||||
LogSumExp(prob_num_cur[t * batch * numclass + l],
|
||||
log_alpha_b_cur[u * maxtime + t] + log_beta_b_cur[u * maxtime + t]);
|
||||
}
|
||||
for (int l = 0; l < L; ++l) {
|
||||
grad_cur[t * numclass * batch + l] =
|
||||
softmax_probs_cur[t * numclass * batch + l] - expf(prob_num_cur[t * batch * numclass + l] - log_pzx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void InnerSoftMaxKernel(const T *probs, T *softmax_probs, const int *sequence_length, int max_time,
|
||||
int batch, int numclass) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch * max_time; i += blockDim.x * gridDim.x) {
|
||||
int k = i / batch;
|
||||
int m = i % batch;
|
||||
if (k < sequence_length[m]) {
|
||||
T maxCoeff = 0.;
|
||||
T sumCoeff = 0.;
|
||||
for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
|
||||
if (probs[j] > maxCoeff) {
|
||||
maxCoeff = probs[j];
|
||||
}
|
||||
}
|
||||
for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
|
||||
sumCoeff += exp(probs[j] - maxCoeff);
|
||||
softmax_probs[j] = exp(probs[j] - maxCoeff);
|
||||
}
|
||||
for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
|
||||
softmax_probs[j] /= sumCoeff;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void GenLabelValuePCRKernel(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
|
||||
int *cum_labels_length, int batch) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
int L = label_squence_length[i];
|
||||
label_squence_length[i] = 0;
|
||||
int offset = 0;
|
||||
if (i > 0) {
|
||||
offset = cum_labels_length[i - 1];
|
||||
}
|
||||
for (int l = offset; l < L; ++l) {
|
||||
if (l == offset || label_value_sp[l] != label_value_sp[l - 1]) {
|
||||
label_value_pcr[offset + label_squence_length[i]++] = label_value_sp[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void UpdateLengthKernel(int *label_squence_length, int *cum_labels_length, int *max_labels_length,
|
||||
int batch) {
|
||||
max_labels_length[0] = 0;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
if (label_squence_length[i] > max_labels_length[0]) {
|
||||
max_labels_length[0] = label_squence_length[i];
|
||||
}
|
||||
if (i == 0) {
|
||||
cum_labels_length[i] = label_squence_length[i];
|
||||
} else {
|
||||
cum_labels_length[i] = label_squence_length[i] + cum_labels_length[i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream) {
|
||||
int log_prob_size = SOffSet * batch * maxtime;
|
||||
LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_beta_b, log_prob_size);
|
||||
CalculateBwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
|
||||
log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
|
||||
blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream) {
|
||||
int log_prob_size = SOffSet * batch * maxtime;
|
||||
LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_alpha_b, log_prob_size);
|
||||
CalculateFwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
|
||||
log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
|
||||
blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InnerSoftMax(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, int batch, int numclass,
|
||||
cudaStream_t stream) {
|
||||
InnerSoftMaxKernel<<<GET_BLOCKS(batch * max_time), GET_THREADS, 0, stream>>>(probs, softmax_probs, sequence_length,
|
||||
max_time, batch, numclass);
|
||||
}
|
||||
|
||||
__global__ void GenLabelWithBlankKernel(int *label_value, int *label_value_with_blank, int *label_squence_length,
|
||||
int *precum_labels_length, int *cum_labels_length, int batch, int blank) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
int offset = 0;
|
||||
int offset1 = 0;
|
||||
if (i > 0) {
|
||||
offset = 2 * cum_labels_length[i - 1] + i;
|
||||
offset1 = precum_labels_length[i - 1];
|
||||
}
|
||||
for (int j = 0; j < label_squence_length[i]; ++j) {
|
||||
label_value_with_blank[offset + 2 * j] = blank;
|
||||
label_value_with_blank[offset + 2 * j + 1] = label_value[offset1 + j];
|
||||
}
|
||||
label_value_with_blank[offset + 2 * label_squence_length[i]] = blank;
|
||||
}
|
||||
}
|
||||
|
||||
void GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
|
||||
int *precum_labels_length, int *cum_labels_length, int batch, int blank, cudaStream_t stream) {
|
||||
GenLabelWithBlankKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
|
||||
label_value, label_value_with_blank, label_squence_length, precum_labels_length, cum_labels_length, batch, blank);
|
||||
}
|
||||
|
||||
void GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, int *cum_labels_length,
|
||||
int *max_labels_length, int batch, cudaStream_t stream) {
|
||||
GenLabelValuePCRKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_value_pcr,
|
||||
label_squence_length, cum_labels_length, batch);
|
||||
UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
|
||||
}
|
||||
|
||||
__global__ void GenLabelValueKernel(int *label_value_sp, const int64_t *label_indices, const int *label_values,
|
||||
int *label_squence_length, int *cum_labels_length, int size) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
int64_t b = label_indices[i * 2];
|
||||
int offset = 0;
|
||||
if (b > 0) {
|
||||
offset = cum_labels_length[b - 1];
|
||||
}
|
||||
int64_t index = offset + label_indices[i * 2 + 1];
|
||||
label_value_sp[index] = label_values[i];
|
||||
}
|
||||
}
|
||||
__global__ void LabelValueInitKernel(int *label_value_sp, int size, int blank) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
label_value_sp[i] = blank;
|
||||
}
|
||||
}
|
||||
__global__ void RecalculateLengthKernel(int *label_value_sp, int *label_squence_length, int *cum_labels_length,
|
||||
int batch, int blank) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
|
||||
int offset = 0;
|
||||
if (i > 0) {
|
||||
offset = cum_labels_length[i - 1];
|
||||
}
|
||||
int L = label_squence_length[i];
|
||||
label_squence_length[i] = 0;
|
||||
for (int j = offset; j < offset + L; ++j) {
|
||||
if (label_value_sp[j] >= blank) {
|
||||
break;
|
||||
} else {
|
||||
label_squence_length[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
|
||||
int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, int blank,
|
||||
int batch, cudaStream_t stream) {
|
||||
LabelValueInitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, size, blank);
|
||||
GenLabelValueKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, label_indices, label_values,
|
||||
label_squence_length, cum_labels_length, size);
|
||||
RecalculateLengthKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_squence_length,
|
||||
cum_labels_length, batch, blank);
|
||||
UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
|
||||
}
|
||||
|
||||
__global__ void CalculatePreLengthKernel(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
|
||||
int *max_labels_length, const int64_t *label_indices, int batch, int size) {
|
||||
max_labels_length[0] = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
label_squence_length[label_indices[i * 2]]++;
|
||||
if (max_labels_length[0] < label_indices[i * 2]) {
|
||||
max_labels_length[0] = label_indices[i * 2];
|
||||
}
|
||||
}
|
||||
precum_labels_length[0] = label_squence_length[0];
|
||||
cum_labels_length[0] = label_squence_length[0];
|
||||
for (int i = 1; i < batch; ++i) {
|
||||
cum_labels_length[i] = cum_labels_length[i - 1] + label_squence_length[i];
|
||||
precum_labels_length[i] = precum_labels_length[i - 1] + label_squence_length[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CalculateMaxSequenceKernel(const int *sequence_length, int *max_labels_length, int batch) {
|
||||
max_labels_length[0] = 0;
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
if (sequence_length[i] > max_labels_length[0]) {
|
||||
max_labels_length[0] = sequence_length[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream) {
|
||||
CalculateMaxSequenceKernel<<<1, 1, 0, stream>>>(sequence_length, max_labels_length, batch);
|
||||
}
|
||||
|
||||
void CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
|
||||
int *max_labels_length, const int64_t *label_indices, int batch, int size,
|
||||
cudaStream_t stream) {
|
||||
CalculatePreLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, precum_labels_length, cum_labels_length,
|
||||
max_labels_length, label_indices, batch, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, int SOffSet,
|
||||
int maxtime, int numclass, const int *sequence_length, int *label_squence_length, int *cum_labels_length,
|
||||
T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, cudaStream_t stream) {
|
||||
ProbInitKernel<<<GET_BLOCKS(maxtime * batch * numclass), GET_THREADS, 0, stream>>>(prob_num,
|
||||
maxtime * batch * numclass);
|
||||
CTCLossKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
|
||||
log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, maxtime, numclass, sequence_length,
|
||||
label_squence_length, cum_labels_length, cost, grads, prob_num, ignore_longer_outputs_than_inputs);
|
||||
}
|
||||
|
||||
template void CalculateFwdVar<float>(float *log_alpha_b, int *label_value_with_blank, float *softmax_probs,
|
||||
const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
|
||||
int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
|
||||
bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template void CalculateBwdVar<float>(float *log_beta_b, int *label_value_with_blank, float *softmax_probs,
|
||||
const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
|
||||
int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
|
||||
bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template void InnerSoftMax<float>(const float *probs, float *softmax_probs, const int *sequence_length, int max_time,
|
||||
int batch, int numclass, cudaStream_t stream);
|
||||
|
||||
template void CTCLoss<float>(float *log_alpha_b, float *log_beta_b, float *softmax_probs, int *label_value_with_blank,
|
||||
int batch, int SOffSet, int maxtime, int numclass, const int *sequence_length,
|
||||
int *label_squence_length, int *cum_labels_length, float *cost, float *grads,
|
||||
float *prob_num, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
||||
|
||||
template <typename T>
|
||||
void CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
|
||||
bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
|
||||
int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time, int batch, int numclass,
|
||||
cudaStream_t stream);
|
||||
|
||||
void GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length, int *cum_labels_length,
|
||||
int *max_labels_length, int batch, cudaStream_t stream);
|
||||
|
||||
void GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
|
||||
int *precum_labels_length, int *cum_labels_length, int batch, int blank, cudaStream_t stream);
|
||||
|
||||
void GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
|
||||
int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size, int blank,
|
||||
int batch, cudaStream_t stream);
|
||||
|
||||
void CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
|
||||
int *max_labels_length, const int64_t *label_indices, int batch, int size, cudaStream_t stream);
|
||||
void CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch, int SOffSet,
|
||||
int maxtime, int numclass, const int *sequence_length, int *label_squence_length, int *cum_labels_length,
|
||||
T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMPL_CUH
|
|
@ -1,31 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,192 +1,233 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CtcLossGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CtcLossGpuKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
probs_desc_(nullptr),
|
||||
ctcloss_desc_(nullptr),
|
||||
label_size_(0),
|
||||
input_lengths_size_(0),
|
||||
label_lengths_size_(0) {}
|
||||
~CtcLossGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
float *probs = GetDeviceAddress<float>(inputs, 0);
|
||||
float *costs = GetDeviceAddress<float>(outputs, 0);
|
||||
float *grads = GetDeviceAddress<float>(outputs, 1);
|
||||
|
||||
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
|
||||
int *labels_host = nullptr;
|
||||
int *no_blank_labels_host = nullptr;
|
||||
void *input_lengths_host = nullptr;
|
||||
void *label_lengths_host = nullptr;
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
|
||||
CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetCTCLossWorkspaceSize(
|
||||
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
|
||||
"cudnnGetCTCLossWorkspaceSize failed.");
|
||||
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
|
||||
if (workspace == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
|
||||
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
|
||||
"cudnnCtcLoss failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
|
||||
FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (probs_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
||||
}
|
||||
probs_dims_[0] = probs_shape[0];
|
||||
probs_dims_[1] = probs_shape[1];
|
||||
probs_dims_[2] = probs_shape[2];
|
||||
|
||||
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
if (labels_dims.size() != 1 && labels_dims.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
||||
}
|
||||
label_size_ = sizeof(int);
|
||||
for (auto i : labels_dims) {
|
||||
label_size_ *= i;
|
||||
}
|
||||
|
||||
auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
input_lengths_size_ = input_length_dims[0] * sizeof(int);
|
||||
auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
label_lengths_size_ = label_length_dims[0] * sizeof(int);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_),
|
||||
"cudnnSetTensorNdDescriptorEx failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT,
|
||||
CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN),
|
||||
"cudnnSetCTCLossDescriptorEx failed.");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed.");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
input_size_list_.push_back(label_size_);
|
||||
input_size_list_.push_back(input_lengths_size_);
|
||||
input_size_list_.push_back(label_lengths_size_);
|
||||
|
||||
output_size_list_.push_back(probs_dims_[1] * sizeof(float));
|
||||
output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
||||
}
|
||||
|
||||
void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
||||
}
|
||||
|
||||
void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
// remove blank element
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
||||
if (labels_host[i] != 0) {
|
||||
no_blank_labels_host[j] = labels_host[i];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t probs_desc_;
|
||||
cudnnCTCLossDescriptor_t ctcloss_desc_;
|
||||
int probs_dims_[3] = {0};
|
||||
int label_size_;
|
||||
int input_lengths_size_;
|
||||
int label_lengths_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/ctcloss_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CtcLossGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CtcLossGpuKernel()
|
||||
: label_indice_size_(0),
|
||||
label_size_(0),
|
||||
squence_lengths_size_(0),
|
||||
preprocess_collapse_repeated_(false),
|
||||
ctc_merge_repeated_(true),
|
||||
ignore_longer_outputs_than_inputs_(false) {}
|
||||
~CtcLossGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
const T *probs = GetDeviceAddress<T>(inputs, 0);
|
||||
const int64_t *label_indices = GetDeviceAddress<int64_t>(inputs, 1);
|
||||
const int *label_values = GetDeviceAddress<int>(inputs, 2);
|
||||
const int *sequence_length = GetDeviceAddress<int>(inputs, 3);
|
||||
T *costs = GetDeviceAddress<T>(outputs, 0);
|
||||
T *grads = GetDeviceAddress<T>(outputs, 1);
|
||||
T *softmax_probs = GetDeviceAddress<T>(workspace, 0);
|
||||
int *cum_labels_length = GetDeviceAddress<int>(workspace, 1);
|
||||
int *label_squence_length = GetDeviceAddress<int>(workspace, 2);
|
||||
int *label_value_sp = GetDeviceAddress<int>(workspace, 3);
|
||||
int *label_value_pcr = GetDeviceAddress<int>(workspace, 4);
|
||||
T *prob_num = GetDeviceAddress<T>(workspace, 5);
|
||||
int *precum_labels_length = GetDeviceAddress<int>(workspace, 6);
|
||||
int *max_labels_length = GetDeviceAddress<int>(workspace, 7);
|
||||
int numclass = SizeToInt(probs_dims_[2]);
|
||||
int batch = SizeToInt(probs_dims_[1]);
|
||||
int max_time = SizeToInt(probs_dims_[0]);
|
||||
int max_sequence = 0;
|
||||
CalculateMaxSequence(sequence_length, max_labels_length, batch, stream);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(&max_sequence, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
if (max_time < max_sequence) {
|
||||
MS_LOG(EXCEPTION) << "max_time should be greater than sequence length.";
|
||||
}
|
||||
InnerSoftMax(probs, softmax_probs, sequence_length, max_time, batch, numclass, stream);
|
||||
MemsetForWS(label_value_pcr, cum_labels_length, label_squence_length, costs, grads, stream);
|
||||
int max_labels_length_host = 0;
|
||||
int batch_label = 0;
|
||||
int *label_value_with_blank = nullptr;
|
||||
T *log_alpha_b = nullptr;
|
||||
T *log_beta_b = nullptr;
|
||||
CalculatePreLength(label_squence_length, precum_labels_length, cum_labels_length, max_labels_length, label_indices,
|
||||
batch, label_size_ / sizeof(int), stream);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(&batch_label, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
if (batch != batch_label + 1) {
|
||||
MS_LOG(EXCEPTION) << "label batch should be equal to input batch.";
|
||||
}
|
||||
GenLabelValue(label_value_sp, label_indices, label_values, label_squence_length, cum_labels_length,
|
||||
max_labels_length, label_size_ / sizeof(int), numclass - 1, batch, stream);
|
||||
if (preprocess_collapse_repeated_) {
|
||||
GenLabelValuePCR(label_value_sp, label_value_pcr, label_squence_length, cum_labels_length, max_labels_length,
|
||||
batch, stream);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(&max_labels_length_host, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
int SOffSet = 2 * max_labels_length_host + 1;
|
||||
int log_prob_size = batch * SOffSet * max_time;
|
||||
if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) {
|
||||
MS_LOG(EXCEPTION) << "output size is greater than input size.";
|
||||
}
|
||||
MemManageForCus(&log_alpha_b, &log_beta_b, &label_value_with_blank, cum_labels_length, log_prob_size, batch,
|
||||
stream);
|
||||
|
||||
if (preprocess_collapse_repeated_) {
|
||||
GenLabelWithBlank(label_value_pcr, label_value_with_blank, label_squence_length, precum_labels_length,
|
||||
cum_labels_length, batch, numclass - 1, stream);
|
||||
} else {
|
||||
GenLabelWithBlank(label_value_sp, label_value_with_blank, label_squence_length, precum_labels_length,
|
||||
cum_labels_length, batch, numclass - 1, stream);
|
||||
}
|
||||
|
||||
CalculateFwdVar(log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_, batch,
|
||||
SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length,
|
||||
ignore_longer_outputs_than_inputs_, stream);
|
||||
CalculateBwdVar(log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_, batch,
|
||||
SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length,
|
||||
ignore_longer_outputs_than_inputs_, stream);
|
||||
CTCLoss(log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, max_time, numclass,
|
||||
sequence_length, label_squence_length, cum_labels_length, costs, grads, prob_num,
|
||||
ignore_longer_outputs_than_inputs_, stream);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
FreeMem(label_value_with_blank, log_alpha_b, log_beta_b);
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (probs_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
||||
}
|
||||
probs_dims_[0] = probs_shape[0];
|
||||
probs_dims_[1] = probs_shape[1];
|
||||
probs_dims_[2] = probs_shape[2];
|
||||
auto indice_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
if (labels_dims.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
||||
}
|
||||
if (indice_dims.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "labels indice dims: " << indice_dims.size() << " not support.";
|
||||
}
|
||||
label_size_ = sizeof(int);
|
||||
for (auto i : labels_dims) {
|
||||
label_size_ *= i;
|
||||
}
|
||||
label_indice_size_ = sizeof(int64_t);
|
||||
for (auto i : indice_dims) {
|
||||
label_indice_size_ *= i;
|
||||
}
|
||||
auto squence_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
squence_lengths_size_ = squence_length_dims[0] * sizeof(int);
|
||||
preprocess_collapse_repeated_ = GetAttr<bool>(kernel_node, "preprocess_collapse_repeated");
|
||||
ctc_merge_repeated_ = GetAttr<bool>(kernel_node, "ctc_merge_repeated");
|
||||
ignore_longer_outputs_than_inputs_ = GetAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T));
|
||||
input_size_list_.push_back(label_indice_size_);
|
||||
input_size_list_.push_back(label_size_);
|
||||
input_size_list_.push_back(squence_lengths_size_);
|
||||
workspace_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T));
|
||||
workspace_size_list_.push_back(squence_lengths_size_);
|
||||
workspace_size_list_.push_back(squence_lengths_size_);
|
||||
workspace_size_list_.push_back(label_size_);
|
||||
workspace_size_list_.push_back(label_size_);
|
||||
workspace_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T));
|
||||
workspace_size_list_.push_back(squence_lengths_size_);
|
||||
workspace_size_list_.push_back(sizeof(int));
|
||||
output_size_list_.push_back(probs_dims_[1] * sizeof(T));
|
||||
output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T));
|
||||
}
|
||||
void MemsetForWS(int *label_value_pcr, int *cum_labels_length, int *label_squence_length, T *costs, T *grads,
|
||||
cudaStream_t stream) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(label_value_pcr, static_cast<int>(0), label_size_, stream),
|
||||
"cudaMemSet failed in CtcLossGpuKernel::Launch.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(cum_labels_length, static_cast<int>(0), squence_lengths_size_, stream),
|
||||
"cudaMemSet failed in CtcLossGpuKernel::Launch.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemsetAsync(label_squence_length, static_cast<int>(0), squence_lengths_size_, stream),
|
||||
"cudaMemSet failed in CtcLossGpuKernel::Launch.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(costs, static_cast<T>(0), probs_dims_[1] * sizeof(T), stream),
|
||||
"cudaMemSet failed in CtcLossGpuKernel::Launch.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemsetAsync(grads, static_cast<T>(0), probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(T), stream),
|
||||
"cudaMemSet failed in CtcLossGpuKernel::Launch.");
|
||||
}
|
||||
void MemManageForCus(T **log_alpha_b, T **log_beta_b, int **label_value_with_blank, int *cum_labels_length,
|
||||
int log_prob_size, int batch, cudaStream_t stream) {
|
||||
int total_labels_size_host = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(log_alpha_b), sizeof(T) * log_prob_size),
|
||||
"cudaMalloc failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(log_beta_b), sizeof(T) * log_prob_size),
|
||||
"cudaMalloc failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&total_labels_size_host, cum_labels_length + batch - 1, sizeof(int),
|
||||
cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMalloc(reinterpret_cast<void **>(label_value_with_blank), sizeof(int) * (2 * total_labels_size_host + batch)),
|
||||
"cudaMalloc failed.");
|
||||
}
|
||||
|
||||
void FreeMem(int *label_value_with_blank, T *log_alpha_b, T *log_beta_b) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(label_value_with_blank), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(log_alpha_b), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(log_beta_b), "cudaFree failed.");
|
||||
}
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
size_t probs_dims_[3] = {0};
|
||||
int label_indice_size_;
|
||||
int label_size_;
|
||||
int squence_lengths_size_;
|
||||
bool preprocess_collapse_repeated_;
|
||||
bool ctc_merge_repeated_;
|
||||
bool ignore_longer_outputs_than_inputs_;
|
||||
T kLogZero_ = -std::numeric_limits<T>::infinity();
|
||||
}; // namespace kernel
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CtcLossV2GpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,192 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CtcLossV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
CtcLossV2GpuKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
probs_desc_(nullptr),
|
||||
ctcloss_desc_(nullptr),
|
||||
label_size_(0),
|
||||
input_lengths_size_(0),
|
||||
label_lengths_size_(0) {}
|
||||
~CtcLossV2GpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
float *probs = GetDeviceAddress<float>(inputs, 0);
|
||||
float *costs = GetDeviceAddress<float>(outputs, 0);
|
||||
float *grads = GetDeviceAddress<float>(outputs, 1);
|
||||
|
||||
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
|
||||
int *labels_host = nullptr;
|
||||
int *no_blank_labels_host = nullptr;
|
||||
void *input_lengths_host = nullptr;
|
||||
void *label_lengths_host = nullptr;
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
|
||||
CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetCTCLossWorkspaceSize(
|
||||
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
|
||||
"cudnnGetCTCLossWorkspaceSize failed.");
|
||||
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
|
||||
if (workspace == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
|
||||
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
|
||||
"cudnnCtcLoss failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
|
||||
FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (probs_shape.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
|
||||
}
|
||||
probs_dims_[0] = probs_shape[0];
|
||||
probs_dims_[1] = probs_shape[1];
|
||||
probs_dims_[2] = probs_shape[2];
|
||||
|
||||
auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
if (labels_dims.size() != 1 && labels_dims.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
|
||||
}
|
||||
label_size_ = sizeof(int);
|
||||
for (auto i : labels_dims) {
|
||||
label_size_ *= i;
|
||||
}
|
||||
|
||||
auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
input_lengths_size_ = input_length_dims[0] * sizeof(int);
|
||||
auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
label_lengths_size_ = label_length_dims[0] * sizeof(int);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_),
|
||||
"cudnnSetTensorNdDescriptorEx failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT,
|
||||
CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN),
|
||||
"cudnnSetCTCLossDescriptorEx failed.");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed.");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
input_size_list_.push_back(label_size_);
|
||||
input_size_list_.push_back(input_lengths_size_);
|
||||
input_size_list_.push_back(label_lengths_size_);
|
||||
|
||||
output_size_list_.push_back(probs_dims_[1] * sizeof(float));
|
||||
output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
||||
}
|
||||
|
||||
void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
||||
}
|
||||
|
||||
void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
|
||||
const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpyAsync failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
// remove blank element
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
||||
if (labels_host[i] != 0) {
|
||||
no_blank_labels_host[j] = labels_host[i];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t probs_desc_;
|
||||
cudnnCTCLossDescriptor_t ctcloss_desc_;
|
||||
int probs_dims_[3] = {0};
|
||||
int label_size_;
|
||||
int input_lengths_size_;
|
||||
int label_lengths_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
|
Loading…
Reference in New Issue