forked from mindspore-Ecosystem/mindspore
!12844 Add float64 support to Absgrad and SqrtGrad
From: @peilin-wang Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
504f45566b
|
@ -170,6 +170,23 @@ void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template void SqrtGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void RsqrtGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void AsinGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void ACosGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void AtanGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void AsinhGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void AcoshGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void ReciprocalGrad<double>(const double *input, const double *dout, double *output, const size_t count,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||||
|
|
|
@ -31,6 +31,10 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
BroadcastOpGpuKernel, double)
|
BroadcastOpGpuKernel, double)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
AbsGrad,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
BroadcastOpGpuKernel, double)
|
||||||
|
|
||||||
// fp32
|
// fp32
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
SqrtGrad,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
UnaryGradOpGpuKernel, double)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
SqrtGrad,
|
SqrtGrad,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
|
|
@ -183,7 +183,8 @@ class SqrtGrad(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, dout_dtype):
|
def infer_dtype(self, x_dtype, dout_dtype):
|
||||||
args = {"x": x_dtype, "dout": dout_dtype}
|
args = {"x": x_dtype, "dout": dout_dtype}
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
valid_types = [mstype.float16, mstype.float32, mstype.float64]
|
||||||
|
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue