From 5a268a5345147d0caff971a529e54b99564562ee Mon Sep 17 00:00:00 2001 From: wsq3 <877518222@qq.com> Date: Mon, 19 Jul 2021 20:10:20 +0800 Subject: [PATCH] add argminwithvalue oparetor arithmetric --- ...cc => argmaxandminwithvalue_gpu_kernel.cc} | 21 ++- ...l.h => argmaxandminwithvalue_gpu_kernel.h} | 19 ++- .../gpu/cuda_impl/general_reduction_impl.cu | 2 +- mindspore/ops/operations/array_ops.py | 2 +- tests/st/ops/gpu/test_argminwithvalue_op.py | 146 ++++++++++++++++++ 5 files changed, 177 insertions(+), 13 deletions(-) rename mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/{argmaxwithvalue_gpu_kernel.cc => argmaxandminwithvalue_gpu_kernel.cc} (58%) rename mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/{argmaxwithvalue_gpu_kernel.h => argmaxandminwithvalue_gpu_kernel.h} (82%) create mode 100644 tests/st/ops/gpu/test_argminwithvalue_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.cc similarity index 58% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc rename to mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.cc index 4ddbd35e073..5c97c8847ae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.cc @@ -14,21 +14,34 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h" namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_TWO( ArgMaxWithValue, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - ArgmaxWithValueGpuKernel, double, int) + ArgMaxAndMinWithValueGpuKernel, double, int) MS_REG_GPU_KERNEL_TWO( ArgMaxWithValue, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - ArgmaxWithValueGpuKernel, float, int) + ArgMaxAndMinWithValueGpuKernel, float, int) MS_REG_GPU_KERNEL_TWO( ArgMaxWithValue, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - ArgmaxWithValueGpuKernel, half, int) + ArgMaxAndMinWithValueGpuKernel, half, int) + +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + ArgMaxAndMinWithValueGpuKernel, double, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ArgMaxAndMinWithValueGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ArgMaxAndMinWithValueGpuKernel, half, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h similarity index 82% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h index 0859ad528ce..3806a86512e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h @@ -14,20 +14,22 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_ #include +#include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh" namespace mindspore { namespace kernel { template -class ArgmaxWithValueGpuKernel : public GpuKernel { +class ArgMaxAndMinWithValueGpuKernel : public GpuKernel { public: - ArgmaxWithValueGpuKernel() { ResetResource(); } - ~ArgmaxWithValueGpuKernel() override = default; + ArgMaxAndMinWithValueGpuKernel() { ResetResource(); } + ~ArgMaxAndMinWithValueGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -38,12 +40,14 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 1); S *index = GetDeviceAddress(outputs, 0); - CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output, + CalGeneralReduction(small_, input, bound_, outerSize_, innerSize_, index, output, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + small_ = (kernel_name == "ArgMinWithValue") ? true : false; std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); int64_t dims = shape.size(); @@ -94,6 +98,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { } private: + bool small_ = false; size_t input_size_; size_t output_size_; std::vector input_size_list_; @@ -106,4 +111,4 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu index 534b3745968..05da5996ada 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu @@ -303,7 +303,7 @@ void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_ if (std::is_same::value) { fp16_flag = true; } - T init_K = small ? std::numeric_limits::lowest() : std::numeric_limits::lowest(); + T init_K = small ? std::numeric_limits::max() : std::numeric_limits::lowest(); if (bound <= kMaxThreadLoop) { ThreadReduction<<>>( diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 1b6fc447a16..b45b871c83c 100755 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1849,7 +1849,7 @@ class ArgMinWithValue(PrimitiveWithInfer): TypeError: If `axis` is not an int. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32) diff --git a/tests/st/ops/gpu/test_argminwithvalue_op.py b/tests/st/ops/gpu/test_argminwithvalue_op.py new file mode 100644 index 00000000000..3a4fcd5ba92 --- /dev/null +++ b/tests/st/ops/gpu/test_argminwithvalue_op.py @@ -0,0 +1,146 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetArgminWithValue(nn.Cell): + def __init__(self): + super(NetArgminWithValue, self).__init__() + axis1 = 0 + axis2 = -1 + self.argmin1 = P.ArgMinWithValue(axis1) + self.argmin2 = P.ArgMinWithValue(axis2) + self.argmin3 = P.ArgMinWithValue() + + def construct(self, x): + return (self.argmin1(x), self.argmin2(x), self.argmin3(x)) + + +class NetArgminWithValueBig(nn.Cell): + def __init__(self, axis=0): + super(NetArgminWithValueBig, self).__init__() + self.argmin = P.ArgMinWithValue(axis) + + def construct(self, x): + return self.argmin(x) + + +def argminwithvalue_base(data_type): + x = Tensor(np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.], + [0.3, -0.4, -15.]]).astype(data_type)) + expect1 = np.array([3, 3, 3]).astype(data_type) + expect2 = np.array([0, 1, 2, 2]).astype(data_type) + expect11 = np.array([0.3, -0.4, -15.]).astype(data_type) + expect22 = np.array([1., 8., 15., -15.]).astype(data_type) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argmin = NetArgminWithValue() + output = argmin(x) + assert (output[0][0].asnumpy() == expect1).all() + assert (output[0][1].asnumpy() == expect11).all() + assert (output[1][0].asnumpy() == expect2).all() + assert (output[1][1].asnumpy() == expect22).all() + assert (output[2][0].asnumpy() == expect1).all() + assert (output[2][1].asnumpy() == expect11).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argmin = NetArgminWithValue() + output = argmin(x) + assert (output[0][0].asnumpy() == expect1).all() + assert (output[0][1].asnumpy() == expect11).all() + assert (output[1][0].asnumpy() == expect2).all() + assert (output[1][1].asnumpy() == expect22).all() + assert (output[2][0].asnumpy() == expect1).all() + assert (output[2][1].asnumpy() == expect11).all() + + +def argminwithvalue_3d(data_type, shape_x): + np.random.seed(2) + x_np = np.random.random(shape_x).astype(data_type) + x = Tensor(x_np) + + argmin = NetArgminWithValueBig(0) + output = argmin(x) + expect1 = np.argmin(x_np, axis=0) + expect2 = np.minimum.reduce(x_np, 0) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + argmin = NetArgminWithValueBig(1) + output = argmin(x) + expect1 = np.argmin(x_np, axis=1) + expect2 = np.minimum.reduce(x_np, 1) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + argmin = NetArgminWithValueBig(2) + output = argmin(x) + expect1 = np.argmin(x_np, axis=2) + expect2 = np.minimum.reduce(x_np, 2) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argminwithvalue_base_float32(): + argminwithvalue_base(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argminwithvalue_base_float16(): + argminwithvalue_base(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argminwithvalue_3d_float32(): + shape_x = (2, 32, 256) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argminwithvalue_3d(np.float32, shape_x) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argminwithvalue_3d(np.float32, shape_x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argminwithvalue_3d_float16(): + shape_x = (2, 64, 128) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argminwithvalue_3d(np.float16, shape_x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argminwithvalue_3d_big_float32(): + shape_x = (128, 1024, 1) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argminwithvalue_3d(np.float32, shape_x) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argminwithvalue_3d(np.float32, shape_x)