!20514 add argminwithvalue_gpu_kernel operator

Merge pull request !20514 from 吴书全/PR0719
This commit is contained in:
i-robot 2021-07-20 22:14:33 +00:00 committed by Gitee
commit 39744cda6c
5 changed files with 177 additions and 13 deletions

View File

@ -14,21 +14,34 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h"
#include "backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
ArgmaxWithValueGpuKernel, double, int)
ArgMaxAndMinWithValueGpuKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
ArgmaxWithValueGpuKernel, float, int)
ArgMaxAndMinWithValueGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
ArgmaxWithValueGpuKernel, half, int)
ArgMaxAndMinWithValueGpuKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
ArgMaxAndMinWithValueGpuKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
ArgMaxAndMinWithValueGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
ArgMaxAndMinWithValueGpuKernel, half, int)
} // namespace kernel
} // namespace mindspore

View File

@ -14,20 +14,22 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class ArgmaxWithValueGpuKernel : public GpuKernel {
class ArgMaxAndMinWithValueGpuKernel : public GpuKernel {
public:
ArgmaxWithValueGpuKernel() { ResetResource(); }
~ArgmaxWithValueGpuKernel() override = default;
ArgMaxAndMinWithValueGpuKernel() { ResetResource(); }
~ArgMaxAndMinWithValueGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -38,12 +40,14 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 1);
S *index = GetDeviceAddress<S>(outputs, 0);
CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output,
CalGeneralReduction(small_, input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
small_ = (kernel_name == "ArgMinWithValue") ? true : false;
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
int64_t dims = shape.size();
@ -94,6 +98,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
}
private:
bool small_ = false;
size_t input_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
@ -106,4 +111,4 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_

View File

@ -303,7 +303,7 @@ void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_
if (std::is_same<T, half>::value) {
fp16_flag = true;
}
T init_K = small ? std::numeric_limits<T>::lowest() : std::numeric_limits<T>::lowest();
T init_K = small ? std::numeric_limits<T>::max() : std::numeric_limits<T>::lowest();
if (bound <= kMaxThreadLoop) {
ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit), kBlockSize, 0, stream>>>(

View File

@ -1849,7 +1849,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
TypeError: If `axis` is not an int.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetArgminWithValue(nn.Cell):
def __init__(self):
super(NetArgminWithValue, self).__init__()
axis1 = 0
axis2 = -1
self.argmin1 = P.ArgMinWithValue(axis1)
self.argmin2 = P.ArgMinWithValue(axis2)
self.argmin3 = P.ArgMinWithValue()
def construct(self, x):
return (self.argmin1(x), self.argmin2(x), self.argmin3(x))
class NetArgminWithValueBig(nn.Cell):
def __init__(self, axis=0):
super(NetArgminWithValueBig, self).__init__()
self.argmin = P.ArgMinWithValue(axis)
def construct(self, x):
return self.argmin(x)
def argminwithvalue_base(data_type):
x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.],
[130., 24., 15.],
[0.3, -0.4, -15.]]).astype(data_type))
expect1 = np.array([3, 3, 3]).astype(data_type)
expect2 = np.array([0, 1, 2, 2]).astype(data_type)
expect11 = np.array([0.3, -0.4, -15.]).astype(data_type)
expect22 = np.array([1., 8., 15., -15.]).astype(data_type)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmin = NetArgminWithValue()
output = argmin(x)
assert (output[0][0].asnumpy() == expect1).all()
assert (output[0][1].asnumpy() == expect11).all()
assert (output[1][0].asnumpy() == expect2).all()
assert (output[1][1].asnumpy() == expect22).all()
assert (output[2][0].asnumpy() == expect1).all()
assert (output[2][1].asnumpy() == expect11).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmin = NetArgminWithValue()
output = argmin(x)
assert (output[0][0].asnumpy() == expect1).all()
assert (output[0][1].asnumpy() == expect11).all()
assert (output[1][0].asnumpy() == expect2).all()
assert (output[1][1].asnumpy() == expect22).all()
assert (output[2][0].asnumpy() == expect1).all()
assert (output[2][1].asnumpy() == expect11).all()
def argminwithvalue_3d(data_type, shape_x):
np.random.seed(2)
x_np = np.random.random(shape_x).astype(data_type)
x = Tensor(x_np)
argmin = NetArgminWithValueBig(0)
output = argmin(x)
expect1 = np.argmin(x_np, axis=0)
expect2 = np.minimum.reduce(x_np, 0)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
argmin = NetArgminWithValueBig(1)
output = argmin(x)
expect1 = np.argmin(x_np, axis=1)
expect2 = np.minimum.reduce(x_np, 1)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
argmin = NetArgminWithValueBig(2)
output = argmin(x)
expect1 = np.argmin(x_np, axis=2)
expect2 = np.minimum.reduce(x_np, 2)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argminwithvalue_base_float32():
argminwithvalue_base(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argminwithvalue_base_float16():
argminwithvalue_base(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argminwithvalue_3d_float32():
shape_x = (2, 32, 256)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argminwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argminwithvalue_3d(np.float32, shape_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argminwithvalue_3d_float16():
shape_x = (2, 64, 128)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argminwithvalue_3d(np.float16, shape_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argminwithvalue_3d_big_float32():
shape_x = (128, 1024, 1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argminwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argminwithvalue_3d(np.float32, shape_x)