!41525 【I5EWOO】add data type for sigmoidgrad

Merge pull request !41525 from 桂胜楠/sigmoid_grad
This commit is contained in:
i-robot 2022-11-18 11:26:28 +00:00 committed by Gitee
commit 2037c79036
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 102 additions and 15 deletions

View File

@ -60,6 +60,7 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
void ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void ComplexSigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
@ -125,13 +126,29 @@ void EltWiseGradCpuTypeFunc<T>::AbsGrad(const T *input1, const T *input2, T *out
template <typename T>
void EltWiseGradCpuTypeFunc<T>::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
if constexpr (!std::is_same<T, float>::value) {
MS_LOG(EXCEPTION) << "For 'SigmoidGrad', the dtype of input must be float.";
if constexpr (std::is_same<T, float>::value) {
int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "For 'SigmoidGrad', execute failed. Error no: " << ret;
}
} else {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = input1[i] * (static_cast<T>(1) - input1[i]);
out[i] = dividend * divisor;
}
}
}
int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "For 'SigmoidGrad', execute failed. Error no: " << ret;
template <typename T>
void EltWiseGradCpuTypeFunc<T>::ComplexSigmoidGrad(const T *input1, const T *input2, T *out, size_t start,
size_t end) const {
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = std::conj(input1[i] * (static_cast<T>(1) - input1[i]));
out[i] = dividend * divisor;
}
}
}
@ -440,7 +457,8 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
{prim::kPrimAbsGrad->name(), &EltWiseGradCpuTypeFunc<T>::AbsGrad},
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad},
{prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc<T>::SigmoidGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_ << " with double as input.";
}
@ -452,7 +470,8 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc<T>::SigmoidGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input.";
}
@ -519,7 +538,8 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc<T>::ComplexSigmoidGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
}
@ -580,7 +600,21 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> ke
&SpecializeEltWiseGradFunc<double>}}},
{kSigmoidGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>}}},
&SpecializeEltWiseGradFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SpecializeEltWiseGradFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SpecializeEltWiseGradFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&SpecializeEltWiseGradFunc<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&SpecializeEltWiseGradFunc<complex128>}}},
{kSqrtGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>},

View File

@ -199,6 +199,26 @@ __global__ void TanhGradKernel(const T *__restrict__ input, const T *dout, T *ou
return;
}
template <typename T>
__global__ void SigmoidGradKernel(const T *__restrict__ input, const T *dout, T *output, const size_t count) {
const T one = static_cast<T>(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T divisor = input[i] * (one - input[i]);
output[i] = dout[i] * divisor;
}
return;
}
template <typename T>
__global__ void SigmoidGradKernel(const Complex<T> *input, const Complex<T> *dout, Complex<T> *output,
const size_t count) {
Complex<T> one = static_cast<Complex<T>>(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
output[i] = dout[i] * conj(input[i] * (one - input[i]));
}
return;
}
template <typename T>
__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@ -352,6 +372,12 @@ void TanhGrad(const T *input, const T *dout, T *output, const size_t count, cuda
return;
}
template <typename T>
void SigmoidGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
SigmoidGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template <typename T>
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
AsinhGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
@ -403,6 +429,14 @@ template CUDA_LIB_EXPORT void SqrtGrad<Complex<double>>(const Complex<double> *i
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void SqrtGrad<Complex<float>>(const Complex<float> *input, const Complex<float> *dout,
Complex<float> *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void SigmoidGrad<Complex<double>>(const Complex<double> *input, const Complex<double> *dout,
Complex<double> *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void SigmoidGrad<Complex<float>>(const Complex<float> *input, const Complex<float> *dout,
Complex<float> *output, const size_t count,
cudaStream_t cuda_stream);
@ -426,6 +460,8 @@ template CUDA_LIB_EXPORT void ReciprocalGrad<double>(const double *input, const
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void InvGrad<double>(const double *input, const double *dout, double *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void SigmoidGrad<double>(const double *input, const double *dout, double *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,

View File

@ -30,6 +30,9 @@ CUDA_LIB_EXPORT void AtanGrad(const T *input, const T *dout, T *output, const si
template <typename T>
CUDA_LIB_EXPORT void TanhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void SigmoidGrad(const T *input, const T *dout, T *output, const size_t count,
cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -1845,13 +1845,13 @@ template CUDA_LIB_EXPORT void Sigmoid<Complex<float>>(const Complex<float> *inpu
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Square<Complex<float>>(const Complex<float> *input, Complex<float> *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tanh<Complex<float>>(const Complex<float> *input, Complex<float> *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Logarithm<Complex<float>>(const Complex<float> *input, Complex<float> *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Negative<Complex<float>>(const Complex<float> *input, Complex<float> *output,
const size_t count, cudaStream_t cuda_stream);
// complex128
template CUDA_LIB_EXPORT void Exponential<Complex<double>>(const Complex<double> *input, Complex<double> *output,
const size_t count, cudaStream_t cuda_stream);
@ -1899,6 +1899,7 @@ template CUDA_LIB_EXPORT void Logarithm<Complex<double>>(const Complex<double> *
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Negative<Complex<double>>(const Complex<double> *input, Complex<double> *output,
const size_t count, cudaStream_t cuda_stream);
// bool
template CUDA_LIB_EXPORT void Real<bool>(const bool *input, bool *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Imag<bool>(const bool *input, bool *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -60,7 +60,19 @@ std::map<std::string, std::vector<std::pair<KernelAttr, ActivationGradGpuKernelM
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ActivationGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ActivationGradGpuKernelMod::LaunchKernel<half>}}}};
&ActivationGradGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&ActivationGradGpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&ActivationGradGpuKernelMod::LaunchKernel<utils::Complex<float>>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&ActivationGradGpuKernelMod::LaunchKernel<utils::Complex<double>>}}}};
bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
@ -100,9 +112,9 @@ bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
const auto dtype = inputs.at(kIndex0)->GetDtype();
if (((dtype == kNumberTypeFloat64) || (dtype == kNumberTypeComplex64) || (dtype == kNumberTypeComplex128)) &&
(kernel_name_ != kTanhGrad)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only tanh support complex input, but got " << kernel_name_
<< " with dtype " << TypeIdLabel(inputs.at(kIndex0)->GetDtype());
(kernel_name_ != kTanhGrad) && (kernel_name_ != kSigmoidGrad)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only tanh and sigmoid support complex input, but got "
<< kernel_name_ << " with dtype " << TypeIdLabel(inputs.at(kIndex0)->GetDtype());
}
return true;
@ -127,7 +139,7 @@ int ActivationGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
const auto dtype = inputs.at(kIndex0)->GetDtype();
if (((dtype == kNumberTypeFloat64) || (dtype == kNumberTypeComplex64) || (dtype == kNumberTypeComplex128)) &&
(kernel_name_ == kTanhGrad)) {
((kernel_name_ == kTanhGrad) || (kernel_name_ == kSigmoidGrad))) {
// Does not call Cudnn
return KRET_OK;
}
@ -198,6 +210,7 @@ bool ActivationGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
std::is_same_v<T, double> || std::is_same_v<T, utils::Complex<float>> || std::is_same_v<T, utils::Complex<double>>;
if constexpr (use_unary) {
TanhGrad(y, dy, dx, input_size_list_[0] / sizeof(T), reinterpret_cast<cudaStream_t>(cuda_stream_));
SigmoidGrad(y, dy, dx, input_size_list_[0] / sizeof(T), reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}