From 5f7fc5a0b19566170fde0c09f392300645de4f06 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Wed, 17 Feb 2021 17:21:19 -0500 Subject: [PATCH] initial commit fix ci --- .../gpu/cuda_impl/cumsum_impl.cu | 14 +- .../gpu/math/cumsum_gpu_kernel.cc | 10 +- tests/st/ops/gpu/test_cumsum_op.py | 139 +++++++++++------- 3 files changed, 104 insertions(+), 59 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu index 30622c60e98..1b6ce2f3ec6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -137,6 +137,18 @@ void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, s return; } +template void CumSum(const uint8_t *input, uint8_t *output, uint8_t *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template void CumSum(const int8_t *input, int8_t *output, int8_t *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template void CumSum(const int32_t *input, int32_t *output, int32_t *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); +template void CumSum(const double *input, double *output, double *workspace, size_t dim0, size_t dim1, + size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_, + cudaStream_t stream); template void CumSum(const float *input, float *output, float *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream); template void CumSum(const half *input, half *output, half *workspace, size_t dim0, size_t dim1, size_t dim2, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc index e13ec638e8f..937f74c74df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,14 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + CumSumGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + CumSumGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + CumSumGpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CumSumGpuKernel, double) MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CumSumGpuKernel, float) MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/tests/st/ops/gpu/test_cumsum_op.py b/tests/st/ops/gpu/test_cumsum_op.py index c639c2952d8..c1c974bbfc7 100644 --- a/tests/st/ops/gpu/test_cumsum_op.py +++ b/tests/st/ops/gpu/test_cumsum_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,71 +22,66 @@ from mindspore import Tensor from mindspore.common.api import ms_function from mindspore.ops import operations as P -x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) -axis0 = 3 +def cum_sum(nptype): + context.set_context(device_target='GPU') + x0 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis0 = 3 -x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) -axis1 = 3 + x1 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis1 = 3 -x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) -axis2 = 2 + x2 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis2 = 2 -x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) -axis3 = 2 + x3 = np.random.rand(2, 3, 1, 4).astype(nptype) + axis3 = 2 -x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) -axis4 = 1 + x4 = np.random.rand(2, 3, 4, 4).astype(nptype) + axis4 = 1 -x5 = np.random.rand(2, 3).astype(np.float32) -axis5 = 1 + x5 = np.random.rand(2, 3).astype(nptype) + axis5 = 1 -x6 = np.random.rand(1, 1, 1, 1).astype(np.float32) -axis6 = 0 + x6 = np.random.rand(1, 1, 1, 1).astype(nptype) + axis6 = 0 -context.set_context(device_target='GPU') + class CumSum(nn.Cell): + def __init__(self, nptype): + super(CumSum, self).__init__() + + self.x0 = Tensor(x0) + self.axis0 = axis0 + + self.x1 = Tensor(x1) + self.axis1 = axis1 + + self.x2 = Tensor(x2) + self.axis2 = axis2 + + self.x3 = Tensor(x3) + self.axis3 = axis3 + + self.x4 = Tensor(x4) + self.axis4 = axis4 + + self.x5 = Tensor(x5) + self.axis5 = axis5 + + self.x6 = Tensor(x6) + self.axis6 = axis6 + + @ms_function + def construct(self): + return (P.CumSum()(self.x0, self.axis0), + P.CumSum()(self.x1, self.axis1), + P.CumSum()(self.x2, self.axis2), + P.CumSum()(self.x3, self.axis3), + P.CumSum()(self.x4, self.axis4), + P.CumSum()(self.x5, self.axis5), + P.CumSum()(self.x6, self.axis6)) -class CumSum(nn.Cell): - def __init__(self): - super(CumSum, self).__init__() - - self.x0 = Tensor(x0) - self.axis0 = axis0 - - self.x1 = Tensor(x1) - self.axis1 = axis1 - - self.x2 = Tensor(x2) - self.axis2 = axis2 - - self.x3 = Tensor(x3) - self.axis3 = axis3 - - self.x4 = Tensor(x4) - self.axis4 = axis4 - - self.x5 = Tensor(x5) - self.axis5 = axis5 - - self.x6 = Tensor(x6) - self.axis6 = axis6 - - @ms_function - def construct(self): - return (P.CumSum()(self.x0, self.axis0), - P.CumSum()(self.x1, self.axis1), - P.CumSum()(self.x2, self.axis2), - P.CumSum()(self.x3, self.axis3), - P.CumSum()(self.x4, self.axis4), - P.CumSum()(self.x5, self.axis5), - P.CumSum()(self.x6, self.axis6)) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_CumSum(): - cumsum = CumSum() + cumsum = CumSum(nptype) output = cumsum() expect0 = np.cumsum(x0, axis=axis0) @@ -130,3 +125,33 @@ def test_CumSum(): error6 = np.ones(shape=expect6.shape) * 1.0e-5 assert np.all(diff6 < error6) assert output[6].shape == expect6.shape + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cum_sum_uint8(): + cum_sum(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cum_sum_int8(): + cum_sum(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cum_sum_int32(): + cum_sum(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cum_sum_float16(): + cum_sum(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cum_sum_float32(): + cum_sum(np.float32)