From b2ec7f2db63429579528580eb14d06fb4373107b Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Mon, 30 Aug 2021 17:28:45 +0800 Subject: [PATCH] add rank cpu kernel --- .../backend/kernel_compiler/cpu/cpu_kernel.cc | 14 + .../backend/kernel_compiler/cpu/cpu_kernel.h | 24 ++ .../kernel_compiler/cpu/rank_cpu_kernel.cc | 304 ++++++++++++++++++ .../kernel_compiler/cpu/rank_cpu_kernel.h | 90 ++++++ .../kernel_compiler/cpu/rolling_cpu_kernel.cc | 2 +- .../kernel_compiler/cpu/rolling_cpu_kernel.h | 4 +- tests/st/ops/cpu/test_rank_op.py | 126 ++++++++ 7 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_rank_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc index 05625a45fb0..4f4d9b5283e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -375,5 +375,19 @@ std::vector CPUKernelUtils::GetBroadcastShape(const std::vector } return broadcast_shape; } + +void AxisIterator::Init(const std::vector &input_shape, size_t axis) { + outer_size_ = 1; + for (size_t i = 0; i < axis; i++) { + outer_size_ *= input_shape[i]; + } + + axis_size_ = input_shape[axis]; + + inner_size_ = 1; + for (size_t i = axis + 1; i < input_shape.size(); ++i) { + inner_size_ *= input_shape[i]; + } +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index bd06d9841be..bfdba981bdd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -70,6 +70,9 @@ const char MIN_PERIODS[] = "min_periods"; const char CENTER[] = "center"; const char METHOD[] = "method"; const char CLOSED[] = "closed"; +const char NA_OPTION[] = "na_option"; +const char ASCENDING[] = "ascending"; +const char PCT[] = "pct"; enum OperateType { ADD = 0, @@ -237,6 +240,27 @@ void ParallelLaunch(const CTask &task, size_t count, float block_size = 128.0, C void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content, ParallelSearchInfo *parallel_search_info); +class AxisIterator { + public: + AxisIterator() = default; + virtual ~AxisIterator() = default; + void Init(const std::vector &input_shape, size_t axis); + + inline void SetOffset(size_t outer_index, size_t inner_index) { + axis_offset_ = outer_index * axis_size_ * inner_size_ + inner_index; + } + inline size_t GetPos(int i) const { return axis_offset_ + i * inner_size_; } + + inline size_t OuterSize() const { return outer_size_; } + inline size_t AxisSize() const { return axis_size_; } + inline size_t InnerSize() const { return inner_size_; } + + private: + size_t outer_size_{0}; + size_t axis_size_{0}; + size_t inner_size_{0}; + size_t axis_offset_{0}; +}; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.cc new file mode 100644 index 00000000000..bcff275be27 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.cc @@ -0,0 +1,304 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/rank_cpu_kernel.h" +#include +#include +#include +#include +#include "common/thread_pool.h" + +namespace mindspore { +namespace kernel { +using rank::Method; +using rank::NaOption; +template +void RankCpuKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + static const std::map kValidMethods = { + {"max", Method::Max}, {"min", Method::Min}, {"average", Method::Average}, + {"first", Method::First}, {"dense", Method::Dense}, + }; + auto method = AnfAlgo::GetNodeAttr(kernel_node, METHOD); + if (kValidMethods.find(method) == kValidMethods.end()) { + MS_LOG(EXCEPTION) << "[" << method << "] not supported"; + } + method_ = kValidMethods.at(method); + + static const std::map kValidOptions = { + {"keep", NaOption::Keep}, + {"top", NaOption::Top}, + {"bottom", NaOption::Bottom}, + }; + auto option = AnfAlgo::GetNodeAttr(kernel_node, NA_OPTION); + if (kValidOptions.find(option) == kValidOptions.end()) { + MS_LOG(EXCEPTION) << "[" << option << "] not supported"; + } + option_ = kValidOptions.at(option); + + ascending_ = AnfAlgo::GetNodeAttr(kernel_node, ASCENDING); + pct_ = AnfAlgo::GetNodeAttr(kernel_node, PCT); + auto axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + axis_ = axis < 0 ? LongToSize(axis + SizeToLong(input_shape.size())) : LongToSize(axis); + if (axis_ >= input_shape.size()) { + MS_LOG(EXCEPTION) << "the evaluated axis should be smaller than the dimension of input tensor " + << input_shape.size() << "D, but got " << axis_; + } + + axisIterator_.Init(input_shape, axis_); + SetFunc(); +} + +template +void RankCpuKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + size_t element_size = axisIterator_.OuterSize() * axisIterator_.InnerSize() * axisIterator_.AxisSize(); + // id + workspace_size_list_.emplace_back((sizeof(size_t) * element_size)); + // copy element + workspace_size_list_.emplace_back((sizeof(T) * element_size)); + if constexpr (!std::is_integral_v) { + // nan flags + workspace_size_list_.emplace_back((sizeof(bool) * element_size)); + } +} + +template +void RankCpuKernel::SetFunc() { + switch (method_) { + case Method::Max: { + func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator, + const size_t *sort_idx, float *output_addr) { + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[axisIterator.GetPos(sort_idx[j])] = i + 1; + } + }; + } break; + case Method::Min: { + func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator, + const size_t *sort_idx, float *output_addr) { + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[axisIterator.GetPos(sort_idx[j])] = i - duplicate_count + 2; + } + }; + } break; + case Method::Average: { + // how avg is computed directly: + // sum = (i - duplicate_count + 1) + (i - duplicate_count + 2) +... + i + // = duplicate_count * (2 * i - duplicate_count + 1) / 2 + // rank_sum = sum + duplicate_count = duplicate_count * (2 * i - duplicate_count + 3) / 2 + // avg = rank_sum / duplicate_count = (2 * i - duplicate_count + 3) / 2 + func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator, + const size_t *sort_idx, float *output_addr) { + float avg = (2 * i - duplicate_count + 3) / 2.0; + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[axisIterator.GetPos(sort_idx[j])] = avg; + } + }; + } break; + case Method::First: { + func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator, + const size_t *sort_idx, float *output_addr) { + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[axisIterator.GetPos(sort_idx[j])] = j + 1; + } + }; + } break; + case Method::Dense: { + func_ = [](int i, int duplicate_count, int culmutive_rank, const AxisIterator &axisIterator, + const size_t *sort_idx, float *output_addr) { + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[axisIterator.GetPos(sort_idx[j])] = culmutive_rank; + } + }; + } break; + case Method::MethodNotDefined: + default: + MS_LOG(EXCEPTION) << "method not init"; + } +} + +template +void RankCpuKernel::Launch1DInt(const T *input_addr, size_t *sort_idx, T *values, const AxisIterator &iter, + float *output_addr) const { + const int n = axisIterator_.AxisSize(); + for (int i = 0; i < n; ++i) { + values[i] = input_addr[iter.GetPos(i)]; + } + + std::iota(sort_idx, sort_idx + iter.AxisSize(), 0); + if (ascending_) { + std::stable_sort(sort_idx, sort_idx + iter.AxisSize(), + [values](size_t lhs, size_t rhs) { return values[lhs] < values[rhs]; }); + } else { + std::stable_sort(sort_idx, sort_idx + iter.AxisSize(), + [values](size_t lhs, size_t rhs) { return values[lhs] > values[rhs]; }); + } + + int culmutive_rank = 1; + int duplicate_count = 0; + + for (int i = 0; i < n; ++i) { + duplicate_count++; + if ((i == n - 1) || (values[sort_idx[i]] != values[sort_idx[i + 1]])) { + func_(i, duplicate_count, culmutive_rank, iter, sort_idx, output_addr); + culmutive_rank++; + duplicate_count = 0; + } + } + + if (pct_) { + // pct calculation + if (method_ == Method::Dense) { + auto size = static_cast(culmutive_rank - 1); + for (int i = 0; i < n; ++i) { + output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size; + } + } else { + auto size = static_cast(n); + for (int i = 0; i < n; ++i) { + output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size; + } + } + } +} + +template +void RankCpuKernel::Launch1DFloat(const T *input_addr, size_t *sort_idx, T *values, bool *is_nan, + const AxisIterator &iter, float *output_addr) const { + const int n = axisIterator_.AxisSize(); + T nan_padding_value; + if (ascending_ != (option_ == NaOption::Top)) { + nan_padding_value = std::numeric_limits::max(); + } else { + nan_padding_value = std::numeric_limits::min(); + } + + for (int i = 0; i < n; ++i) { + const T value = input_addr[iter.GetPos(i)]; + if (std::isnan(value)) { + values[i] = nan_padding_value; + is_nan[i] = true; + } else { + values[i] = value; + is_nan[i] = false; + } + } + + std::iota(sort_idx, sort_idx + iter.AxisSize(), 0); + if (ascending_) { + std::stable_sort(sort_idx, sort_idx + iter.AxisSize(), + [values](size_t lhs, size_t rhs) { return values[lhs] < values[rhs]; }); + } else { + std::stable_sort(sort_idx, sort_idx + iter.AxisSize(), + [values](size_t lhs, size_t rhs) { return values[lhs] > values[rhs]; }); + } + + int culmutive_rank = 1; + int duplicate_count = 0; + int nans_count = 0; + + for (int i = 0; i < n; ++i) { + duplicate_count++; + + if ((i == n - 1) || (values[sort_idx[i]] != values[sort_idx[i + 1]]) || + (is_nan[sort_idx[i]] != is_nan[sort_idx[i + 1]])) { + if ((option_ == NaOption::Keep) && is_nan[sort_idx[i]]) { + for (int j = i - duplicate_count + 1; j < i + 1; ++j) { + output_addr[iter.GetPos(sort_idx[j])] = NAN; + } + } else { + func_(i, duplicate_count, culmutive_rank, iter, sort_idx, output_addr); + } + if (is_nan[sort_idx[i]]) { + nans_count = duplicate_count; + } + culmutive_rank++; + duplicate_count = 0; + } + } + + if (pct_) { + // pct calculation + if (method_ == Method::Dense) { + auto size = static_cast(culmutive_rank - 1); + if ((option_ == NaOption::Keep && (nans_count > 0))) { + size--; + } + for (int i = 0; i < n; ++i) { + output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size; + } + } else { + auto size = static_cast(n); + if (option_ == NaOption::Keep) { + size -= static_cast(nans_count); + } + for (int i = 0; i < n; ++i) { + output_addr[iter.GetPos(i)] = output_addr[iter.GetPos(i)] / size; + } + } + } +} + +template +bool RankCpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "input or output num error"; + } + if constexpr (std::is_integral_v) { + if (workspace.size() != 2) { + MS_LOG(EXCEPTION) << "workspace num error"; + } + } else { + if (workspace.size() != 3) { + MS_LOG(EXCEPTION) << "workspace num error"; + } + } + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto ids_addr = reinterpret_cast(workspace[0]->addr); + auto values_addr = reinterpret_cast(workspace[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + std::vector tasks; + tasks.reserve(axisIterator_.OuterSize() * axisIterator_.InnerSize()); + for (size_t i = 0; i < axisIterator_.OuterSize(); ++i) { + for (size_t j = 0; j < axisIterator_.InnerSize(); ++j) { + tasks.emplace_back([this, i, j, input_addr, ids_addr, values_addr, workspace, output_addr]() { + AxisIterator iter(axisIterator_); + iter.SetOffset(i, j); + + size_t offset = (i * iter.InnerSize() + j) * iter.AxisSize(); + size_t *sort_idx = ids_addr + offset; + T *values = values_addr + offset; + + if constexpr (std::is_integral_v) { + Launch1DInt(input_addr, sort_idx, values, iter, output_addr); + } else { + auto flags_addr = reinterpret_cast(workspace[2]->addr); + bool *is_nan = flags_addr + offset; + Launch1DFloat(input_addr, sort_idx, values, is_nan, iter, output_addr); + } + return common::SUCCESS; + }); + } + } + common::ThreadPool::GetInstance().SyncRun(tasks); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.h new file mode 100644 index 00000000000..50964c77d0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rank_cpu_kernel.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/nnacl/op_base.h" + +namespace mindspore { +namespace kernel { +namespace rank { +enum Method : int { + Average, + Max, + Min, + First, + Dense, + MethodNotDefined, +}; +enum NaOption : int { + Keep, + Top, + Bottom, + OptionNotDefined, +}; +} // namespace rank +template +class RankCpuKernel : public CPUKernel { + public: + RankCpuKernel() = default; + ~RankCpuKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + void SetFunc(); + + void Launch1DInt(const T *input_addr, size_t *sort_idx, T *values, const AxisIterator &iter, + float *output_addr) const; + void Launch1DFloat(const T *input_addr, size_t *sort_idx, T *values, bool *is_nan, const AxisIterator &iter, + float *output_addr) const; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t rank_size_; + // shape info + AxisIterator axisIterator_{}; + // parameters + size_t axis_{0}; + rank::Method method_{rank::MethodNotDefined}; + std::function func_; + rank::NaOption option_{rank::OptionNotDefined}; + bool ascending_{true}; + bool pct_{false}; +}; + +MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + RankCpuKernel, float) + +MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), + RankCpuKernel, double) + +MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RankCpuKernel, + int32_t) + +MS_REG_CPU_KERNEL_T(Rank, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), RankCpuKernel, + int64_t) + +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANK_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc index 2cef9a0af2e..76ba551e3f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace kernel { - +using rolling::Method; template void RollingCpuKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h index 5e7fa778ac2..ebdcb25a1e4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h @@ -24,6 +24,7 @@ namespace mindspore { namespace kernel { +namespace rolling { enum Method : int { Max, Min, @@ -32,6 +33,7 @@ enum Method : int { Std, Var, }; +} template class RollingCpuKernel : public CPUKernel { public: @@ -53,7 +55,7 @@ class RollingCpuKernel : public CPUKernel { size_t axis_{0}; bool center_{false}; std::string closed_{}; - Method method_{}; + rolling::Method method_{}; std::function reduceMethod_{}; // shape info size_t outer_size_{0}; diff --git a/tests/st/ops/cpu/test_rank_op.py b/tests/st/ops/cpu/test_rank_op.py new file mode 100644 index 00000000000..38581b8be98 --- /dev/null +++ b/tests/st/ops/cpu/test_rank_op.py @@ -0,0 +1,126 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import List +from random import sample +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import PrimitiveWithInfer, prim_attr_register +from mindspore._checkparam import Validator as validator +from mindspore.common import dtype as mstype +import numpy as np +import pandas as pd +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Rank(PrimitiveWithInfer): + """ + Shift op frontend implementation + """ + + # size_t axis_{0}; + # rank::Method method_{rank::MethodNotDefined}; + # rank::NaOption option_{rank::OptionNotDefined}; + # bool ascending_{true}; + # bool pct_{false}; + @prim_attr_register + def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool): + """Initialize Sort""" + self.axis = validator.check_value_type("axis", axis, [int], self.name) + self.method = validator.check_value_type("method", method, [str], self.name) + self.na_option = validator.check_value_type("na_option", na_option, [str], self.name) + self.ascending = validator.check_value_type("ascending", ascending, [bool], self.name) + self.pct = validator.check_value_type("pct", pct, [bool], self.name) + + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def __infer__(self, x): + out_shapes = x['shape'] + return { + 'shape': tuple(out_shapes), + 'dtype': mstype.float32, + 'value': None + } + + +class RankNet(nn.Cell): + def __init__(self, axis: int, method: str, na_option: str, ascending: bool, pct: bool): + super(RankNet, self).__init__() + self.rank = Rank(axis, method, na_option, ascending, pct) + + def construct(self, x): + return self.rank(x) + + +def pandas_rank(arr, **kwargs): + ser = pd.DataFrame(arr) + result = ser.rank(**kwargs) + return result.to_numpy() + + +@pytest.mark.parametrize('shape', [(10,)]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) +@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average']) +@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"]) +@pytest.mark.parametrize('ascending', [True, False]) +@pytest.mark.parametrize('pct', [False, True]) +def test_rank_1d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, na_option: str): + np.random.seed(0) + + if dtype in (np.int32, np.int64): + arr = np.random.randint(0, 100, size=shape).astype(dtype) + else: + arr = np.random.random(size=shape).astype(dtype) + arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan + + pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option).flatten() + rank = RankNet(0, method=method, ascending=ascending, pct=pct, na_option=na_option) + mind_result = rank(Tensor(arr)).asnumpy() + + print('arr: \n', arr, arr.dtype, arr.shape) + print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape) + print('mind: \n', mind_result, mind_result.dtype, mind_result.shape) + print(f'method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}') + assert np.allclose(pd_result, mind_result, equal_nan=True) + + +@pytest.mark.parametrize('shape', [(5, 6)]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) +@pytest.mark.parametrize('method', ['dense', 'first', 'max', 'min', 'average']) +@pytest.mark.parametrize('na_option', ["keep", "top", "bottom"]) +@pytest.mark.parametrize('axis', [0, 1]) +@pytest.mark.parametrize('ascending', [True, False]) +@pytest.mark.parametrize('pct', [False, True]) +def test_rank_2d(shape: List[int], dtype, method: str, ascending: bool, pct: bool, axis: int, na_option: str): + np.random.seed(0) + + if dtype in (np.int32, np.int64): + arr = np.random.randint(0, 100, size=shape).astype(dtype) + else: + arr = np.random.random(size=shape).astype(dtype) + arr.flat[sample(range(arr.size), int(arr.size / 10))] = np.nan + + pd_result = pandas_rank(arr, method=method, ascending=ascending, pct=pct, na_option=na_option, axis=axis) + rank = RankNet(axis=axis, method=method, ascending=ascending, pct=pct, na_option=na_option) + mind_result = rank(Tensor(arr)).asnumpy() + + print('arr: \n', arr, arr.dtype, arr.shape) + print('pandas: \n', pd_result, pd_result.dtype, pd_result.shape) + print('mind: \n', mind_result, mind_result.dtype, mind_result.shape) + print(f'axis: {axis}, method: {method}, ascending: {ascending}, pct: {pct} na_option: {na_option}') + assert np.allclose(pd_result, mind_result, equal_nan=True)