adding SearchSorted operator to CPU

This commit is contained in:
huangbo77 2021-06-03 10:24:52 +08:00
parent ab599aa23b
commit 04d7094aff
5 changed files with 392 additions and 1 deletions

View File

@ -0,0 +1,109 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/searchsorted_cpu_kernel.h"
#include <vector>
#include <numeric>
#include <functional>
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputSize = 2;
constexpr size_t kOutputSize = 1;
} // namespace
template <typename S, typename T>
void SearchSortedCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
right_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "right");
sequence_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
values_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
search_len = sequence_shape_.back();
}
template <typename S, typename T>
const S *SearchSortedCPUKernel<S, T>::CustomizedLowerBound(const S *seq_start, const S *seq_end, const S key) {
while (seq_start < seq_end) {
const S *mid = seq_start + ((seq_end - seq_start) >> 1);
if (!(key <= *mid)) {
seq_start = mid + 1;
} else {
seq_end = mid;
}
}
return seq_start;
}
template <typename S, typename T>
bool SearchSortedCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParam(inputs, outputs);
auto sequence = reinterpret_cast<S *>(inputs[0]->addr);
auto values = reinterpret_cast<S *>(inputs[1]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
size_t elem_num = inputs[1]->size / sizeof(S);
size_t seq_dim = sequence_shape_.size();
size_t search_repeat = values_shape_.back();
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto seq_start = (seq_dim == 1) ? sequence : sequence + (i / search_repeat) * search_len;
output[i] = right_ ? std::upper_bound(seq_start, seq_start + search_len, values[i]) - seq_start
: CustomizedLowerBound(seq_start, seq_start + search_len, values[i]) - seq_start;
}
};
CPUKernelUtils::ParallelFor(task, elem_num);
return true;
}
template <typename S, typename T>
void SearchSortedCPUKernel<S, T>::CheckParam(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
// inputs: sequence, values
if (inputs.size() != kInputSize) {
MS_LOG(EXCEPTION) << "Input number is: " << inputs.size() << ", but SearchSorted needs" << kInputSize << " inputs.";
}
// outputs: positions
if (outputs.size() != kOutputSize) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but SearchSorted needs " << kOutputSize
<< " outputs";
}
if (outputs[0]->size / sizeof(T) != inputs[1]->size / sizeof(S)) {
MS_LOG(EXCEPTION) << "The output dimensions " << outputs[0]->size << " must match the dimensions of input values "
<< inputs[1]->size;
}
auto sequence = reinterpret_cast<S *>(inputs[0]->addr);
size_t list_count = accumulate(sequence_shape_.begin(), sequence_shape_.end() - 1, 1, std::multiplies<int>());
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
for (size_t j = 0; j < search_len - 1; j++) {
if (sequence[i * search_len + j] > sequence[i * search_len + j + 1]) {
MS_LOG(EXCEPTION) << "The input sequence must be sorted.";
}
}
}
};
CPUKernelUtils::ParallelFor(task, list_count);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename S, typename T>
class SearchSortedCPUKernel : public CPUKernel {
public:
SearchSortedCPUKernel() = default;
~SearchSortedCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
const S *CustomizedLowerBound(const S *seq_start, const S *seq_end, const S key);
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
bool right_{false};
std::vector<size_t> sequence_shape_;
std::vector<size_t> values_shape_;
std::vector<size_t> output_shape_;
size_t search_len;
};
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
SearchSortedCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, double, int64_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
SearchSorted,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
SearchSortedCPUKernel, int8_t, int64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SEARCHSORTED_CPU_KERNEL_H_

View File

@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect)
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedSelect, SearchSorted)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd,
@ -438,6 +438,7 @@ __all__ = [
"SparseTensorDenseMatmul",
"MatrixInverse",
"Range",
"SearchSorted",
"IndexAdd",
"PQC",
"Evolution",

View File

@ -5338,3 +5338,58 @@ class MaskedSelect(PrimitiveWithCheck):
def check_dtype(self, x_dtype, mask_dtype):
validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
class SearchSorted(PrimitiveWithInfer):
"""
Find the indices from the innermost dimension of `sequence` such that the order of the innermost dimension
within `sequence` would be preserved when the corresponding values in `values` were inserted before the indices.
Args:
out_int32 (bool): Output datatype. Optional. If True, the output datatype will be int32;
if False, the output datatype will be int64. Default is False.
right (bool): Search Strategy. Optional. If True, return the last suitable index found.
If False, return the first such index. Default is False.
Inputs:
- **sequence** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R-1, x_R)` or `(x_1)`.
It must contain monitonically increasing sequence on the innermost dimension.
- **values** (Tensor) - The shape of tensor is : math:`(x_1, x_2, ..., x_R-1, x_S)`.
Outputs:
Tensor containing the indices from the innermost dimension of the input sequence such that,
if insert the corresponding value in the values tensor, the order of the tensor sequence would be preserved.
The shape of tensor is :math:`(x_1, x_2, ..., x_R-1, x_S)`,
whose datatype is int32 if out_int32 is True, otherwise int64, and shape is the same as the shape of values.
Raises:
ValueError: If `sequence` and `values` do not have proper shapes.
Supported Platforms:
``CPU``
Examples:
>>> sequence = Tensor(np.array([[0, 1, 3, 5, 7], [2, 4, 6, 8, 10]]), mindspore.float32)
>>> values = Tensor(np.array([[3, 6, 9], [3, 6, 9]]), mindspore.float32)
>>> output = ops.SearchSorted()(sequence, values)
>>> print(output)
[[2, 4, 5]
[1, 2, 4]]
"""
@prim_attr_register
def __init__(self, out_int32=False, right=False):
"""Initialize SearchSorted"""
self.out_int32 = validator.check_value_type("out_int32", out_int32, [bool], self.name)
self.right = validator.check_value_type("right", right, [bool], self.name)
self.init_prim_io_names(inputs=['sequence', 'values'], outputs=['positions'])
def infer_shape(self, sequence_shape, values_shape):
if len(sequence_shape) != 1 and sequence_shape[:-1] != values_shape[:-1]:
raise ValueError(f"Sequence should be 1 dimensional or has all but the last dimension matching "
f" the dimensions of values, but got sequence's dimensions: {sequence_shape} "
f"and values' dimensions: {values_shape}.")
return values_shape
def infer_dtype(self, sequence_dtype, values_dtype):
args = {"sequence_dtype": sequence_dtype, "values_dtype": values_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return mstype.tensor_type(mstype.int32) if self.out_int32 else mstype.tensor_type(mstype.int64)

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class SearchSortedNet(nn.Cell):
def __init__(self, out_int32=False, right=False):
super(SearchSortedNet, self).__init__()
self.searchsorted = P.SearchSorted(out_int32=out_int32, right=right)
def construct(self, sequence, values):
return self.searchsorted(sequence, values)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_right_out32():
np.random.seed(1)
input1 = np.sort(np.array(np.random.randint(10, size=(2, 3, 9)), dtype=np.int32), axis=-1)
sequence = Tensor(input1, mstype.int32)
input2 = np.array(np.random.randint(10, size=(2, 3, 1)), dtype=np.int32)
values = Tensor(input2, mstype.int32)
net = SearchSortedNet(out_int32=True, right=True)
output = net(sequence, values)
expect = [[[9],
[3],
[6]],
[[5],
[9],
[8]]]
assert output.dtype == mstype.int32
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_out32():
np.random.seed(1)
input1 = np.sort(np.array(np.random.randint(10, size=(2, 3, 9)), dtype=np.int64), axis=-1)
sequence = Tensor(input1, mstype.int64)
input2 = np.array(np.random.randint(10, size=(2, 3, 1)), dtype=np.int64)
values = Tensor(input2, mstype.int64)
net = SearchSortedNet(out_int32=True, right=False)
output = net(sequence, values)
expect = [[[8],
[0],
[3]],
[[5],
[8],
[7]]]
assert output.dtype == mstype.int32
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_right_out64():
np.random.seed(1)
input1 = np.sort(np.array(np.random.random((2, 5)), dtype=np.float32), axis=-1)
sequence = Tensor(input1, mstype.float32)
input2 = np.array(np.random.random((2, 3)), dtype=np.float32)
values = Tensor(input2, mstype.float32)
net = SearchSortedNet(out_int32=False, right=True)
output = net(sequence, values)
expect = [[4, 4, 2],
[5, 0, 5]]
assert output.dtype == mstype.int64
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_out64():
np.random.seed(1)
input1 = np.sort(np.array(np.random.random((5)), dtype=np.float64), axis=-1)
sequence = Tensor(input1, mstype.float64)
input2 = np.array(np.random.random((2, 3)), dtype=np.float64)
values = Tensor(input2, mstype.float64)
net = SearchSortedNet(out_int32=False, right=False)
output = net(sequence, values)
expect = [[1, 2, 3],
[3, 4, 4]]
assert output.dtype == mstype.int64
assert (output.asnumpy() == expect).all()