diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 2c0e6f7905c..45603f8e73a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -87,6 +87,14 @@ struct FloorDivFunc { __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } }; +template +struct AbsGradFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { + T zero = 0.0; + return lhs < zero ? -rhs : rhs; + } +}; + template <> struct PowerFunc { @@ -149,6 +157,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const case BROADCAST_TYPE_FLOORDIV: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, output); + case BROADCAST_TYPE_ABSGRAD: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); } } @@ -192,6 +203,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const return NoBroadcastOperator>(nums, input0, input1, output); case BROADCAST_TYPE_FLOORDIV: return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_ABSGRAD: + return NoBroadcastOperator>(nums, input0, input1, output); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index e81cc16e33d..7d762c34d93 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -30,6 +30,7 @@ enum BroadcastOpType { BROADCAST_TYPE_SUB = 7, BROADCAST_TYPE_ADD = 8, BROADCAST_TYPE_FLOORDIV = 9, + BROADCAST_TYPE_ABSGRAD = 10, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index f5fffc0a4bd..7232e9a3f55 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -55,6 +55,10 @@ MS_REG_GPU_KERNEL_TWO( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) // fp16 MS_REG_GPU_KERNEL_TWO( @@ -93,6 +97,10 @@ MS_REG_GPU_KERNEL_TWO( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) // int32 MS_REG_GPU_KERNEL_TWO( @@ -113,5 +121,8 @@ MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 7cbc2f692e8..b6ac5a36887 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -96,10 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, - {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, }; auto iter = kBroadcastTypeMap.find(kernel_name);