diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc new file mode 100644 index 00000000000..24c8a9a7301 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ArgmaxWithValueGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ArgmaxWithValueGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h new file mode 100644 index 00000000000..9d42f31eb6c --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh" +namespace mindspore { +namespace kernel { +template +class ArgmaxWithValueGpuKernel : public GpuKernel { + public: + ArgmaxWithValueGpuKernel() + : input_size_(0), + output_size_(0), + workspace_size_(0), + axis_(0), + dims_(1), + bound_(0), + outerSize_(0), + innerSize_(0) {} + ~ArgmaxWithValueGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 1); + S *index = GetDeviceAddress(outputs, 0); + CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, axis_, dims_, index, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); + dims_ = shape_.size(); + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ += dims_; + } + input_size_ = sizeof(T); + for (auto x : shape_) { + input_size_ *= x; + } + output_size_ = sizeof(S); + for (auto x : output_shape) { + output_size_ *= x; + } + bound_ = shape_[axis_]; + outerSize_ = 1; + for (int i = axis_ - 1; i >= 0; i--) { + outerSize_ *= shape_[i]; + } + + innerSize_ = 1; + for (int i = axis_ + 1; i < dims_; i++) { + innerSize_ *= shape_[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T)); + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::vector shape_; + int axis_; + int dims_; + int bound_; + int outerSize_; + int innerSize_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu new file mode 100644 index 00000000000..47a794cdcd9 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmaxwithvalue_impl.cuh" +#include "device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, int outerSize, int innerSize, + S* index, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + for (int i = 0; i < outerSize; i++) { + int inputOutterOffset = i * innerSize * bound; + int outputOutterOffset = i * innerSize; + for (int j = 0; j < innerSize; j++) { + auto outputInnerOffset = outputOutterOffset + j; + S idx = 0; + T maxData = input[j + inputOutterOffset]; + for (S c = 0; c < bound; c++) { + int offset = j + c * innerSize; + auto inputData = input[inputOutterOffset + offset]; + idx = inputData > maxData ? c : idx; + maxData = inputData > maxData ? inputData : maxData; + } + output[outputInnerOffset] = maxData; + index[outputInnerOffset] = idx; + } + } + } + return; +} + +template +void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, + int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream) { + ArgmaxWithValue<<>>(size, input, bound_, outerSize_, innerSize_, + index, output); + return; +} + +template void CalArgmaxWithValue(size_t size, const float* input, const int bound_, const int outerSize_, + const int innerSize_, int axis_, int dims_, int* index, float* output, + cudaStream_t cuda_stream); +template void CalArgmaxWithValue(size_t size, const half* input, const int bound_, const int outerSize_, + const int innerSize_, int axis_, int dims_, int* index, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh new file mode 100644 index 00000000000..eebe4c8fa63 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh @@ -0,0 +1,22 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ +template +void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, + int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ diff --git a/tests/st/ops/gpu/test_argmaxwithvalue_op.py b/tests/st/ops/gpu/test_argmaxwithvalue_op.py new file mode 100644 index 00000000000..6ce729a6cbc --- /dev/null +++ b/tests/st/ops/gpu/test_argmaxwithvalue_op.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetArgmaxWithValue(nn.Cell): + def __init__(self): + super(NetArgmaxWithValue, self).__init__() + axis1 = 0 + axis2 = -1 + self.argmax1 = P.ArgMaxWithValue(axis1) + self.argmax2 = P.ArgMaxWithValue(axis2) + self.argmax3 = P.ArgMaxWithValue() + + def construct(self, x): + return (self.argmax1(x), self.argmax2(x), self.argmax3(x)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue(): + x = Tensor(np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.], + [0.3, -0.4, -15.]]).astype(np.float32)) + expect1 = np.array([2, 2, 2]).astype(np.float32) + expect2 = np.array([1, 0, 0, 0]).astype(np.float32) + expect11 = np.array([130, 24, 15]).astype(np.float32) + expect22 = np.array([20, 67, 130, 0.3]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argmax = NetArgmaxWithValue() + output = argmax(x) + assert (output[0][0].asnumpy() == expect1).all() + assert (output[0][1].asnumpy() == expect11).all() + assert (output[1][0].asnumpy() == expect2).all() + assert (output[1][1].asnumpy() == expect22).all() + assert (output[2][0].asnumpy() == expect1).all() + assert (output[2][1].asnumpy() == expect11).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argmax = NetArgmaxWithValue() + output = argmax(x) + assert (output[0][0].asnumpy() == expect1).all() + assert (output[0][1].asnumpy() == expect11).all() + assert (output[1][0].asnumpy() == expect2).all() + assert (output[1][1].asnumpy() == expect22).all() + assert (output[2][0].asnumpy() == expect1).all() + assert (output[2][1].asnumpy() == expect11).all()