From a0645c41fe5a7de3da6450fcdd16e8001cf44a1b Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Wed, 3 Mar 2021 17:18:00 -0500 Subject: [PATCH] add float64 support to absgrad and sqrtgrad --- .../gpu/cuda_impl/unary_op_grad_impl.cu | 17 +++++++++++++++++ .../gpu/math/broadcast_gpu_kernel.cc | 4 ++++ .../gpu/math/unary_op_grad_gpu_kernel.cc | 4 ++++ mindspore/ops/operations/_grad_ops.py | 3 ++- 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu index 3cc485b752a..4beb1e0c58a 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu @@ -170,6 +170,23 @@ void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count return; } +template void SqrtGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void RsqrtGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void AsinGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void ACosGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void AtanGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void AsinhGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void AcoshGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); +template void ReciprocalGrad(const double *input, const double *dout, double *output, const size_t count, + cudaStream_t cuda_stream); + template void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 8d232afcd7f..7c191a9e0b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -31,6 +31,10 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), BroadcastOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE( + AbsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + BroadcastOpGpuKernel, double) // fp32 MS_REG_GPU_KERNEL_ONE( diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc index 116168f9e50..ece51aa3004 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc @@ -18,6 +18,10 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SqrtGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryGradOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE( SqrtGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 02f17e6f41c..5aa8b7d36f9 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -183,7 +183,8 @@ class SqrtGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, dout_dtype): args = {"x": x_dtype, "dout": dout_dtype} - validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + valid_types = [mstype.float16, mstype.float32, mstype.float64] + validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) return x_dtype