add rank cpu kernel

This commit is contained in:
zhujingxuan 2021-08-30 17:28:45 +08:00
parent 00ac7b855c
commit b2ec7f2db6
7 changed files with 562 additions and 2 deletions

View File

@ -375,5 +375,19 @@ std::vector<size_t> CPUKernelUtils::GetBroadcastShape(const std::vector<size_t>
}
return broadcast_shape;
}
void AxisIterator::Init(const std::vector<size_t> &input_shape, size_t axis) {
outer_size_ = 1;
for (size_t i = 0; i < axis; i++) {
outer_size_ *= input_shape[i];
}
axis_size_ = input_shape[axis];
inner_size_ = 1;
for (size_t i = axis + 1; i < input_shape.size(); ++i) {
inner_size_ *= input_shape[i];
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -70,6 +70,9 @@ const char MIN_PERIODS[] = "min_periods";
const char CENTER[] = "center";
const char METHOD[] = "method";
const char CLOSED[] = "closed";
const char NA_OPTION[] = "na_option";
const char ASCENDING[] = "ascending";
const char PCT[] = "pct";
enum OperateType {
ADD = 0,
@ -237,6 +240,27 @@ void ParallelLaunch(const CTask &task, size_t count, float block_size = 128.0, C
void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content,
ParallelSearchInfo *parallel_search_info);
class AxisIterator {
public:
AxisIterator() = default;
virtual ~AxisIterator() = default;
void Init(const std::vector<size_t> &input_shape, size_t axis);
inline void SetOffset(size_t outer_index, size_t inner_index) {
axis_offset_ = outer_index * axis_size_ * inner_size_ + inner_index;
}
inline size_t GetPos(int i) const { return axis_offset_ + i * inner_size_; }
inline size_t OuterSize() const { return outer_size_; }
inline size_t AxisSize() const { return axis_size_; }
inline size_t InnerSize() const { return inner_size_; }
private:
size_t outer_size_{0};
size_t axis_size_{0};
size_t inner_size_{0};
size_t axis_offset_{0};
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,304 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/rank_cpu_kernel.h"
#include <type_traits>
#include <functional>
#include <limits>
#include <map>
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
using rank::Method;
using rank::NaOption;
template <typename T>
void RankCpuKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
static const std::map<std::string, Method> kValidMethods = {
{"max", Method::Max}, {"min", Method::Min}, {"average", Method::Average},
{"first", Method::First}, {"dense", Method::Dense},
};
auto method = AnfAlgo::GetNodeAttr<std::string>(kernel_node, METHOD);
if (kValidMethods.find(method) == kValidMethods.end()) {
MS_LOG(EXCEPTION) << "[" << method << "] not supported";
}
method_ = kValidMethods.at(method);
static const std::map<std::string, NaOption> kValidOptions = {
{"keep", NaOption::Keep},
{"top", NaOption::Top},
{"bottom", NaOption::Bottom},
};
auto option = AnfAlgo::GetNodeAttr<std::string>(kernel_node, NA_OPTION);
if (kValidOptions.find(option) == kValidOptions.end()) {
MS_LOG(EXCEPTION) << "[" << option << "] not supported";
}
option_ = kValidOptions.at(option);
ascending_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ASCENDING);
pct_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, PCT);
auto axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis_ = axis < 0 ? LongToSize(axis + SizeToLong(input_shape.size())) : LongToSize(axis);
if (axis_ >= input_shape.size()) {
MS_LOG(EXCEPTION) << "the evaluated axis should be smaller than the dimension of input tensor "
<< input_shape.size() << "D, but got " << axis_;
}
axisIterator_.Init(input_shape, axis_);
SetFunc();
}
template <typename T>
void RankCpuKernel<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
size_t element_size = axisIterator_.OuterSize() * axisIterator_.InnerSize() * axisIterator_.AxisSize();
// id
workspace_size_list_.emplace_back((sizeof(size_t) * element_size));
// copy element
workspace_size_list_.emplace_back((sizeof(T) * element_size));
if constexpr (!std::is_integral_v<T>) {
// nan flags
workspace_size_list_.emplace_back((sizeof(bool) * element_size));
}
}
template <typename T>
void RankCpuKernel<T>::SetFunc() {
switch (method_) {
case Method::Max: {
func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator,
const size_t *sort_idx, float *output_addr) {
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[axisIterator.GetPos(sort_idx[j])] = i + 1;
}
};
} break;
case Method::Min: {
func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator,
const size_t *sort_idx, float *output_addr) {
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[axisIterator.GetPos(sort_idx[j])] = i - duplicate_count + 2;
}
};
} break;
case Method::Average: {
// how avg is computed directly:
// sum = (i - duplicate_count + 1) + (i - duplicate_count + 2) +... + i
// = duplicate_count * (2 * i - duplicate_count + 1) / 2
// rank_sum = sum + duplicate_count = duplicate_count * (2 * i - duplicate_count + 3) / 2
// avg = rank_sum / duplicate_count = (2 * i - duplicate_count + 3) / 2
func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator,
const size_t *sort_idx, float *output_addr) {
float avg = (2 * i - duplicate_count + 3) / 2.0;
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[axisIterator.GetPos(sort_idx[j])] = avg;
}
};
} break;
case Method::First: {
func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator,
const size_t *sort_idx, float *output_addr) {
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[axisIterator.GetPos(sort_idx[j])] = j + 1;
}
};
} break;
case Method::Dense: {
func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator,
const size_t *sort_idx, float *output_addr) {
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[axisIterator.GetPos(sort_idx[j])] = culmutive_rank;
}
};
} break;
case Method::MethodNotDefined:
default:
MS_LOG(EXCEPTION) << "method not init";
}
}
template <typename T>
void RankCpuKernel<T>::Launch1DInt(const T *input_addr, size_t *sort_idx, T *values, const AxisIterator &iter,
float *output_addr) const {
const int n = axisIterator_.AxisSize();
for (int i = 0; i < n; ++i) {
values[i] = input_addr[iter.GetPos(i)];
}
std::iota(sort_idx, sort_idx + iter.AxisSize(), 0);
if (ascending_) {
std::stable_sort(sort_idx, sort_idx + iter.AxisSize(),
[values](size_t lhs, size_t rhs) { return values[lhs] < values[rhs]; });
} else {
std::stable_sort(sort_idx, sort_idx + iter.AxisSize(),
[values](size_t lhs, size_t rhs) { return values[lhs] > values[rhs]; });
}
int culmutive_rank = 1;
int duplicate_count = 0;
for (int i = 0; i < n; ++i) {
duplicate_count++;
if ((i == n - 1) || (values[sort_idx[i]] != values[sort_idx[i + 1]])) {
func_(i, duplicate_count, culmutive_rank, iter, sort_idx, output_addr);
culmutive_rank++;
duplicate_count = 0;
}
}
if (pct_) {
// pct calculation
if (method_ == Method::Dense) {
auto size = static_cast<float>(culmutive_rank - 1);
for (int i = 0; i < n; ++i) {
output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size;
}
} else {
auto size = static_cast<float>(n);
for (int i = 0; i < n; ++i) {
output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size;
}
}
}
}
template <typename T>
void RankCpuKernel<T>::Launch1DFloat(const T *input_addr, size_t *sort_idx, T *values, bool *is_nan,
const AxisIterator &iter, float *output_addr) const {
const int n = axisIterator_.AxisSize();
T nan_padding_value;
if (ascending_ != (option_ == NaOption::Top)) {
nan_padding_value = std::numeric_limits<T>::max();
} else {
nan_padding_value = std::numeric_limits<T>::min();
}
for (int i = 0; i < n; ++i) {
const T value = input_addr[iter.GetPos(i)];
if (std::isnan(value)) {
values[i] = nan_padding_value;
is_nan[i] = true;
} else {
values[i] = value;
is_nan[i] = false;
}
}
std::iota(sort_idx, sort_idx + iter.AxisSize(), 0);
if (ascending_) {
std::stable_sort(sort_idx, sort_idx + iter.AxisSize(),
[values](size_t lhs, size_t rhs) { return values[lhs] < values[rhs]; });
} else {
std::stable_sort(sort_idx, sort_idx + iter.AxisSize(),
[values](size_t lhs, size_t rhs) { return values[lhs] > values[rhs]; });
}
int culmutive_rank = 1;
int duplicate_count = 0;
int nans_count = 0;
for (int i = 0; i < n; ++i) {
duplicate_count++;
if ((i == n - 1) || (values[sort_idx[i]] != values[sort_idx[i + 1]]) ||
(is_nan[sort_idx[i]] != is_nan[sort_idx[i + 1]])) {
if ((option_ == NaOption::Keep) && is_nan[sort_idx[i]]) {
for (int j = i - duplicate_count + 1; j < i + 1; ++j) {
output_addr[iter.GetPos(sort_idx[j])] = NAN;
}
} else {
func_(i, duplicate_count, culmutive_rank, iter, sort_idx, output_addr);
}
if (is_nan[sort_idx[i]]) {
nans_count = duplicate_count;
}
culmutive_rank++;
duplicate_count = 0;
}
}
if (pct_) {
// pct calculation
if (method_ == Method::Dense) {
auto size = static_cast<float>(culmutive_rank - 1);
if ((option_ == NaOption::Keep && (nans_count > 0))) {
size--;
}
for (int i = 0; i < n; ++i) {
output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size;
}
} else {
auto size = static_cast<float>(n);
if (option_ == NaOption::Keep) {
size -= static_cast<float>(nans_count);
}
for (int i = 0; i < n; ++i) {
output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size;
}
}
}
}
template <typename T>
bool RankCpuKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(EXCEPTION) << "input or output num error";
}
if constexpr (std::is_integral_v<T>) {
if (workspace.size() != 2) {
MS_LOG(EXCEPTION) << "workspace num error";
}
} else {
if (workspace.size() != 3) {
MS_LOG(EXCEPTION) << "workspace num error";
}
}
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto ids_addr = reinterpret_cast<size_t *>(workspace[0]->addr);
auto values_addr = reinterpret_cast<T *>(workspace[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
std::vector<common::Task> tasks;
tasks.reserve(axisIterator_.OuterSize() * axisIterator_.InnerSize());
for (size_t i = 0; i < axisIterator_.OuterSize(); ++i) {
for (size_t j = 0; j < axisIterator_.InnerSize(); ++j) {
tasks.emplace_back([this, i, j, input_addr, ids_addr, values_addr, workspace, output_addr]() {
AxisIterator iter(axisIterator_);
iter.SetOffset(i, j);
size_t offset = (i * iter.InnerSize() + j) * iter.AxisSize();
size_t *sort_idx = ids_addr + offset;
T *values = values_addr + offset;
if constexpr (std::is_integral_v<T>) {
Launch1DInt(input_addr, sort_idx, values, iter, output_addr);
} else {
auto flags_addr = reinterpret_cast<bool *>(workspace[2]->addr);
bool *is_nan = flags_addr + offset;
Launch1DFloat(input_addr, sort_idx, values, is_nan, iter, output_addr);
}
return common::SUCCESS;
});
}
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,90 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
namespace mindspore {
namespace kernel {
namespace rank {
enum Method : int {
Average,
Max,
Min,
First,
Dense,
MethodNotDefined,
};
enum NaOption : int {
Keep,
Top,
Bottom,
OptionNotDefined,
};
} // namespace rank
template <typename T>
class RankCpuKernel : public CPUKernel {
public:
RankCpuKernel() = default;
~RankCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
void SetFunc();
void Launch1DInt(const T *input_addr, size_t *sort_idx, T *values, const AxisIterator &iter,
float *output_addr) const;
void Launch1DFloat(const T *input_addr, size_t *sort_idx, T *values, bool *is_nan, const AxisIterator &iter,
float *output_addr) const;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t rank_size_;
// shape info
AxisIterator axisIterator_{};
// parameters
size_t axis_{0};
rank::Method method_{rank::MethodNotDefined};
std::function<void(int, int, int, const AxisIterator &, const size_t *, float *)> func_;
rank::NaOption option_{rank::OptionNotDefined};
bool ascending_{true};
bool pct_{false};
};
MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
RankCpuKernel, float)
MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
RankCpuKernel, double)
MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RankCpuKernel,
int32_t)
MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), RankCpuKernel,
int64_t)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_

View File

@ -23,7 +23,7 @@
namespace mindspore {
namespace kernel {
using rolling::Method;
template <typename T, typename S>
void RollingCpuKernel<T, S>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);

View File

@ -24,6 +24,7 @@
namespace mindspore {
namespace kernel {
namespace rolling {
enum Method : int {
Max,
Min,
@ -32,6 +33,7 @@ enum Method : int {
Std,
Var,
};
}
template <typename T, typename S>
class RollingCpuKernel : public CPUKernel {
public:
@ -53,7 +55,7 @@ class RollingCpuKernel : public CPUKernel {
size_t axis_{0};
bool center_{false};
std::string closed_{};
Method method_{};
rolling::Method method_{};
std::function<S(const T *input_addr, int outer_offset, size_t start, size_t end, int col)> reduceMethod_{};
// shape info
size_t outer_size_{0};

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import List
from random import sample
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
import numpy as np
import pandas as pd
import pytest
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Rank(PrimitiveWithInfer):
"""
Shift op frontend implementation
"""
# size_t axis_{0};
# rank::Method method_{rank::MethodNotDefined};
# rank::NaOption option_{rank::OptionNotDefined};
# bool ascending_{true};
# bool pct_{false};
@prim_attr_register
def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool):
"""Initialize Sort"""
self.axis = validator.check_value_type("axis", axis, [int], self.name)
self.method = validator.check_value_type("method", method, [str], self.name)
self.na_option = validator.check_value_type("na_option", na_option, [str], self.name)
self.ascending = validator.check_value_type("ascending", ascending, [bool], self.name)
self.pct = validator.check_value_type("pct", pct, [bool], self.name)
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def __infer__(self, x):
out_shapes = x['shape']
return {
'shape': tuple(out_shapes),
'dtype': mstype.float32,
'value': None
}
class RankNet(nn.Cell):
def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool):
super(RankNet, self).__init__()
self.rank = Rank(axis, method, na_option, ascending, pct)
def construct(self, x):
return self.rank(x)
def pandas_rank(arr, **kwargs):
ser = pd.DataFrame(arr)
result = ser.rank(**kwargs)
return result.to_numpy()
@pytest.mark.parametrize('shape', [(10,)])
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average'])
@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"])
@pytest.mark.parametrize('ascending', [True, False])
@pytest.mark.parametrize('pct', [False, True])
def test_rank_1d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, na_option: str):
np.random.seed(0)
if dtype in (np.int32, np.int64):
arr = np.random.randint(0, 100, size=shape).astype(dtype)
else:
arr = np.random.random(size=shape).astype(dtype)
arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan
pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option).flatten()
rank = RankNet(0, method=method, ascending=ascending, pct=pct, na_option=na_option)
mind_result = rank(Tensor(arr)).asnumpy()
print('arr: \n', arr, arr.dtype, arr.shape)
print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape)
print('mind: \n', mind_result, mind_result.dtype, mind_result.shape)
print(f'method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}')
assert np.allclose(pd_result, mind_result, equal_nan=True)
@pytest.mark.parametrize('shape', [(5, 6)])
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average'])
@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"])
@pytest.mark.parametrize('axis', [0, 1])
@pytest.mark.parametrize('ascending', [True, False])
@pytest.mark.parametrize('pct', [False, True])
def test_rank_2d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, axis: int, na_option: str):
np.random.seed(0)
if dtype in (np.int32, np.int64):
arr = np.random.randint(0, 100, size=shape).astype(dtype)
else:
arr = np.random.random(size=shape).astype(dtype)
arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan
pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option, axis=axis)
rank = RankNet(axis=axis, method=method, ascending=ascending, pct=pct, na_option=na_option)
mind_result = rank(Tensor(arr)).asnumpy()
print('arr: \n', arr, arr.dtype, arr.shape)
print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape)
print('mind: \n', mind_result, mind_result.dtype, mind_result.shape)
print(f'axis: {axis}, method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}')
assert np.allclose(pd_result, mind_result, equal_nan=True)