forked from mindspore-Ecosystem/mindspore
!3244 Support 2-dimensions target of CTCLossV2
Merge pull request !3244 from yangyongjie/yangyongjie
This commit is contained in:
commit
1625a27ae5
|
@ -51,10 +51,12 @@ class CtcLossGpuKernel : public GpuKernel {
|
|||
float *grads = GetDeviceAddress<float>(outputs, 1);
|
||||
|
||||
// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
|
||||
void *labels_host = nullptr;
|
||||
int *labels_host = nullptr;
|
||||
int *no_blank_labels_host = nullptr;
|
||||
void *input_lengths_host = nullptr;
|
||||
void *label_lengths_host = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
|
@ -68,12 +70,21 @@ class CtcLossGpuKernel : public GpuKernel {
|
|||
"cudaMemcpyAsync failed.");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
|
||||
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
|
||||
if (labels_host[i] != 0) {
|
||||
no_blank_labels_host[j] = labels_host[i];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host),
|
||||
reinterpret_cast<int *>(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||
ctcloss_desc_, &workspace_size),
|
||||
cudnnGetCTCLossWorkspaceSize(
|
||||
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
|
||||
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
|
||||
"cudnnGetCTCLossWorkspaceSize failed.");
|
||||
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
|
||||
if (workspace == nullptr) {
|
||||
|
@ -81,7 +92,7 @@ class CtcLossGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(labels_host),
|
||||
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
|
||||
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
|
||||
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
|
||||
"cudnnCtcLoss failed.");
|
||||
|
@ -91,6 +102,7 @@ class CtcLossGpuKernel : public GpuKernel {
|
|||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
|
Loading…
Reference in New Issue