From d65a5affba80d96139f73cf5189359c83af0ebb0 Mon Sep 17 00:00:00 2001 From: xcnick Date: Wed, 20 Jan 2021 14:40:59 +0800 Subject: [PATCH] fix cpu/gpu argmax op --- .../kernel_compiler/cpu/argmax_cpu_kernel.cc | 104 ++++++++++++------ .../kernel_compiler/cpu/argmax_cpu_kernel.h | 13 ++- .../gpu/arrays/argmax_gpu_kernel.cc | 8 +- .../gpu/arrays/argmax_gpu_kernel.h | 66 +++++------ .../gpu/cuda_impl/argmax_impl.cu | 92 +++++----------- .../gpu/cuda_impl/argmax_impl.cuh | 4 +- tests/st/ops/cpu/test_argmax_op.py | 62 ++++++++--- tests/st/ops/gpu/test_argmax_op.py | 72 ++++++++---- 8 files changed, 234 insertions(+), 187 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc index 513464f14ba..216cb266ee3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc @@ -18,48 +18,82 @@ namespace mindspore { namespace kernel { -void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (shape.size() != 2) { - MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - - int64_t axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis != -1 && axis != 1) { - MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; +namespace { +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; } + return size; } -bool ArgmaxCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspaces*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; +template +bool check_validation(const std::vector &shape, const size_t num_before_axis, const size_t num_after_axis, + const std::vector &inputs, const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "Wrong number of inputs or outputs!"; + return false; + } + size_t data_size = sizeof(T); + size_t input_size = get_element_num(shape) * data_size; + size_t output_num = num_before_axis * num_after_axis; + size_t output_size = output_num * sizeof(int); + if (inputs[0]->size != input_size || outputs[0]->size != output_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + return false; + } + return true; +} +} // namespace + +template +void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t shape_len = shape_.size(); + int64_t axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + axis += shape_len; + if (axis < 0) { + MS_LOG(EXCEPTION) << "Invalid axis:" << axis << ", should in range [-1, " << shape_len - 1 << "]"; + } + axis = axis % static_cast(shape_len); + num_before_axis_ = 1; + num_after_axis_ = 1; + for (size_t i = 0; i < shape_len; i++) { + if (static_cast(i) < axis) { + num_before_axis_ *= shape_[i]; + } else if (static_cast(i) > axis) { + num_after_axis_ *= shape_[i]; + } + } + dim_axis_ = shape_[axis]; +} + +template +bool ArgmaxCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (!check_validation(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) { + return false; } - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "invalid input or output data size!"; - } - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - size_t row_start = 0; - for (size_t i = 0; i < batch_size_; ++i) { - size_t max_index = 0; - float max_value = input[row_start]; - for (size_t j = 1; j < class_num_; ++j) { - size_t index = row_start + j; - if (input[index] > max_value) { - max_value = input[index]; - max_index = j; + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + + for (size_t i = 0; i < num_before_axis_; i++) { + size_t src_index_i = i * dim_axis_ * num_after_axis_; + for (size_t j = 0; j < num_after_axis_; j++) { + std::vector array_axis; + size_t src_index_j = src_index_i + j; + for (size_t k = 0; k < dim_axis_; k++) { + size_t src_index_k = k * num_after_axis_ + src_index_j; + array_axis.push_back(static_cast(input[src_index_k])); } + auto max_ops = std::max_element(array_axis.begin(), array_axis.end()); + auto max_index = static_cast(std::distance(array_axis.begin(), max_ops)); + auto dst_index = i * num_after_axis_ + j; + output[dst_index] = max_index; } - output[i] = SizeToInt(max_index); - row_start += class_num_; } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h index dc2cfcefacf..33edc96df86 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h @@ -22,6 +22,7 @@ namespace mindspore { namespace kernel { +template class ArgmaxCPUKernel : public CPUKernel { public: ArgmaxCPUKernel() = default; @@ -33,12 +34,16 @@ class ArgmaxCPUKernel : public CPUKernel { const std::vector &outputs) override; private: - size_t class_num_{0}; - size_t batch_size_{0}; + std::vector shape_; + size_t num_before_axis_; + size_t num_after_axis_; + size_t dim_axis_; }; -MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxCPUKernel); +MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxCPUKernel, float); +MS_REG_CPU_KERNEL_T(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + ArgmaxCPUKernel, float16); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc index 881abe6d7db..d1be400473a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc @@ -18,9 +18,9 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, half) +MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, half, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h index 18b6f231f2a..ed5e0dbb5c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h @@ -23,11 +23,10 @@ #include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh" namespace mindspore { namespace kernel { -#define ARGMAX_MAX_DIMENSION 2 -template +template class ArgmaxGpuKernel : public GpuKernel { public: - ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} + ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), bound_(0), outer_size_(0), inner_size_(0) {} ~ArgmaxGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -37,47 +36,38 @@ class ArgmaxGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); - int *output = GetDeviceAddress(outputs, 0); - CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, - reinterpret_cast(stream_ptr)); + S *output = GetDeviceAddress(outputs, 0); + CalArgmax(input, bound_, outer_size_, inner_size_, output, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; - return false; + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + int64_t dims = shape.size(); + int64_t axis = GetAttr(kernel_node, "axis"); + if (axis < 0) { + axis += dims; } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; - return false; + input_size_ = sizeof(T); + for (auto x : shape) { + input_size_ *= x; } - auto output_type = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); - if (output_type->type_id() != TypeId::kNumberTypeInt32) { - MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; + output_size_ = sizeof(S); + for (auto x : output_shape) { + output_size_ *= x; } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > ARGMAX_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but Argmax supports max " << ARGMAX_MAX_DIMENSION - << "-D inputs."; + bound_ = static_cast(shape[axis]); + if (shape[axis] != static_cast(bound_)) { + MS_LOG(EXCEPTION) << "Bound's shape is larger than index type and overflows when casting."; } - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ += static_cast(input_shape.size()); + outer_size_ = 1; + for (int64_t i = axis - 1; i >= 0; i--) { + outer_size_ *= shape[i]; } - if (input_shape.size() == 1) { - batch_size_ = 0; - channel_size_ = input_shape[0]; - input_size_ = sizeof(T) * channel_size_; - output_size_ = sizeof(int); - } else { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - input_size_ = sizeof(T) * batch_size_ * channel_size_; - output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; + inner_size_ = 1; + for (int64_t i = axis + 1; i < dims; i++) { + inner_size_ *= shape[i]; } InitSizeLists(); return true; @@ -96,9 +86,9 @@ class ArgmaxGpuKernel : public GpuKernel { std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - int64_t axis_; + S bound_; + size_t outer_size_; + size_t inner_size_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu index 116bd9f42db..2bf80aab528 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu @@ -17,72 +17,36 @@ #include "argmax_impl.cuh" #include "runtime/device/gpu/cuda_common.h" #include "include/cuda_fp16.h" -template -__global__ void Argmax1D(const T *input, const int channel_size, int *output) { - int max_index = 0; - T max = input[0]; - for (int pos = 1; pos < channel_size; pos++) { - if (max < input[pos]) { - max = input[pos]; - max_index = pos; +template +__global__ void Argmax(const T *input, const S bound, const size_t outer_size, + const size_t inner_size, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size * inner_size; + pos += gridDim.x * blockDim.x) { + size_t x = pos / inner_size % outer_size; + size_t y = pos % inner_size; + S idx = 0; + size_t input_offset = x * bound * inner_size + 0 * inner_size + y; + T max_data = input[input_offset]; + for (S i = 1; i < bound; i++) { + input_offset = x * bound * inner_size + i * inner_size + y; + auto input_data = input[input_offset]; + idx = input_data > max_data ? i : idx; + max_data = input_data > max_data ? input_data : max_data; } - } - output[0] = max_index; - return; -} -template -__global__ void ArgmaxDefault2D(const T *input, const int batch_size, const int channel_size, int *output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { - max = input[i * channel_size]; - max_index = 0; - for (int j = 1; j < channel_size; j++) { - pos = i * channel_size + j; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - - output[i] = max_index; - } - return; -} -template -__global__ void ArgmaxAxis2D(const T *input, const int batch_size, const int channel_size, int *output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - max = input[i]; - max_index = 0; - for (int j = 1; j < batch_size; j++) { - pos = j * channel_size + i; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - output[i] = max_index; - } - return; -} -template -void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output, - cudaStream_t cuda_stream) { - if (batch_size == 0) { - Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); - } else if (axis == 1) { - ArgmaxDefault2D<<>>(input, batch_size, channel_size, output); - } else { - ArgmaxAxis2D<<>>(input, batch_size, channel_size, output); + output[pos] = idx; } return; } -template void CalArgmax(const float *input, const int batch_size, const int channel_size, const int64_t axis, - int *output, cudaStream_t cuda_stream); -template void CalArgmax(const half *input, const int batch_size, const int channel_size, const int64_t axis, - int *output, cudaStream_t cuda_stream); +template +void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, + S *output, cudaStream_t cuda_stream) { + Argmax<<>>(input, bound, outer_size, inner_size, + output); + return; +} + +template void CalArgmax(const float *input, const int bound, const size_t outer_size, + const size_t inner_size, int *output, cudaStream_t cuda_stream); +template void CalArgmax(const half *input, const int bound, const size_t outer_size, + const size_t inner_size, int *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh index ddebaca7e1b..5b80eb85b48 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh @@ -16,8 +16,8 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ -template -void CalArgmax(const T *input, const int batch_size, const int channel_size, const int64_t axis, int *output, +template +void CalArgmax(const T *input, const S bound, const size_t outer_size, const size_t inner_size, S *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_ARGMAX_IMPL_CUH_ diff --git a/tests/st/ops/cpu/test_argmax_op.py b/tests/st/ops/cpu/test_argmax_op.py index fdafd7750f8..6de1ad8d421 100644 --- a/tests/st/ops/cpu/test_argmax_op.py +++ b/tests/st/ops/cpu/test_argmax_op.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ +import random +from functools import reduce import numpy as np import pytest @@ -20,33 +22,59 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.ops import operations as P +import mindspore.ops as ops context.set_context(mode=context.GRAPH_MODE, device_target="CPU") class NetArgmax(nn.Cell): - def __init__(self): + def __init__(self, axis=0): super(NetArgmax, self).__init__() - self.argmax = P.Argmax(output_type=mstype.int32) - x = Tensor(np.array([[1., 20., 5.], - [67., 8., 9.], - [130., 24., 15.]]).astype(np.float32)) - self.x = Parameter(initializer(x, x.shape), name='x') + self.argmax = ops.Argmax(axis=axis, output_type=mstype.int32) - def construct(self): - return self.argmax(self.x) + def construct(self, x): + return self.argmax(x) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_argmax(): - Argmax = NetArgmax() - output = Argmax() - print("================================") - expect = np.array([1, 0, 0]).astype(np.float32) - print(output) +def test_argmax_1d(): + x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) + Argmax = NetArgmax(axis=0) + output = Argmax(x) + expect = np.array([1]).astype(np.float32) assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_argmax_2d(): + x = Tensor(np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.]]).astype(np.float32)) + Argmax_axis_0 = NetArgmax(axis=0) + output = Argmax_axis_0(x) + expect = np.array([2, 2, 2]).astype(np.float32) + assert (output.asnumpy() == expect).all() + Argmax_axis_1 = NetArgmax(axis=1) + output = Argmax_axis_1(x) + expect = np.array([1, 0, 0]).astype(np.float32) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_argmax_high_dims(): + for dim in range(3, 10): + shape = np.random.randint(1, 10, size=dim) + x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) + x = x.reshape(shape) + + rnd_axis = random.randint(-dim + 1, dim - 1) + Argmax = NetArgmax(axis=rnd_axis) + ms_output = Argmax(Tensor(x)) + np_output = np.argmax(x, axis=rnd_axis) + assert (ms_output.asnumpy() == np_output).all() diff --git a/tests/st/ops/gpu/test_argmax_op.py b/tests/st/ops/gpu/test_argmax_op.py index ccf492cfb41..b251cbe7d5f 100644 --- a/tests/st/ops/gpu/test_argmax_op.py +++ b/tests/st/ops/gpu/test_argmax_op.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ +import random +from functools import reduce import numpy as np import pytest @@ -20,43 +22,67 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype -from mindspore.ops import operations as P +import mindspore.ops as ops class NetArgmax(nn.Cell): - def __init__(self): + def __init__(self, axis=0): super(NetArgmax, self).__init__() - axis1 = 0 - axis2 = -1 - self.argmax1 = P.Argmax(axis1, output_type=mstype.int32) - self.argmax2 = P.Argmax(axis2, output_type=mstype.int32) - self.argmax3 = P.Argmax(output_type=mstype.int32) + self.argmax = ops.Argmax(axis, output_type=mstype.int32) def construct(self, x): - return (self.argmax1(x), self.argmax2(x), self.argmax3(x)) + return self.argmax(x) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_argmax(): +def test_argmax_1d(): + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + + x = Tensor(np.array([1., 20., 5.]).astype(np.float32)) + Argmax = NetArgmax(axis=0) + output = Argmax(x) + expect = np.array([1]).astype(np.float32) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmax_2d(): + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], [0.3, -0.4, -15.]]).astype(np.float32)) - expect1 = np.array([2, 2, 2]).astype(np.int32) - expect2 = np.array([1, 0, 0, 0]).astype(np.int32) + Argmax_axis_0 = NetArgmax(axis=0) + output = Argmax_axis_0(x) + expect = np.array([2, 2, 2]).astype(np.int32) + assert (output.asnumpy() == expect).all() - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - argmax = NetArgmax() - output = argmax(x) - assert (output[0].asnumpy() == expect1).all() - assert (output[1].asnumpy() == expect2).all() - assert (output[2].asnumpy() == expect2).all() + Argmax_axis_1 = NetArgmax(axis=1) + output = Argmax_axis_1(x) + expect = np.array([1, 0, 0, 0]).astype(np.int32) + assert (output.asnumpy() == expect).all() - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - argmax1 = NetArgmax() - output1 = argmax1(x) - assert (output1[0].asnumpy() == expect1).all() - assert (output1[1].asnumpy() == expect2).all() - assert (output1[2].asnumpy() == expect2).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmax_high_dims(): + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + for dim in range(3, 10): + shape = np.random.randint(1, 10, size=dim) + x = np.random.randn(reduce(lambda x, y: x * y, shape)).astype(np.float32) + x = x.reshape(shape) + + rnd_axis = random.randint(-dim + 1, dim - 1) + Argmax = NetArgmax(axis=rnd_axis) + ms_output = Argmax(Tensor(x)) + np_output = np.argmax(x, axis=rnd_axis) + assert (ms_output.asnumpy() == np_output).all()