From ad389316cb7bb0d520cb50da515b965762950a9d Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Tue, 17 Aug 2021 15:32:48 +0800 Subject: [PATCH] add rolling cpu kernel --- .../backend/kernel_compiler/cpu/cpu_kernel.h | 5 + .../kernel_compiler/cpu/rolling_cpu_kernel.cc | 231 ++++++++++++++++++ .../kernel_compiler/cpu/rolling_cpu_kernel.h | 88 +++++++ tests/st/ops/cpu/test_rolling_op.py | 144 +++++++++++ 4 files changed, 468 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_rolling_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index c3bd29f7e65..bd06d9841be 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -65,6 +65,11 @@ const char SORTED[] = "sorted"; const char ADJ_ST[] = "adjoint_st"; const char ADJ_dT[] = "adjoint_dt"; const char PERIODS[] = "periods"; +const char WINDOW[] = "window"; +const char MIN_PERIODS[] = "min_periods"; +const char CENTER[] = "center"; +const char METHOD[] = "method"; +const char CLOSED[] = "closed"; enum OperateType { ADD = 0, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc new file mode 100644 index 00000000000..32680933f63 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc @@ -0,0 +1,231 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/rolling_cpu_kernel.h" +#include +#include +#include +#include +#include +#include "common/thread_pool.h" + +namespace mindspore { +namespace kernel { + +template +void RollingCpuKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + static const std::map kValidMethods = { + {"max", Method::Max}, {"min", Method::Min}, {"mean", Method::Mean}, + {"sum", Method::Sum}, {"std", Method::Std}, {"var", Method::Var}, + }; + auto method = AnfAlgo::GetNodeAttr(kernel_node, METHOD); + if (kValidMethods.find(method) == kValidMethods.end()) { + MS_LOG(EXCEPTION) << "[" << method << "] not supported"; + } + method_ = kValidMethods.at(method); + window_ = AnfAlgo::GetNodeAttr(kernel_node, WINDOW); + if (window_ <= 0) { + MS_LOG(EXCEPTION) << "window size should not less than 0, but got " << window_; + } + min_periods_ = AnfAlgo::GetNodeAttr(kernel_node, MIN_PERIODS); + if (min_periods_ <= 0) { + MS_LOG(EXCEPTION) << "min_periods should not less than 0, but got " << min_periods_; + } + center_ = AnfAlgo::GetNodeAttr(kernel_node, CENTER); + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + closed_ = AnfAlgo::GetNodeAttr(kernel_node, CLOSED); + if (axis_ < 0) { + axis_ += input_shape.size(); + } + if ((axis_ < 0) || (axis_ >= static_cast(input_shape.size()))) { + MS_LOG(EXCEPTION) << "axis should be smaller than the dimension of input tensor " << input_shape.size() + << "D, but got " << axis_; + } + AxisCalculate(input_shape); + RollingBoundsCalculate(); + MethodSwitch(); +} + +template +void RollingCpuKernel::AxisCalculate(const std::vector &input_shape) { + outer_size_ = 1; + for (int i = 0; i < axis_; i++) { + outer_size_ *= input_shape[i]; + } + + axis_size_ = input_shape[axis_]; + + inner_size_ = 1; + for (int i = axis_ + 1; i < static_cast(input_shape.size()); ++i) { + inner_size_ *= input_shape[i]; + } +} + +template +void RollingCpuKernel::RollingBoundsCalculate() { + int offset = 0; + if (center_) { + offset = (window_ - 1) / 2; + } + starts_.resize(axis_size_); + ends_.resize(axis_size_); + int start_offset = 0; + int end_offset = 0; + if (closed_ == "left") { + start_offset -= 1; + end_offset -= 1; + } else if (closed_ == "both") { + start_offset -= 1; + } else if (closed_ == "neither") { + end_offset -= 1; + } + for (int i = 0; i < axis_size_; ++i) { + int end = offset + i + 1; + int start = end - window_; + ends_[i] = std::max(0, std::min(end + end_offset, axis_size_)); + starts_[i] = std::max(0, std::min(start + start_offset, axis_size_)); + } +} + +template +void RollingCpuKernel::MethodSwitch() { + switch (method_) { + case Method::Max: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + T max_value = std::numeric_limits::min(); + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + if (max_value < input_addr[index]) { + max_value = input_addr[index]; + } + } + return max_value; + }; + break; + case Method::Min: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + T min_value = std::numeric_limits::max(); + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + if (min_value > input_addr[index]) { + min_value = input_addr[index]; + } + } + return min_value; + }; + break; + case Method::Sum: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + T sum = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum += input_addr[index]; + } + return sum; + }; + break; + case Method::Mean: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + T sum = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum += input_addr[index]; + } + return sum * 1.0 / (ends_[w] - starts_[w]); + }; + break; + case Method::Var: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + // float for division + float n = ends_[w] - starts_[w]; + T sum1 = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum1 += input_addr[index]; + } + double mean = sum1 / n; + double sum2 = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum2 += (input_addr[index] - mean) * (input_addr[index] - mean); + } + // ddof = 1 + return sum2 / (n - 1); + }; + break; + case Method::Std: + reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) { + // float for division + float n = ends_[w] - starts_[w]; + T sum1 = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum1 += input_addr[index]; + } + double mean = sum1 / n; + double sum2 = 0; + for (int x = starts_[w]; x < ends_[w]; ++x) { + int index = outer_offset + x * inner_size_ + col; + sum2 += (input_addr[index] - mean) * (input_addr[index] - mean); + } + // ddof = 1 + return std::sqrt(sum2 / (n - 1)); + }; + break; + default: + MS_LOG(EXCEPTION) << "reduce method is not yet supported: " << method_; + } +} + +template +bool RollingCpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + size_t input_size = inputs[0]->size / sizeof(T); + if (input_size == 0) { + MS_LOG(EXCEPTION) << "Input data size is 0."; + } + + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + for (int i = 0; i < outer_size_; ++i) { + int outer_offset = i * axis_size_ * inner_size_; + for (int col = 0; col < inner_size_; ++col) { + for (int w = 0; w < axis_size_; ++w) { + int result_offset = outer_offset + w * inner_size_ + col; + if (ends_[w] - starts_[w] < min_periods_) { + if constexpr (std::is_same_v) { + output_addr[result_offset] = std::nanf(""); + } else if constexpr (std::is_same_v) { + output_addr[result_offset] = std::nan(""); + } else { + // integer values not support nan + output_addr[result_offset] = 0; + } + } else { + output_addr[result_offset] = reduceMethod_(input_addr, outer_offset, w, col); + } + } + } + } + + return true; +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h new file mode 100644 index 00000000000..d46cb409303 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ROLLING_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ROLLING_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/nnacl/op_base.h" + +namespace mindspore { +namespace kernel { +enum Method : int { + Max, + Min, + Mean, + Sum, + Std, + Var, +}; +template +class RollingCpuKernel : public CPUKernel { + public: + RollingCpuKernel() = default; + ~RollingCpuKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void AxisCalculate(const std::vector &input_shape); + void RollingBoundsCalculate(); + void MethodSwitch(); + + int64_t window_{0}; + int64_t min_periods_{0}; + int axis_{0}; + bool center_{false}; + std::string closed_{}; + Method method_{}; + std::function reduceMethod_{}; + // shape info + int outer_size_{0}; + int axis_size_{0}; + int inner_size_{0}; + // rolling info + std::vector starts_{}; + std::vector ends_{}; +}; + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + RollingCpuKernel, float, float) + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + RollingCpuKernel, double, double) + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + RollingCpuKernel, int32_t, int32_t) + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + RollingCpuKernel, int64_t, int64_t) + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + RollingCpuKernel, int32_t, float) + +MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + RollingCpuKernel, int64_t, double) + +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_rolling_op.py b/tests/st/ops/cpu/test_rolling_op.py new file mode 100644 index 00000000000..b9c4fd097ea --- /dev/null +++ b/tests/st/ops/cpu/test_rolling_op.py @@ -0,0 +1,144 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from functools import partial +from typing import Tuple, List +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import PrimitiveWithInfer, prim_attr_register +from mindspore._checkparam import Validator as validator +from mindspore.common import dtype as mstype +import numpy as np +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Rolling(PrimitiveWithInfer): + """ + Shift op frontend implementation + """ + + @prim_attr_register + def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str, + method: str): + """Initialize Sort""" + self.window = validator.check_value_type("window", window, [int], self.name) + self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name) + self.center = validator.check_value_type("center", center, [bool], self.name) + self.axis = validator.check_value_type("axis", axis, [int], self.name) + self.closed = validator.check_value_type("closed", closed, [str], self.name) + self.method = validator.check_value_type("method", method, [str], self.name) + + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def __infer__(self, x): + out_shapes = x['shape'] + return { + 'shape': tuple(out_shapes), + 'dtype': x['dtype'], + 'value': None + } + + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64], + self.name, True) + return x_dtype + + +class RollingNet(nn.Cell): + def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str, + method: str): + super(RollingNet, self).__init__() + self.rolling = Rolling(window, min_periods, center, axis, closed, method) + + def construct(self, x): + return self.rolling(x) + + +def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]: + assert closed in {'left', 'both', 'right', 'neither'} + offset = (window_size - 1) // 2 if center else 0 + + end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64) + start = end - window_size + if closed in {'left', 'both'}: + start -= 1 + if closed in {'left', 'neither'}: + end -= 1 + + end = np.clip(end, 0, num_values) + start = np.clip(start, 0, num_values) + + return list(start), list(end) + + +def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str, + method: str) -> np.ndarray: + assert window > 0 + assert 0 < min_periods <= window + assert axis in range(-array.ndim, array.ndim) + reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1), + 'var': partial(np.var, ddof=1)} + assert method in reduce_map + + size = array.shape[axis] + start, end = get_window_bounds(size, window, center, closed) + + rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))] + for i, j, indice in zip(start, end, rolling_indices): + indice[axis] = None if j - i < min_periods else slice(i, j) + # print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}') + + shape = list(array.shape) + shape[axis] = 1 + nan_array = np.empty(shape) + if array.dtype == np.float32 or array.dtype == np.float64: + nan_array[:] = np.nan + elif array.dtype == np.int32 or array.dtype == np.int64: + nan_array[:] = 0 + + arrays = [ + nan_array.copy() if not indice[axis] + else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape) + for indice in rolling_indices] + + return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype) + + +@pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)]) +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) +@pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)]) +@pytest.mark.parametrize('center', [True, False]) +@pytest.mark.parametrize('axis', [2, 3, -1]) +@pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither']) +@pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var']) +def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str, + method: str) -> np.ndarray: + if dtype in (np.int32, np.int64): + arr = np.random.randint(0, 100, size=shape) + else: + arr = np.random.random(shape).astype(dtype) + expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed, + method=method) + rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed, + method=method) + actual_result = rolling(Tensor(arr)).asnumpy() + print('arr: \n', arr, arr.dtype, arr.shape) + print('np: \n', expect_result, expect_result.dtype, expect_result.shape) + print('mine: \n', actual_result, actual_result.dtype, actual_result.shape) + print(f'center: {center}, axis: {axis}, method: {method}') + assert np.allclose(expect_result, actual_result, equal_nan=True)