!20514 add argminwithvalue_gpu_kernel operator
Merge pull request !20514 from 吴书全/PR0719
This commit is contained in:
commit
39744cda6c
|
@ -14,21 +14,34 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMaxWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArgmaxWithValueGpuKernel, double, int)
|
||||
ArgMaxAndMinWithValueGpuKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMaxWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArgmaxWithValueGpuKernel, float, int)
|
||||
ArgMaxAndMinWithValueGpuKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMaxWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArgmaxWithValueGpuKernel, half, int)
|
||||
ArgMaxAndMinWithValueGpuKernel, half, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMinWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArgMaxAndMinWithValueGpuKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMinWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArgMaxAndMinWithValueGpuKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ArgMinWithValue,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArgMaxAndMinWithValueGpuKernel, half, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,20 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class ArgmaxWithValueGpuKernel : public GpuKernel {
|
||||
class ArgMaxAndMinWithValueGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ArgmaxWithValueGpuKernel() { ResetResource(); }
|
||||
~ArgmaxWithValueGpuKernel() override = default;
|
||||
ArgMaxAndMinWithValueGpuKernel() { ResetResource(); }
|
||||
~ArgMaxAndMinWithValueGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -38,12 +40,14 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 1);
|
||||
S *index = GetDeviceAddress<S>(outputs, 0);
|
||||
CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output,
|
||||
CalGeneralReduction(small_, input, bound_, outerSize_, innerSize_, index, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
small_ = (kernel_name == "ArgMinWithValue") ? true : false;
|
||||
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
|
||||
int64_t dims = shape.size();
|
||||
|
@ -94,6 +98,7 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
private:
|
||||
bool small_ = false;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
|
@ -106,4 +111,4 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXWITHVALUEGPUKERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARGMAXANDMINWITHVALUEGPUKERNEL_H_
|
|
@ -303,7 +303,7 @@ void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_
|
|||
if (std::is_same<T, half>::value) {
|
||||
fp16_flag = true;
|
||||
}
|
||||
T init_K = small ? std::numeric_limits<T>::lowest() : std::numeric_limits<T>::lowest();
|
||||
T init_K = small ? std::numeric_limits<T>::max() : std::numeric_limits<T>::lowest();
|
||||
|
||||
if (bound <= kMaxThreadLoop) {
|
||||
ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit), kBlockSize, 0, stream>>>(
|
||||
|
|
|
@ -1849,7 +1849,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
TypeError: If `axis` is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetArgminWithValue(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetArgminWithValue, self).__init__()
|
||||
axis1 = 0
|
||||
axis2 = -1
|
||||
self.argmin1 = P.ArgMinWithValue(axis1)
|
||||
self.argmin2 = P.ArgMinWithValue(axis2)
|
||||
self.argmin3 = P.ArgMinWithValue()
|
||||
|
||||
def construct(self, x):
|
||||
return (self.argmin1(x), self.argmin2(x), self.argmin3(x))
|
||||
|
||||
|
||||
class NetArgminWithValueBig(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super(NetArgminWithValueBig, self).__init__()
|
||||
self.argmin = P.ArgMinWithValue(axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.argmin(x)
|
||||
|
||||
|
||||
def argminwithvalue_base(data_type):
|
||||
x = Tensor(np.array([[1., 20., 5.],
|
||||
[67., 8., 9.],
|
||||
[130., 24., 15.],
|
||||
[0.3, -0.4, -15.]]).astype(data_type))
|
||||
expect1 = np.array([3, 3, 3]).astype(data_type)
|
||||
expect2 = np.array([0, 1, 2, 2]).astype(data_type)
|
||||
expect11 = np.array([0.3, -0.4, -15.]).astype(data_type)
|
||||
expect22 = np.array([1., 8., 15., -15.]).astype(data_type)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
argmin = NetArgminWithValue()
|
||||
output = argmin(x)
|
||||
assert (output[0][0].asnumpy() == expect1).all()
|
||||
assert (output[0][1].asnumpy() == expect11).all()
|
||||
assert (output[1][0].asnumpy() == expect2).all()
|
||||
assert (output[1][1].asnumpy() == expect22).all()
|
||||
assert (output[2][0].asnumpy() == expect1).all()
|
||||
assert (output[2][1].asnumpy() == expect11).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
argmin = NetArgminWithValue()
|
||||
output = argmin(x)
|
||||
assert (output[0][0].asnumpy() == expect1).all()
|
||||
assert (output[0][1].asnumpy() == expect11).all()
|
||||
assert (output[1][0].asnumpy() == expect2).all()
|
||||
assert (output[1][1].asnumpy() == expect22).all()
|
||||
assert (output[2][0].asnumpy() == expect1).all()
|
||||
assert (output[2][1].asnumpy() == expect11).all()
|
||||
|
||||
|
||||
def argminwithvalue_3d(data_type, shape_x):
|
||||
np.random.seed(2)
|
||||
x_np = np.random.random(shape_x).astype(data_type)
|
||||
x = Tensor(x_np)
|
||||
|
||||
argmin = NetArgminWithValueBig(0)
|
||||
output = argmin(x)
|
||||
expect1 = np.argmin(x_np, axis=0)
|
||||
expect2 = np.minimum.reduce(x_np, 0)
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
argmin = NetArgminWithValueBig(1)
|
||||
output = argmin(x)
|
||||
expect1 = np.argmin(x_np, axis=1)
|
||||
expect2 = np.minimum.reduce(x_np, 1)
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
argmin = NetArgminWithValueBig(2)
|
||||
output = argmin(x)
|
||||
expect1 = np.argmin(x_np, axis=2)
|
||||
expect2 = np.minimum.reduce(x_np, 2)
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_base_float32():
|
||||
argminwithvalue_base(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_base_float16():
|
||||
argminwithvalue_base(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_3d_float32():
|
||||
shape_x = (2, 32, 256)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float32, shape_x)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float32, shape_x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_3d_float16():
|
||||
shape_x = (2, 64, 128)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float16, shape_x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argminwithvalue_3d_big_float32():
|
||||
shape_x = (128, 1024, 1)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float32, shape_x)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
argminwithvalue_3d(np.float32, shape_x)
|
Loading…
Reference in New Issue