forked from mindspore-Ecosystem/mindspore
!21004 Add CPU Sort OP
Merge pull request !21004 from zhujingxuan/master
This commit is contained in:
commit
27e0c7d8b8
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/sort_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void SortCpuKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SortCpuKernel expects 1.";
|
||||
}
|
||||
|
||||
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_count != 2) {
|
||||
MS_LOG(EXCEPTION) << "Number of outputs is " << output_count << ", but should be 2 for SortCpuKernel.";
|
||||
}
|
||||
|
||||
auto x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto input_rank = x_shape_.size();
|
||||
|
||||
descending_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "descending");
|
||||
auto axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis");
|
||||
if (axis < 0) {
|
||||
axis += input_rank;
|
||||
}
|
||||
|
||||
if ((axis < 0) || (axis >= static_cast<int64_t>(input_rank))) {
|
||||
MS_LOG(EXCEPTION) << "evaluated axis is " << axis << ", but should be in range [0, " << input_rank << "].";
|
||||
}
|
||||
|
||||
size_t axis_t = axis;
|
||||
|
||||
outer_size_ = 1;
|
||||
for (size_t i = 0; i < axis_t; i++) {
|
||||
outer_size_ *= x_shape_[i];
|
||||
}
|
||||
|
||||
axis_size_ = x_shape_[axis_t];
|
||||
|
||||
inner_size_ = 1;
|
||||
for (size_t i = axis_t + 1; i < input_rank; ++i) {
|
||||
inner_size_ *= x_shape_[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SortCpuKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "TopK needs 1 input and 2 outputs, but get inputs: " << inputs.size()
|
||||
<< "outputs: " << outputs.size();
|
||||
}
|
||||
if (inputs[0]->size != outer_size_ * axis_size_ * inner_size_ * sizeof(T)) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto indices = reinterpret_cast<int *>(outputs[1]->addr);
|
||||
|
||||
if (outputs[0]->size != inputs[0]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error output data size!";
|
||||
}
|
||||
|
||||
std::function<bool(size_t, size_t)> comparator;
|
||||
if (descending_) {
|
||||
comparator = [&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; };
|
||||
} else {
|
||||
comparator = [&input](size_t index_1, size_t index_2) { return input[index_1] < input[index_2]; };
|
||||
}
|
||||
|
||||
std::vector<common::Task> tasks;
|
||||
tasks.reserve(outer_size_ * inner_size_);
|
||||
for (size_t i = 0; i < outer_size_; ++i) {
|
||||
const auto out_offset = i * axis_size_ * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; ++j) {
|
||||
const auto axis_offset = out_offset + j;
|
||||
auto task = [this, axis_offset, &input, &indices, &output, &comparator]() {
|
||||
std::vector<size_t> idx(axis_size_);
|
||||
// fill idx starts with out_offset + j, step inner_size_
|
||||
for (size_t k = 0; k < axis_size_; ++k) {
|
||||
idx[k] = axis_offset + k * inner_size_;
|
||||
}
|
||||
|
||||
std::stable_sort(idx.begin(), idx.end(), comparator);
|
||||
|
||||
for (size_t k = 0; k < axis_size_; ++k) {
|
||||
const auto index = axis_offset + k * inner_size_;
|
||||
indices[index] = (SizeToInt(idx[k]) - SizeToInt(axis_offset)) / inner_size_;
|
||||
output[index] = input[idx[k]];
|
||||
}
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(task);
|
||||
}
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SortCpuKernel : public CPUKernel {
|
||||
public:
|
||||
SortCpuKernel() = default;
|
||||
~SortCpuKernel() = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
size_t inner_size_{1};
|
||||
size_t outer_size_{1};
|
||||
size_t axis_size_{1};
|
||||
bool descending_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sort, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
SortCpuKernel, float16)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sort, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
SortCpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class SortNet(nn.Cell):
|
||||
def __init__(self, axis, descending):
|
||||
super(SortNet, self).__init__()
|
||||
self.sort = P.Sort(axis, descending)
|
||||
|
||||
def construct(self, x):
|
||||
return self.sort(x)
|
||||
|
||||
|
||||
def sort_1d(descending, nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
x_numpy = np.array([1, -2, 3, 4]).astype(nptype)
|
||||
x = Tensor(x_numpy)
|
||||
sort_net = SortNet(0, descending)
|
||||
output, indices = sort_net(x)
|
||||
|
||||
expected_output = np.sort(x_numpy, 0)
|
||||
expected_indices = np.array([1, 0, 2, 3])
|
||||
if descending:
|
||||
expected_output = expected_output[::-1]
|
||||
expected_indices = expected_indices[::-1]
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
np.testing.assert_array_equal(indices.asnumpy(), expected_indices)
|
||||
|
||||
def sort_3d(descending, nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
x_numpy = np.array([[[1, 2, 3, 4],
|
||||
[8, 7, 2, 0],
|
||||
[9, 4, 1, 8]],
|
||||
[[5, 4, 1, 8],
|
||||
[2, 9, 0, 7],
|
||||
[6, 1, 7, 4]]]).astype(nptype)
|
||||
x = Tensor(x_numpy)
|
||||
|
||||
axis = -1
|
||||
sort_net = SortNet(axis, descending)
|
||||
output, indices = sort_net(x)
|
||||
|
||||
expected_output = np.sort(x_numpy, axis)
|
||||
expected_indices = np.array([[[0, 1, 2, 3],
|
||||
[3, 2, 1, 0],
|
||||
[2, 1, 3, 0]],
|
||||
[[2, 1, 0, 3],
|
||||
[2, 0, 3, 1],
|
||||
[1, 3, 0, 2]]])
|
||||
if descending:
|
||||
expected_output = expected_output[:, :, ::-1]
|
||||
expected_indices = expected_indices[:, :, ::-1]
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
np.testing.assert_array_equal(indices.asnumpy(), expected_indices)
|
||||
|
||||
axis = 1
|
||||
sort_net = SortNet(axis, descending)
|
||||
output, indices = sort_net(x)
|
||||
|
||||
expected_output = np.sort(x_numpy, axis)
|
||||
expected_indices = np.array([[[0, 0, 2, 1],
|
||||
[1, 2, 1, 0],
|
||||
[2, 1, 0, 2]],
|
||||
[[1, 2, 1, 2],
|
||||
[0, 0, 0, 1],
|
||||
[2, 1, 2, 0]]])
|
||||
if descending:
|
||||
expected_output = expected_output[:, ::-1, :]
|
||||
expected_indices = expected_indices[:, ::-1, :]
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
np.testing.assert_array_equal(indices.asnumpy(), expected_indices)
|
||||
|
||||
axis = -3
|
||||
sort_net = SortNet(axis, descending)
|
||||
output, indices = sort_net(x)
|
||||
|
||||
expected_output = np.sort(x_numpy, axis)
|
||||
expected_indices = np.array([[[0, 0, 1, 0],
|
||||
[1, 0, 1, 0],
|
||||
[1, 1, 0, 1]],
|
||||
[[1, 1, 0, 1],
|
||||
[0, 1, 0, 1],
|
||||
[0, 0, 1, 0]]])
|
||||
if descending:
|
||||
expected_output = expected_output[::-1, :, :]
|
||||
expected_indices = expected_indices[::-1, :, :]
|
||||
|
||||
np.testing.assert_array_equal(output.asnumpy(), expected_output)
|
||||
np.testing.assert_array_equal(indices.asnumpy(), expected_indices)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort1d_float16():
|
||||
sort_1d(False, np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort1d_descending_float16():
|
||||
sort_1d(True, np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort1d_float32():
|
||||
sort_1d(False, np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort1d_descending_float32():
|
||||
sort_1d(True, np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort3d_float16():
|
||||
sort_3d(False, np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort3d_descending_float16():
|
||||
sort_3d(True, np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort3d_float32():
|
||||
sort_3d(False, np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sort3d_descending_float32():
|
||||
sort_3d(True, np.float32)
|
Loading…
Reference in New Issue