From 1085bb45e1b2a1ef812040b9524630a0a4978231 Mon Sep 17 00:00:00 2001 From: hj-ustb Date: Sun, 13 Nov 2022 11:20:35 +0800 Subject: [PATCH] [feat][assistant][ops][I5EWOO] add data type for sigmoidgrad --- .../cpu/kernel/eltwise_grad_cpu_kernel.cc | 52 +++++++++++++++---- .../cuda_impl/cuda_ops/unary_op_grad_impl.cu | 36 +++++++++++++ .../cuda_impl/cuda_ops/unary_op_grad_impl.cuh | 3 ++ .../cuda_impl/cuda_ops/unary_op_impl.cu | 3 +- .../gpu/kernel/nn/activation_grad_kernel.cc | 23 ++++++-- 5 files changed, 102 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc index 5c122997b26..470794e06ab 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc @@ -60,6 +60,7 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc { void ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; + void ComplexSigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; @@ -125,13 +126,29 @@ void EltWiseGradCpuTypeFunc::AbsGrad(const T *input1, const T *input2, T *out template void EltWiseGradCpuTypeFunc::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (!std::is_same::value) { - MS_LOG(EXCEPTION) << "For 'SigmoidGrad', the dtype of input must be float."; + if constexpr (std::is_same::value) { + int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "For 'SigmoidGrad', execute failed. Error no: " << ret; + } + } else { + for (size_t i = start; i < end; i++) { + T dividend = input2[i]; + T divisor = input1[i] * (static_cast(1) - input1[i]); + out[i] = dividend * divisor; + } } +} - int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "For 'SigmoidGrad', execute failed. Error no: " << ret; +template +void EltWiseGradCpuTypeFunc::ComplexSigmoidGrad(const T *input1, const T *input2, T *out, size_t start, + size_t end) const { + if constexpr ((std::is_same_v) || (std::is_same_v)) { + for (size_t i = start; i < end; i++) { + T dividend = input2[i]; + T divisor = std::conj(input1[i] * (static_cast(1) - input1[i])); + out[i] = dividend * divisor; + } } } @@ -440,7 +457,8 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc::InvGrad}, {prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc::AcoshGrad}, {prim::kPrimAbsGrad->name(), &EltWiseGradCpuTypeFunc::AbsGrad}, - {prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc::ReluGrad}}; + {prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc::ReluGrad}, + {prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc::SigmoidGrad}}; if (elt_map.find(kernel_name_) == elt_map.end()) { MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_ << " with double as input."; } @@ -452,7 +470,8 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c std::function> elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc::ReluGrad}, {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, - {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}}; + {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}, + {prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc::SigmoidGrad}}; if (elt_map.find(kernel_name_) == elt_map.end()) { MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input."; } @@ -519,7 +538,8 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc::InvGrad}, {prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc::SqrtGrad}, {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, - {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}}; + {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}, + {prim::kPrimSigmoidGrad->name(), &EltWiseGradCpuTypeFunc::ComplexSigmoidGrad}}; if (elt_map.find(kernel_name_) == elt_map.end()) { MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_; } @@ -580,7 +600,21 @@ static std::map>> ke &SpecializeEltWiseGradFunc}}}, {kSigmoidGrad, {{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - &SpecializeEltWiseGradFunc}}}, + &SpecializeEltWiseGradFunc}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &SpecializeEltWiseGradFunc}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &SpecializeEltWiseGradFunc}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &SpecializeEltWiseGradFunc}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &SpecializeEltWiseGradFunc}}}, {kSqrtGrad, {{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), &SpecializeEltWiseGradFunc}, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cu index 50688071b01..381788fcb08 100755 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cu @@ -199,6 +199,26 @@ __global__ void TanhGradKernel(const T *__restrict__ input, const T *dout, T *ou return; } +template +__global__ void SigmoidGradKernel(const T *__restrict__ input, const T *dout, T *output, const size_t count) { + const T one = static_cast(1); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + T divisor = input[i] * (one - input[i]); + output[i] = dout[i] * divisor; + } + return; +} + +template +__global__ void SigmoidGradKernel(const Complex *input, const Complex *dout, Complex *output, + const size_t count) { + Complex one = static_cast>(1); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + output[i] = dout[i] * conj(input[i] * (one - input[i])); + } + return; +} + template __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -352,6 +372,12 @@ void TanhGrad(const T *input, const T *dout, T *output, const size_t count, cuda return; } +template +void SigmoidGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + SigmoidGradKernel<<>>(input, dout, output, count); + return; +} + template void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { AsinhGradKernel<<>>(input, dout, output, count); @@ -403,6 +429,14 @@ template CUDA_LIB_EXPORT void SqrtGrad>(const Complex *i cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void SqrtGrad>(const Complex *input, const Complex *dout, + Complex *output, const size_t count, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void SigmoidGrad>(const Complex *input, const Complex *dout, + Complex *output, const size_t count, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void SigmoidGrad>(const Complex *input, const Complex *dout, Complex *output, const size_t count, cudaStream_t cuda_stream); @@ -426,6 +460,8 @@ template CUDA_LIB_EXPORT void ReciprocalGrad(const double *input, const const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void InvGrad(const double *input, const double *dout, double *output, const size_t count, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void SigmoidGrad(const double *input, const double *dout, double *output, + const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cuh index b153df0e73e..11909ea7bee 100755 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_grad_impl.cuh @@ -30,6 +30,9 @@ CUDA_LIB_EXPORT void AtanGrad(const T *input, const T *dout, T *output, const si template CUDA_LIB_EXPORT void TanhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); template +CUDA_LIB_EXPORT void SigmoidGrad(const T *input, const T *dout, T *output, const size_t count, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cu index 7631fafc3ea..ac6379c077f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cu @@ -1828,13 +1828,13 @@ template CUDA_LIB_EXPORT void Sigmoid>(const Complex *inpu const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void Square>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); - template CUDA_LIB_EXPORT void Tanh>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void Logarithm>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void Negative>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); + // complex128 template CUDA_LIB_EXPORT void Exponential>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); @@ -1880,6 +1880,7 @@ template CUDA_LIB_EXPORT void Logarithm>(const Complex * const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void Negative>(const Complex *input, Complex *output, const size_t count, cudaStream_t cuda_stream); + // bool template CUDA_LIB_EXPORT void Real(const bool *input, bool *output, const size_t count, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void Imag(const bool *input, bool *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/activation_grad_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/activation_grad_kernel.cc index 7cdf949d271..53adb8c8598 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/activation_grad_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/activation_grad_kernel.cc @@ -60,7 +60,19 @@ std::map}, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - &ActivationGradGpuKernelMod::LaunchKernel}}}}; + &ActivationGradGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &ActivationGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &ActivationGradGpuKernelMod::LaunchKernel>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &ActivationGradGpuKernelMod::LaunchKernel>}}}}; bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { @@ -100,9 +112,9 @@ bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons const auto dtype = inputs.at(kIndex0)->GetDtype(); if (((dtype == kNumberTypeFloat64) || (dtype == kNumberTypeComplex64) || (dtype == kNumberTypeComplex128)) && - (kernel_name_ != kTanhGrad)) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', only tanh support complex input, but got " << kernel_name_ - << " with dtype " << TypeIdLabel(inputs.at(kIndex0)->GetDtype()); + (kernel_name_ != kTanhGrad) && (kernel_name_ != kSigmoidGrad)) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', only tanh and sigmoid support complex input, but got " + << kernel_name_ << " with dtype " << TypeIdLabel(inputs.at(kIndex0)->GetDtype()); } return true; @@ -127,7 +139,7 @@ int ActivationGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con const auto dtype = inputs.at(kIndex0)->GetDtype(); if (((dtype == kNumberTypeFloat64) || (dtype == kNumberTypeComplex64) || (dtype == kNumberTypeComplex128)) && - (kernel_name_ == kTanhGrad)) { + ((kernel_name_ == kTanhGrad) || (kernel_name_ == kSigmoidGrad))) { // Does not call Cudnn return KRET_OK; } @@ -198,6 +210,7 @@ bool ActivationGradGpuKernelMod::LaunchKernel(const std::vector || std::is_same_v> || std::is_same_v>; if constexpr (use_unary) { TanhGrad(y, dy, dx, input_size_list_[0] / sizeof(T), reinterpret_cast(cuda_stream_)); + SigmoidGrad(y, dy, dx, input_size_list_[0] / sizeof(T), reinterpret_cast(cuda_stream_)); return true; }