From ba9ae8097f4df74450c8ec3df2fc12677cfc68d8 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Thu, 18 Feb 2021 15:39:42 -0500 Subject: [PATCH] removed zeroslike from math/unary_op_gpu_kernel and added float64 support fix ci fix ci --- .../gpu/arrays/zeroslike_gpu_kernel.cc | 4 +- .../gpu/cuda_impl/unary_op_impl.cu | 16 -------- .../gpu/cuda_impl/unary_op_impl.cuh | 4 +- .../gpu/math/unary_op_gpu_kernel.cc | 4 -- .../gpu/math/unary_op_gpu_kernel.h | 41 +++++++------------ tests/st/ops/gpu/test_zeroslike_op.py | 15 ++++--- 6 files changed, 28 insertions(+), 56 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc index 74d9c865232..22a8d1b68a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/zeroslike_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,5 +37,7 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).A MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ZerosLikeGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ZerosLikeGpuKernel, double) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 64a4008ca90..06759a050e2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -196,14 +196,6 @@ __global__ void AtanKernel(const T *input, T *output, const size_t count) { return; } template -__global__ void ZeroslikeKernel(T *output, const size_t count) { - T zero = 0.0; - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = zero; - } - return; -} -template __global__ void AbsKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = abs(input[i]); @@ -328,11 +320,6 @@ void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stre return; } template -void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream) { - ZeroslikeKernel<<>>(output, count); - return; -} -template void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { AbsKernel<<>>(input, output, count); return; @@ -362,7 +349,6 @@ template void Atan(const double *input, double *output, const size_t cou template void Asinh(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Acosh(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Rsqrt(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); -template void Zeroslike(double *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); @@ -386,7 +372,6 @@ template void Atan(const float *input, float *output, const size_t count, template void Asinh(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Acosh(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Rsqrt(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); -template void Zeroslike(float *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); @@ -409,6 +394,5 @@ template void Atan(const half *input, half *output, const size_t count, cu template void Asinh(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Acosh(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Rsqrt(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); -template void Zeroslike(half *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index e0347a1d93a..fffe32ca52b 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,8 +55,6 @@ void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stre template void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template -void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream); -template void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 89599fd637b..3079fd1ca6c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -52,10 +52,6 @@ MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32). UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 29c4d12b894..8c859def140 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,6 @@ enum UnaryOptype { UNARY_OP_ERFC, UNARY_OP_NEG, UNARY_OP_RECIPROCAL, - UNARY_OP_ZEROSLIKE, UNARY_OP_SQUARE, UNARY_OP_SQRT, UNARY_OP_RSQRT, @@ -51,27 +50,19 @@ enum UnaryOptype { UNARY_OP_FLOOR, UNARY_OP_INVALID_TYPE = 255 }; -static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, - {"Expm1", UNARY_OP_EXPM1}, - {"Log", UNARY_OP_LOG}, - {"Log1p", UNARY_OP_LOG1P}, - {"Erf", UNARY_OP_ERF}, - {"Erfc", UNARY_OP_ERFC}, - {"Neg", UNARY_OP_NEG}, - {"Reciprocal", UNARY_OP_RECIPROCAL}, - {"ZerosLike", UNARY_OP_ZEROSLIKE}, - {"Square", UNARY_OP_SQUARE}, - {"Sqrt", UNARY_OP_SQRT}, - {"Rsqrt", UNARY_OP_RSQRT}, - {"Sin", UNARY_OP_SIN}, - {"Cos", UNARY_OP_COS}, - {"Asin", UNARY_OP_ASIN}, - {"ACos", UNARY_OP_ACOS}, - {"Atan", UNARY_OP_ATAN}, - {"Asinh", UNARY_OP_ASINH}, - {"Acosh", UNARY_OP_ACOSH}, - {"Abs", UNARY_OP_ABS}, - {"Floor", UNARY_OP_FLOOR}}; + +static const std::map kUnaryOpTypeMap = { + {"Exp", UNARY_OP_EXP}, {"Expm1", UNARY_OP_EXPM1}, + {"Log", UNARY_OP_LOG}, {"Log1p", UNARY_OP_LOG1P}, + {"Erf", UNARY_OP_ERF}, {"Erfc", UNARY_OP_ERFC}, + {"Neg", UNARY_OP_NEG}, {"Reciprocal", UNARY_OP_RECIPROCAL}, + {"Square", UNARY_OP_SQUARE}, {"Sqrt", UNARY_OP_SQRT}, + {"Rsqrt", UNARY_OP_RSQRT}, {"Sin", UNARY_OP_SIN}, + {"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN}, + {"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN}, + {"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH}, + {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; + template class UnaryOpGpuKernel : public GpuKernel { public: @@ -160,10 +151,6 @@ class UnaryOpGpuKernel : public GpuKernel { Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } - case UNARY_OP_ZEROSLIKE: { - Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); - return true; - } case UNARY_OP_ABS: { Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; diff --git a/tests/st/ops/gpu/test_zeroslike_op.py b/tests/st/ops/gpu/test_zeroslike_op.py index 25fec97be96..0af1d363f89 100644 --- a/tests/st/ops/gpu/test_zeroslike_op.py +++ b/tests/st/ops/gpu/test_zeroslike_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,6 @@ from mindspore import Tensor from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner -context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - - class NetZerosLike(nn.Cell): def __init__(self): super(NetZerosLike, self).__init__() @@ -109,7 +106,6 @@ def test_zeros_like_dynamic_int8(): x = Tensor(np.arange(24).reshape(1, 4, 1, 6).astype(np.int8)) output = zeros_like_dynamic(x) expected = np.zeros([1, 4, 1, 6]) - print(output) np.testing.assert_array_equal(output.asnumpy(), expected) @pytest.mark.level0 @@ -148,6 +144,15 @@ def test_zeros_like_dynamic_float32(): expected = np.zeros([3, 7, 3]) np.testing.assert_array_almost_equal(output.asnumpy(), expected) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_zeros_like_dynamic_float64(): + x = Tensor(np.arange(2).reshape(2, 1, 1).astype(np.float64)) + output = zeros_like_dynamic(x) + expected = np.zeros([2, 1, 1]) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard