Add CTCLossGrad GPU Kernel

This commit is contained in:
zhujingxuan 2022-08-12 09:47:57 +08:00
parent abf2225625
commit 69eb5047e9
8 changed files with 647 additions and 22 deletions

View File

@ -57,19 +57,36 @@ bool CTCLossV2GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
// Getting values
auto kernel_ptr = std::make_shared<ops::CTCLossV2Grad>(base_operator->GetPrim());
blank_ = kernel_ptr->get_blank();
reduction_ = kernel_ptr->get_reduction();
zero_infinity_ = kernel_ptr->get_zero_infinity();
auto log_probs_shape = inputs[kIndex1]->GetShapeVector();
T_ = log_probs_shape[kIndex0];
batch_size_ = log_probs_shape[kIndex1];
num_labels_ = log_probs_shape[kIndex2];
target_shape_ = inputs[kIndex2]->GetShapeVector();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int CTCLossV2GradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto log_probs_shape = inputs[kIndex1]->GetShapeVector();
T_ = log_probs_shape[kIndex0];
batch_size_ = log_probs_shape[kIndex1];
num_labels_ = log_probs_shape[kIndex2];
const auto target_shape = inputs[kIndex2]->GetShapeVector();
max_target_length_ = target_shape[kIndex1];
const size_t scalar_type_size = abstract::TypeIdSize(inputs[kIndex0]->GetDtype());
workspace_size_list_.clear();
workspace_size_list_ = {
LongToSize(batch_size_ * T_ * (target_mul * max_target_length_ + 1)) * scalar_type_size,
};
return KRET_OK;
}
template <typename scalar_t, typename target_t>
void ComputeGrad(scalar_t *log_probs, const NdTensorIterator<kDim3> &log_probs_it, SoftParam params,
scalar_t *log_alpha, const NdTensorIterator<kDim3> &log_alpha_it, scalar_t *log_beta,
@ -131,20 +148,16 @@ bool CTCLossV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
auto target_lengths = reinterpret_cast<target_t *>(inputs[kIndex4]->addr);
auto neg_log_likelihood = reinterpret_cast<scalar_t *>(inputs[kIndex5]->addr);
auto log_alpha = reinterpret_cast<scalar_t *>(inputs[kIndex6]->addr);
auto log_beta = reinterpret_cast<scalar_t *>(workspace[kIndex0]->addr);
auto grad = reinterpret_cast<scalar_t *>(outputs[kIndex0]->addr);
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
std::fill(grad, grad + (T_ * batch_size_ * num_labels_), neginf);
NdTensorIterator<kDim3> log_probs_it(T_, batch_size_, num_labels_);
NdTensorIterator<kDim3> grad_it(T_, batch_size_, num_labels_);
int64_t max_target_length = target_shape_[kIndex1];
std::vector<int64_t> tg_batch_offsets(batch_size_);
int64_t tg_batch_stride = target_shape_[kIndex1];
for (int64_t i = 0; i < batch_size_; i++) {
tg_batch_offsets[i] = i * tg_batch_stride;
}
scalar_t *log_beta = new scalar_t[batch_size_ * T_ * (target_mul * max_target_length + 1)]();
NdTensorIterator<kDim3> log_alpha_it(batch_size_, T_, target_mul * max_target_length + 1);
NdTensorIterator<kDim3> log_beta_it(batch_size_, T_, target_mul * max_target_length + 1);
NdTensorIterator<kDim3> log_alpha_it(batch_size_, T_, target_mul * max_target_length_ + 1);
NdTensorIterator<kDim3> log_beta_it(batch_size_, T_, target_mul * max_target_length_ + 1);
for (int64_t b = 0; b < batch_size_; b++) {
scalar_t nll = neg_log_likelihood[b];
if (zero_infinity_ && nll == std::numeric_limits<scalar_t>::infinity()) {
@ -157,9 +170,9 @@ bool CTCLossV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
}
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t tg_batch_offset = tg_batch_offsets[b];
int64_t tg_batch_offset = max_target_length_ * b;
if (input_length > 0) {
for (size_t s = 0; s < kIndex2 * max_target_length + 1; s++) {
for (size_t s = 0; s < kIndex2 * max_target_length_ + 1; s++) {
log_beta[log_beta_it(b, input_length - 1, s)] = neginf;
}
log_beta[log_beta_it(b, input_length - 1, target_mul * target_length)] =
@ -193,7 +206,7 @@ bool CTCLossV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
}
}
}
delete[] log_beta;
return true;
}

View File

@ -40,6 +40,9 @@ class CTCLossV2GradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
@ -49,9 +52,8 @@ class CTCLossV2GradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
int64_t T_ = 0;
int64_t batch_size_ = 0;
int64_t num_labels_ = 0;
std::vector<int64_t> target_shape_;
int64_t max_target_length_ = 0;
int32_t blank_ = 0;
std::string reduction_ = "mean";
bool zero_infinity_ = false;
// Dealing with multiple types
template <typename scalar_t, typename target_t>

View File

@ -85,6 +85,54 @@ __device__ __forceinline__ void LossCompute(const S *log_probs_p, S *log_alpha_p
}
}
template <typename S, typename T>
__device__ __forceinline__ void GradCompute(const S *log_probs, const S *log_alpha, S *log_beta, int64_t blank,
int64_t input_length, int64_t target_length, int64_t tg_batch_offset,
int64_t b, const T *targets, dim3 log_probs_shape, dim3 log_alpha_shape,
S *grad) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
for (int64_t t = input_length - 2; t >= 0; t--) {
for (int64_t s = 2 * target_length; s >= 0; s--) {
S lb1 = log_beta[GetOffset3D(log_alpha_shape, b, t + 1, s)];
S lbmax = lb1;
S lb2, lb3;
auto current_target_prime = GetBlankPaddedTarget<T>(targets, tg_batch_offset, s, blank);
if (s < 2 * target_length) {
lb2 = log_beta[GetOffset3D(log_alpha_shape, b, t + 1, s + 1)];
if (lb2 > lbmax) {
lbmax = lb2;
}
} else {
lb2 = neg_inf;
}
if ((s < 2 * target_length - 1) &&
(GetBlankPaddedTarget<T>(targets, tg_batch_offset, s + 2, blank) != current_target_prime)) {
lb3 = log_beta[GetOffset3D(log_alpha_shape, b, t + 1, s + 2)];
if (lb3 > lbmax) {
lbmax = lb3;
}
} else {
lb3 = neg_inf;
}
if (lbmax == neg_inf) {
lbmax = 0;
}
log_beta[GetOffset3D(log_alpha_shape, b, t, s)] =
std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) + std::exp(lb3 - lbmax)) + lbmax +
log_probs[GetOffset3D(log_probs_shape, t, b, current_target_prime)];
S log_alpha_beta =
log_alpha[GetOffset3D(log_alpha_shape, b, t, s)] + log_beta[GetOffset3D(log_alpha_shape, b, t, s)];
S &lcab = grad[GetOffset3D(log_probs_shape, t, b, current_target_prime)];
if (lcab == neg_inf) {
lcab = log_alpha_beta;
} else {
S max_val = max(lcab, log_alpha_beta);
lcab = std::log(std::exp(lcab - max_val) + std::exp(log_alpha_beta - max_val)) + max_val;
}
}
}
}
template <typename S, typename T>
__global__ void CTCLossV2Kernel(const S *log_probs_p, const T *target_p, const T *input_len_p, const T *target_len_p,
int64_t max_target_length, int64_t time_series, int64_t batch_size, T blank,
@ -133,6 +181,78 @@ void CalCTCLossV2(const S *log_probs_p, const T *target_p, const T *input_len_p,
log_probs_shape, log_alpha_shape, neg_log_p, log_alpha_p);
}
template <typename S, typename T>
__global__ void CTCLossV2GradKernel(const S *grad_out, const S *log_probs, const T *targets, const T *input_lengths,
const T *target_lengths, const S *neg_log_likelihood, const S *log_alpha,
S *log_beta, int64_t batch_size, int64_t time_series, int64_t num_labels,
int64_t max_target_length, bool zero_infinity, T blank, dim3 log_probs_shape,
dim3 log_alpha_shape, S *grad) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
for (int64_t b = 0; b < batch_size; b++) {
S nll = neg_log_likelihood[b];
if (zero_infinity && nll == std::numeric_limits<S>::infinity()) {
for (int t = 0; t < time_series; t++) {
for (int c = 0; c < num_labels; c++) {
grad[GetOffset3D(log_probs_shape, t, b, c)] = 0;
}
}
continue;
}
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t tg_batch_offset = max_target_length * b;
if (input_length > 0) {
for (size_t s = 0; s < 2 * max_target_length + 1; s++) {
log_beta[GetOffset3D(log_alpha_shape, b, input_length - 1, s)] = neg_inf;
}
log_beta[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length)] =
log_probs[GetOffset3D(log_probs_shape, input_length - 1, b, blank)];
grad[GetOffset3D(log_probs_shape, input_length - 1, b, blank)] =
log_alpha[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length)] +
log_beta[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length)];
if (target_length > 0) {
auto current_target_prime = GetBlankPaddedTarget(targets, tg_batch_offset, 2 * target_length - 1, blank);
log_beta[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length - 1)] =
log_probs[GetOffset3D(log_probs_shape, input_length - 1, b, current_target_prime)];
grad[GetOffset3D(log_probs_shape, input_length - 1, b, current_target_prime)] =
log_alpha[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length - 1)] +
log_beta[GetOffset3D(log_alpha_shape, b, input_length - 1, 2 * target_length - 1)];
}
}
GradCompute<S, T>(log_probs, log_alpha, log_beta, blank, input_length, target_length, tg_batch_offset, b, targets,
log_probs_shape, log_alpha_shape, grad);
S gr = grad_out[b];
for (int64_t t = 0; t < input_length; t++) {
for (int64_t c = 0; c < num_labels; c++) {
S &res = grad[GetOffset3D(log_probs_shape, t, b, c)];
S lp = log_probs[GetOffset3D(log_probs_shape, t, b, c)];
res = (std::exp(lp) - std::exp(res + nll - lp)) * gr;
}
}
for (auto l = input_length; l < time_series; l++) {
for (int c = 0; c < num_labels; c++) {
grad[GetOffset3D(log_probs_shape, l, b, c)] = 0;
}
}
}
}
template <typename S, typename T>
void CalCTCLossGradV2(const S *grad_out, const S *log_probs, const T *targets, const T *input_lengths,
const T *target_lengths, const S *neg_log_likelihood, const S *log_alpha, S *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length,
bool zero_infinity, T blank, dim3 log_probs_shape, dim3 log_alpha_shape, S *grad,
uint32_t device_id, cudaStream_t cuda_stream) {
constexpr S neg_inf = -std::numeric_limits<S>::infinity();
const size_t grad_size = log_probs_shape.x * log_probs_shape.y * log_probs_shape.z;
thrust::device_ptr<S> dev_ptr(grad);
thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + grad_size, neg_inf);
CTCLossV2GradKernel<<<CUDA_BLOCKS(device_id, batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
grad_out, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, log_beta, batch_size,
time_series, num_labels, max_target_length, zero_infinity, blank, log_probs_shape, log_alpha_shape, grad);
}
template CUDA_LIB_EXPORT void CalCTCLossV2<float, int>(const float *log_probs_p, const int *target_p,
const int *input_len_p, const int *target_len_p,
int64_t batch_size, int64_t target_stride, int64_t time_series,
@ -157,3 +277,28 @@ template CUDA_LIB_EXPORT void CalCTCLossV2<double, int64_t>(
const double *log_probs_p, const int64_t *target_p, const int64_t *input_len_p, const int64_t *target_len_p,
int64_t batch_size, int64_t target_stride, int64_t time_series, int64_t blank, dim3 log_probs_shape,
dim3 log_alpha_shape, double *neg_log_p, double *log_alpha_p, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCTCLossGradV2<float, int>(
const float *grad_out, const float *log_probs, const int *targets, const int *input_lengths,
const int *target_lengths, const float *neg_log_likelihood, const float *log_alpha, float *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length, bool zero_infinity, int blank,
dim3 log_probs_shape, dim3 log_alpha_shape, float *grad, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCTCLossGradV2<double, int>(
const double *grad_out, const double *log_probs, const int *targets, const int *input_lengths,
const int *target_lengths, const double *neg_log_likelihood, const double *log_alpha, double *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length, bool zero_infinity, int blank,
dim3 log_probs_shape, dim3 log_alpha_shape, double *grad, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCTCLossGradV2<float, int64_t>(
const float *grad_out, const float *log_probs, const int64_t *targets, const int64_t *input_lengths,
const int64_t *target_lengths, const float *neg_log_likelihood, const float *log_alpha, float *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length, bool zero_infinity,
int64_t blank, dim3 log_probs_shape, dim3 log_alpha_shape, float *grad, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCTCLossGradV2<double, int64_t>(
const double *grad_out, const double *log_probs, const int64_t *targets, const int64_t *input_lengths,
const int64_t *target_lengths, const double *neg_log_likelihood, const double *log_alpha, double *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length, bool zero_infinity,
int64_t blank, dim3 log_probs_shape, dim3 log_alpha_shape, double *grad, uint32_t device_id,
cudaStream_t cuda_stream);

View File

@ -22,4 +22,11 @@ void CalCTCLossV2(const S *log_probs_p, const T *target_p, const T *input_len_p,
int64_t batch_size, int64_t target_stride, int64_t time_series, T blank, dim3 log_probs_shape,
dim3 log_alpha_shape, S *neg_log_p, S *log_alpha_p, uint32_t device_id, cudaStream_t cuda_stream);
template <typename S, typename T>
void CalCTCLossGradV2(const S *grad_out, const S *log_probs, const T *targets, const T *input_lengths,
const T *target_lengths, const S *neg_log_likelihood, const S *log_alpha, S *log_beta,
int64_t batch_size, int64_t time_series, int64_t num_labels, int64_t max_target_length,
bool zero_infinity, T blank, dim3 log_probs_shape, dim3 log_alpha_shape, S *grad,
uint32_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_V2_IMPL_CUH_

View File

@ -0,0 +1,142 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/ctcloss_v2_grad_gpu_kernel.h"
#include <memory>
#include "abstract/utils.h"
#include "mindspore/core/ops/ctc_loss_v2_grad.h"
namespace mindspore {
namespace kernel {
namespace {
using KernelRunFunc = CTCLossV2GradGpuKernelMod::KernelRunFunc;
constexpr int64_t kInterval = 2;
} // namespace
bool CTCLossV2GradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
// Getting values
auto kernel_ptr = std::make_shared<ops::CTCLossV2Grad>(base_operator->GetPrim());
blank_ = kernel_ptr->get_blank();
zero_infinity_ = kernel_ptr->get_zero_infinity();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int CTCLossV2GradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto log_probs_shape = inputs[kIndex1]->GetShapeVector();
time_series_ = log_probs_shape[kIndex0];
batch_size_ = log_probs_shape[kIndex1];
num_labels_ = log_probs_shape[kIndex2];
const auto target_shape = inputs[kIndex2]->GetShapeVector();
max_target_length_ = target_shape[kIndex1];
log_probs_shape_.x = LongToSize(time_series_);
log_probs_shape_.y = LongToSize(batch_size_);
log_probs_shape_.z = LongToSize(num_labels_);
log_alpha_shape_.x = LongToSize(batch_size_);
log_alpha_shape_.y = LongToSize(time_series_);
log_alpha_shape_.z = LongToSize(kInterval * max_target_length_ + 1);
const size_t scalar_type_size = abstract::TypeIdSize(inputs[kIndex0]->GetDtype());
workspace_size_list_.clear();
workspace_size_list_ = {
LongToSize(batch_size_ * time_series_ * (kInterval * max_target_length_ + 1)) * scalar_type_size,
};
return KRET_OK;
}
template <typename scalar_t, typename target_t>
bool CTCLossV2GradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
auto grad_out = reinterpret_cast<scalar_t *>(inputs[kIndex0]->addr);
auto log_probs = reinterpret_cast<scalar_t *>(inputs[kIndex1]->addr);
auto targets = reinterpret_cast<target_t *>(inputs[kIndex2]->addr);
auto input_lengths = reinterpret_cast<target_t *>(inputs[kIndex3]->addr);
auto target_lengths = reinterpret_cast<target_t *>(inputs[kIndex4]->addr);
auto neg_log_likelihood = reinterpret_cast<scalar_t *>(inputs[kIndex5]->addr);
auto log_alpha = reinterpret_cast<scalar_t *>(inputs[kIndex6]->addr);
auto log_beta = reinterpret_cast<scalar_t *>(workspace[kIndex0]->addr);
auto grad = reinterpret_cast<scalar_t *>(outputs[kIndex0]->addr);
CalCTCLossGradV2<scalar_t, target_t>(grad_out, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood,
log_alpha, log_beta, batch_size_, time_series_, num_labels_, max_target_length_,
zero_infinity_, blank_, log_probs_shape_, log_alpha_shape_, grad, device_id_,
stream_ptr_);
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &CTCLossV2GradGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, CTCLossV2GradGpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&CTCLossV2GradGpuKernelMod::LaunchKernel<float, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&CTCLossV2GradGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&CTCLossV2GradGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&CTCLossV2GradGpuKernelMod::LaunchKernel<double, int64_t>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CTCLossV2Grad, CTCLossV2GradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_V2_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_V2_GRAD_GPU_KERNEL_H_
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_v2_impl.cuh"
namespace mindspore {
namespace kernel {
class CTCLossV2GradGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<CTCLossV2GradGpuKernelMod> {
public:
CTCLossV2GradGpuKernelMod() = default;
~CTCLossV2GradGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
stream_ptr_ = reinterpret_cast<cudaStream_t>(cuda_stream);
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
[[nodiscard]] const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename S, typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
// Variables for the operator itself
int64_t blank_{0};
// Stands for T
int64_t time_series_{0};
// Stands for N
int64_t batch_size_{0};
// Stands for C
int64_t num_labels_{0};
// Stands for S
int64_t max_target_length_{0};
dim3 log_probs_shape_;
dim3 log_alpha_shape_;
bool zero_infinity_ = false;
bool is_null_input_{false};
cudaStream_t stream_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_V2_GRAD_GPU_KERNEL_H_

View File

@ -17,8 +17,29 @@ import pytest
import numpy as np
from scipy.special import log_softmax
from mindspore import Tensor, context
from mindspore import Tensor, context, nn
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation
class Net(nn.Cell):
def __init__(self, blank, reduction):
super(Net, self).__init__()
self.loss = P.CTCLossV2(blank=blank, reduction=reduction)
def construct(self, input_matrix, target, input_lengths, target_lengths):
x, _ = self.loss(input_matrix, target, input_lengths, target_lengths)
return x
class GradData(nn.Cell):
def __init__(self, network):
super(GradData, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=False)
self.network = network
def construct(self, probs, indices, labels, input_lengths):
return self.grad(self.network)(probs, indices, labels, input_lengths)[0]
def logsumexp(a, b):
@ -124,6 +145,103 @@ def test_ctc_loss_v2_un_padded(batch, data_type):
compare_to_numpy(method, input_matrix, target, input_lengths, target_lengths)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu
def test_ctc_loss_v2_un_padded_grad():
"""
Feature: Test CTCLossV2.
Description: The input is padded and the target target_sequences maybe equal to input_sequences
Expectation: Result matches the numpy implemented version.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
np.random.seed(0)
batch = 10
data_type = np.float64
method = 'none'
input_sequences = 5
classes = 3
target_sequences = input_sequences
target_sequences_min = 1
input_matrix = log_softmax(np.random.randn(input_sequences, batch, classes), 2).astype(data_type)
input_lengths = np.full(shape=(batch,), fill_value=input_sequences, dtype=np.int64)
target_lengths = np.random.randint(low=target_sequences_min, high=target_sequences, size=(batch,), dtype=np.int64)
target = np.random.randint(low=1, high=classes, size=(batch, np.max(target_lengths)), dtype=np.int64)
input_matrix = Tensor(input_matrix)
target = Tensor(target)
input_lengths = Tensor(input_lengths)
target_lengths = Tensor(target_lengths)
net = Net(blank=0, reduction=method)
loss = net(input_matrix, target, input_lengths, target_lengths)
print(np.mean(loss.asnumpy()))
expected_grad = np.array([[[2.21999385e-01, 1.49367328e-01, -3.71366713e-01],
[1.21524177e-01, -1.44682444e-01, 2.31582675e-02],
[-2.72130267e-01, 6.46665395e-02, 2.07463727e-01],
[-2.08145533e-01, 1.66320573e-01, 4.18249597e-02],
[-6.40181898e-02, 2.33895304e-01, -1.69877115e-01],
[1.66409322e-01, -2.88602261e-01, 1.22192939e-01],
[7.30902539e-01, 2.27492367e-01, -9.58394906e-01],
[4.02847968e-01, -5.02608798e-01, 9.97608303e-02],
[np.nan, np.nan, np.nan],
[3.54625312e-02, -4.78671463e-01, 4.43208932e-01]],
[[-2.47471377e-01, 4.80327110e-01, -2.32855733e-01],
[-7.19044568e-04, -3.50625805e-01, 3.51344849e-01],
[-7.06902188e-03, -8.43104544e-02, 9.13794763e-02],
[-9.66319775e-02, 2.63241041e-01, -1.66609063e-01],
[-1.33521992e-01, 8.99922273e-01, -7.66400281e-01],
[3.06852929e-02, -2.73818291e-02, -3.30346375e-03],
[-8.59374974e-01, 5.70923102e-01, 2.88451871e-01],
[-3.81211648e-01, 2.52157946e-01, 1.29053702e-01],
[np.nan, np.nan, np.nan],
[5.36556015e-03, -5.05351326e-02, 4.51695724e-02]],
[[-2.80587684e-01, 4.22536598e-01, -1.41948913e-01],
[7.27670649e-03, 1.89209901e-01, -1.96486607e-01],
[3.08220162e-02, -2.15289462e-01, 1.84467446e-01],
[-3.79525891e-01, 4.86187906e-01, -1.06662015e-01],
[-1.89419282e-01, 5.92301671e-02, 1.30189115e-01],
[-5.41303138e-02, 2.43931812e-01, -1.89801498e-01],
[3.48393864e-01, 5.03232105e-01, -8.51625969e-01],
[5.76510068e-01, -6.26906613e-01, 5.03965454e-02],
[np.nan, np.nan, np.nan],
[4.65760597e-02, 5.56476308e-02, -1.02223690e-01]],
[[-2.14977953e-01, 6.41234229e-01, -4.26256276e-01],
[5.60402917e-02, 1.95392934e-01, -2.51433226e-01],
[1.01160497e-01, -2.41139199e-01, 1.39978702e-01],
[-6.77672365e-01, 7.89331393e-01, -1.11659028e-01],
[-3.17830985e-01, 8.17107077e-01, -4.99276092e-01],
[-6.53148112e-02, 7.75629981e-02, -1.22481869e-02],
[-6.13690461e-01, 2.48194256e-01, 3.65496205e-01],
[2.56408238e-01, 4.37941616e-02, -3.00202400e-01],
[np.nan, np.nan, np.nan],
[-3.73922094e-02, 3.48893393e-01, -3.11501184e-01]],
[[6.82148555e-02, 1.06153890e-01, -1.74368746e-01],
[-3.82024509e-02, 9.73708746e-02, -5.91684237e-02],
[-3.38166563e-02, -1.84766114e-01, 2.18582770e-01],
[-3.88080543e-01, 1.25803041e-01, 2.62277502e-01],
[-5.94619350e-02, 4.98396907e-01, -4.38934972e-01],
[-4.53783646e-01, 3.90447024e-01, 6.33366220e-02],
[7.26180335e-01, 1.63813757e-01, -8.89994092e-01],
[3.35863956e-01, -7.44304322e-01, 4.08440366e-01],
[np.nan, np.nan, np.nan],
[-7.70374805e-02, 6.78337545e-02, 9.20372600e-03]]])
grad = GradData(net)(input_matrix, target, input_lengths, target_lengths)
print(grad.shape)
print(grad)
np.allclose(grad.asnumpy(), expected_grad)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_cpu

View File

@ -17,8 +17,29 @@ import pytest
import numpy as np
from scipy.special import log_softmax
from mindspore import Tensor, context
from mindspore import Tensor, context, nn
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation
class Net(nn.Cell):
def __init__(self, blank, reduction):
super(Net, self).__init__()
self.loss = P.CTCLossV2(blank=blank, reduction=reduction)
def construct(self, input_matrix, target, input_lengths, target_lengths):
x, _ = self.loss(input_matrix, target, input_lengths, target_lengths)
return x
class GradData(nn.Cell):
def __init__(self, network):
super(GradData, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=False)
self.network = network
def construct(self, probs, indices, labels, input_lengths):
return self.grad(self.network)(probs, indices, labels, input_lengths)[0]
def logsumexp(a, b):
@ -124,6 +145,103 @@ def test_ctc_loss_v2_un_padded(batch, data_type):
compare_to_numpy(method, input_matrix, target, input_lengths, target_lengths)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
def test_ctc_loss_v2_un_padded_grad():
"""
Feature: Test CTCLossV2.
Description: The input is padded and the target target_sequences maybe equal to input_sequences
Expectation: Result matches the numpy implemented version.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
np.random.seed(0)
batch = 10
data_type = np.float64
method = 'none'
input_sequences = 5
classes = 3
target_sequences = input_sequences
target_sequences_min = 1
input_matrix = log_softmax(np.random.randn(input_sequences, batch, classes), 2).astype(data_type)
input_lengths = np.full(shape=(batch,), fill_value=input_sequences, dtype=np.int64)
target_lengths = np.random.randint(low=target_sequences_min, high=target_sequences, size=(batch,), dtype=np.int64)
target = np.random.randint(low=1, high=classes, size=(batch, np.max(target_lengths)), dtype=np.int64)
input_matrix = Tensor(input_matrix)
target = Tensor(target)
input_lengths = Tensor(input_lengths)
target_lengths = Tensor(target_lengths)
net = Net(blank=0, reduction=method)
loss = net(input_matrix, target, input_lengths, target_lengths)
print(np.mean(loss.asnumpy()))
expected_grad = np.array([[[2.21999385e-01, 1.49367328e-01, -3.71366713e-01],
[1.21524177e-01, -1.44682444e-01, 2.31582675e-02],
[-2.72130267e-01, 6.46665395e-02, 2.07463727e-01],
[-2.08145533e-01, 1.66320573e-01, 4.18249597e-02],
[-6.40181898e-02, 2.33895304e-01, -1.69877115e-01],
[1.66409322e-01, -2.88602261e-01, 1.22192939e-01],
[7.30902539e-01, 2.27492367e-01, -9.58394906e-01],
[4.02847968e-01, -5.02608798e-01, 9.97608303e-02],
[np.nan, np.nan, np.nan],
[3.54625312e-02, -4.78671463e-01, 4.43208932e-01]],
[[-2.47471377e-01, 4.80327110e-01, -2.32855733e-01],
[-7.19044568e-04, -3.50625805e-01, 3.51344849e-01],
[-7.06902188e-03, -8.43104544e-02, 9.13794763e-02],
[-9.66319775e-02, 2.63241041e-01, -1.66609063e-01],
[-1.33521992e-01, 8.99922273e-01, -7.66400281e-01],
[3.06852929e-02, -2.73818291e-02, -3.30346375e-03],
[-8.59374974e-01, 5.70923102e-01, 2.88451871e-01],
[-3.81211648e-01, 2.52157946e-01, 1.29053702e-01],
[np.nan, np.nan, np.nan],
[5.36556015e-03, -5.05351326e-02, 4.51695724e-02]],
[[-2.80587684e-01, 4.22536598e-01, -1.41948913e-01],
[7.27670649e-03, 1.89209901e-01, -1.96486607e-01],
[3.08220162e-02, -2.15289462e-01, 1.84467446e-01],
[-3.79525891e-01, 4.86187906e-01, -1.06662015e-01],
[-1.89419282e-01, 5.92301671e-02, 1.30189115e-01],
[-5.41303138e-02, 2.43931812e-01, -1.89801498e-01],
[3.48393864e-01, 5.03232105e-01, -8.51625969e-01],
[5.76510068e-01, -6.26906613e-01, 5.03965454e-02],
[np.nan, np.nan, np.nan],
[4.65760597e-02, 5.56476308e-02, -1.02223690e-01]],
[[-2.14977953e-01, 6.41234229e-01, -4.26256276e-01],
[5.60402917e-02, 1.95392934e-01, -2.51433226e-01],
[1.01160497e-01, -2.41139199e-01, 1.39978702e-01],
[-6.77672365e-01, 7.89331393e-01, -1.11659028e-01],
[-3.17830985e-01, 8.17107077e-01, -4.99276092e-01],
[-6.53148112e-02, 7.75629981e-02, -1.22481869e-02],
[-6.13690461e-01, 2.48194256e-01, 3.65496205e-01],
[2.56408238e-01, 4.37941616e-02, -3.00202400e-01],
[np.nan, np.nan, np.nan],
[-3.73922094e-02, 3.48893393e-01, -3.11501184e-01]],
[[6.82148555e-02, 1.06153890e-01, -1.74368746e-01],
[-3.82024509e-02, 9.73708746e-02, -5.91684237e-02],
[-3.38166563e-02, -1.84766114e-01, 2.18582770e-01],
[-3.88080543e-01, 1.25803041e-01, 2.62277502e-01],
[-5.94619350e-02, 4.98396907e-01, -4.38934972e-01],
[-4.53783646e-01, 3.90447024e-01, 6.33366220e-02],
[7.26180335e-01, 1.63813757e-01, -8.89994092e-01],
[3.35863956e-01, -7.44304322e-01, 4.08440366e-01],
[np.nan, np.nan, np.nan],
[-7.70374805e-02, 6.78337545e-02, 9.20372600e-03]]])
grad = GradData(net)(input_matrix, target, input_lengths, target_lengths)
print(grad.shape)
print(grad)
np.allclose(grad.asnumpy(), expected_grad)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training