forked from mindspore-Ecosystem/mindspore
!12844 Add float64 support to Absgrad and SqrtGrad
From: @peilin-wang Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
504f45566b
|
@ -170,6 +170,23 @@ void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count
|
|||
return;
|
||||
}
|
||||
|
||||
template void SqrtGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ACosGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AtanGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinhGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AcoshGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ReciprocalGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
|
|
|
@ -31,6 +31,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AbsGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
|
||||
// fp32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryGradOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -183,7 +183,8 @@ class SqrtGrad(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
valid_types = [mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue