forked from mindspore-Ecosystem/mindspore
!22336 add rolling cpu kernel
Merge pull request !22336 from zhujingxuan/master
This commit is contained in:
commit
c710b57efe
|
@ -65,6 +65,11 @@ const char SORTED[] = "sorted";
|
|||
const char ADJ_ST[] = "adjoint_st";
|
||||
const char ADJ_dT[] = "adjoint_dt";
|
||||
const char PERIODS[] = "periods";
|
||||
const char WINDOW[] = "window";
|
||||
const char MIN_PERIODS[] = "min_periods";
|
||||
const char CENTER[] = "center";
|
||||
const char METHOD[] = "method";
|
||||
const char CLOSED[] = "closed";
|
||||
|
||||
enum OperateType {
|
||||
ADD = 0,
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/rolling_cpu_kernel.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <limits>
|
||||
#include <functional>
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T, typename S>
|
||||
void RollingCpuKernel<T, S>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
static const std::map<std::string, Method> kValidMethods = {
|
||||
{"max", Method::Max}, {"min", Method::Min}, {"mean", Method::Mean},
|
||||
{"sum", Method::Sum}, {"std", Method::Std}, {"var", Method::Var},
|
||||
};
|
||||
auto method = AnfAlgo::GetNodeAttr<std::string>(kernel_node, METHOD);
|
||||
if (kValidMethods.find(method) == kValidMethods.end()) {
|
||||
MS_LOG(EXCEPTION) << "[" << method << "] not supported";
|
||||
}
|
||||
method_ = kValidMethods.at(method);
|
||||
window_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, WINDOW);
|
||||
if (window_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "window size should not less than 0, but got " << window_;
|
||||
}
|
||||
min_periods_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, MIN_PERIODS);
|
||||
if (min_periods_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "min_periods should not less than 0, but got " << min_periods_;
|
||||
}
|
||||
center_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CENTER);
|
||||
axis_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
closed_ = AnfAlgo::GetNodeAttr<std::string>(kernel_node, CLOSED);
|
||||
if (axis_ < 0) {
|
||||
axis_ += input_shape.size();
|
||||
}
|
||||
if ((axis_ < 0) || (axis_ >= static_cast<int64_t>(input_shape.size()))) {
|
||||
MS_LOG(EXCEPTION) << "axis should be smaller than the dimension of input tensor " << input_shape.size()
|
||||
<< "D, but got " << axis_;
|
||||
}
|
||||
AxisCalculate(input_shape);
|
||||
RollingBoundsCalculate();
|
||||
MethodSwitch();
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void RollingCpuKernel<T, S>::AxisCalculate(const std::vector<size_t> &input_shape) {
|
||||
outer_size_ = 1;
|
||||
for (int i = 0; i < axis_; i++) {
|
||||
outer_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
axis_size_ = input_shape[axis_];
|
||||
|
||||
inner_size_ = 1;
|
||||
for (int i = axis_ + 1; i < static_cast<int>(input_shape.size()); ++i) {
|
||||
inner_size_ *= input_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void RollingCpuKernel<T, S>::RollingBoundsCalculate() {
|
||||
int offset = 0;
|
||||
if (center_) {
|
||||
offset = (window_ - 1) / 2;
|
||||
}
|
||||
starts_.resize(axis_size_);
|
||||
ends_.resize(axis_size_);
|
||||
int start_offset = 0;
|
||||
int end_offset = 0;
|
||||
if (closed_ == "left") {
|
||||
start_offset -= 1;
|
||||
end_offset -= 1;
|
||||
} else if (closed_ == "both") {
|
||||
start_offset -= 1;
|
||||
} else if (closed_ == "neither") {
|
||||
end_offset -= 1;
|
||||
}
|
||||
for (int i = 0; i < axis_size_; ++i) {
|
||||
int end = offset + i + 1;
|
||||
int start = end - window_;
|
||||
ends_[i] = std::max(0, std::min(end + end_offset, axis_size_));
|
||||
starts_[i] = std::max(0, std::min(start + start_offset, axis_size_));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void RollingCpuKernel<T, S>::MethodSwitch() {
|
||||
switch (method_) {
|
||||
case Method::Max:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
T max_value = std::numeric_limits<T>::min();
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
if (max_value < input_addr[index]) {
|
||||
max_value = input_addr[index];
|
||||
}
|
||||
}
|
||||
return max_value;
|
||||
};
|
||||
break;
|
||||
case Method::Min:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
T min_value = std::numeric_limits<T>::max();
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
if (min_value > input_addr[index]) {
|
||||
min_value = input_addr[index];
|
||||
}
|
||||
}
|
||||
return min_value;
|
||||
};
|
||||
break;
|
||||
case Method::Sum:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
T sum = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum += input_addr[index];
|
||||
}
|
||||
return sum;
|
||||
};
|
||||
break;
|
||||
case Method::Mean:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
T sum = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum += input_addr[index];
|
||||
}
|
||||
return sum * 1.0 / (ends_[w] - starts_[w]);
|
||||
};
|
||||
break;
|
||||
case Method::Var:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
// float for division
|
||||
float n = ends_[w] - starts_[w];
|
||||
T sum1 = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum1 += input_addr[index];
|
||||
}
|
||||
double mean = sum1 / n;
|
||||
double sum2 = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum2 += (input_addr[index] - mean) * (input_addr[index] - mean);
|
||||
}
|
||||
// ddof = 1
|
||||
return sum2 / (n - 1);
|
||||
};
|
||||
break;
|
||||
case Method::Std:
|
||||
reduceMethod_ = [this](const T *input_addr, int outer_offset, int w, int col) {
|
||||
// float for division
|
||||
float n = ends_[w] - starts_[w];
|
||||
T sum1 = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum1 += input_addr[index];
|
||||
}
|
||||
double mean = sum1 / n;
|
||||
double sum2 = 0;
|
||||
for (int x = starts_[w]; x < ends_[w]; ++x) {
|
||||
int index = outer_offset + x * inner_size_ + col;
|
||||
sum2 += (input_addr[index] - mean) * (input_addr[index] - mean);
|
||||
}
|
||||
// ddof = 1
|
||||
return std::sqrt(sum2 / (n - 1));
|
||||
};
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "reduce method is not yet supported: " << method_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool RollingCpuKernel<T, S>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
size_t input_size = inputs[0]->size / sizeof(T);
|
||||
if (input_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "Input data size is 0.";
|
||||
}
|
||||
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<S *>(outputs[0]->addr);
|
||||
|
||||
for (int i = 0; i < outer_size_; ++i) {
|
||||
int outer_offset = i * axis_size_ * inner_size_;
|
||||
for (int col = 0; col < inner_size_; ++col) {
|
||||
for (int w = 0; w < axis_size_; ++w) {
|
||||
int result_offset = outer_offset + w * inner_size_ + col;
|
||||
if (ends_[w] - starts_[w] < min_periods_) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
output_addr[result_offset] = std::nanf("");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
output_addr[result_offset] = std::nan("");
|
||||
} else {
|
||||
// integer values not support nan
|
||||
output_addr[result_offset] = 0;
|
||||
}
|
||||
} else {
|
||||
output_addr[result_offset] = reduceMethod_(input_addr, outer_offset, w, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ROLLING_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ROLLING_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum Method : int {
|
||||
Max,
|
||||
Min,
|
||||
Mean,
|
||||
Sum,
|
||||
Std,
|
||||
Var,
|
||||
};
|
||||
template <typename T, typename S>
|
||||
class RollingCpuKernel : public CPUKernel {
|
||||
public:
|
||||
RollingCpuKernel() = default;
|
||||
~RollingCpuKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void AxisCalculate(const std::vector<size_t> &input_shape);
|
||||
void RollingBoundsCalculate();
|
||||
void MethodSwitch();
|
||||
|
||||
int64_t window_{0};
|
||||
int64_t min_periods_{0};
|
||||
int axis_{0};
|
||||
bool center_{false};
|
||||
std::string closed_{};
|
||||
Method method_{};
|
||||
std::function<S(const T *input_addr, int outer_offset, int w, int col)> reduceMethod_{};
|
||||
// shape info
|
||||
int outer_size_{0};
|
||||
int axis_size_{0};
|
||||
int inner_size_{0};
|
||||
// rolling info
|
||||
std::vector<int> starts_{};
|
||||
std::vector<int> ends_{};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
RollingCpuKernel, float, float)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
RollingCpuKernel, double, double)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
RollingCpuKernel, int32_t, int32_t)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
RollingCpuKernel, int64_t, int64_t)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
RollingCpuKernel, int32_t, float)
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Rolling, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
RollingCpuKernel, int64_t, double)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, List
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Rolling(PrimitiveWithInfer):
|
||||
"""
|
||||
Shift op frontend implementation
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
|
||||
method: str):
|
||||
"""Initialize Sort"""
|
||||
self.window = validator.check_value_type("window", window, [int], self.name)
|
||||
self.min_periods = validator.check_value_type("min_periods", min_periods, [int], self.name)
|
||||
self.center = validator.check_value_type("center", center, [bool], self.name)
|
||||
self.axis = validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.closed = validator.check_value_type("closed", closed, [str], self.name)
|
||||
self.method = validator.check_value_type("method", method, [str], self.name)
|
||||
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x):
|
||||
out_shapes = x['shape']
|
||||
return {
|
||||
'shape': tuple(out_shapes),
|
||||
'dtype': x['dtype'],
|
||||
'value': None
|
||||
}
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid(x_dtype, [mstype.float32, mstype.float64, mstype.int32, mstype.int64],
|
||||
self.name, True)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class RollingNet(nn.Cell):
|
||||
def __init__(self, window: int, min_periods: int, center: bool, axis: int, closed: str,
|
||||
method: str):
|
||||
super(RollingNet, self).__init__()
|
||||
self.rolling = Rolling(window, min_periods, center, axis, closed, method)
|
||||
|
||||
def construct(self, x):
|
||||
return self.rolling(x)
|
||||
|
||||
|
||||
def get_window_bounds(num_values: int, window_size: int, center: bool, closed: str = 'right') -> Tuple[List, List]:
|
||||
assert closed in {'left', 'both', 'right', 'neither'}
|
||||
offset = (window_size - 1) // 2 if center else 0
|
||||
|
||||
end = np.arange(offset + 1, num_values + 1 + offset, dtype=np.int64)
|
||||
start = end - window_size
|
||||
if closed in {'left', 'both'}:
|
||||
start -= 1
|
||||
if closed in {'left', 'neither'}:
|
||||
end -= 1
|
||||
|
||||
end = np.clip(end, 0, num_values)
|
||||
start = np.clip(start, 0, num_values)
|
||||
|
||||
return list(start), list(end)
|
||||
|
||||
|
||||
def numpy_rolling(array: np.ndarray, window: int, min_periods: int, center: bool, axis: int, closed: str,
|
||||
method: str) -> np.ndarray:
|
||||
assert window > 0
|
||||
assert 0 < min_periods <= window
|
||||
assert axis in range(-array.ndim, array.ndim)
|
||||
reduce_map = {'max': np.max, 'min': np.min, 'mean': np.mean, 'sum': np.sum, 'std': partial(np.std, ddof=1),
|
||||
'var': partial(np.var, ddof=1)}
|
||||
assert method in reduce_map
|
||||
|
||||
size = array.shape[axis]
|
||||
start, end = get_window_bounds(size, window, center, closed)
|
||||
|
||||
rolling_indices = [[slice(None)] * array.ndim for _ in range(len(start))]
|
||||
for i, j, indice in zip(start, end, rolling_indices):
|
||||
indice[axis] = None if j - i < min_periods else slice(i, j)
|
||||
# print(f'i={i}, j={j}, index={index}, indice={rolling_indices[index][axis]}')
|
||||
|
||||
shape = list(array.shape)
|
||||
shape[axis] = 1
|
||||
nan_array = np.empty(shape)
|
||||
if array.dtype == np.float32 or array.dtype == np.float64:
|
||||
nan_array[:] = np.nan
|
||||
elif array.dtype == np.int32 or array.dtype == np.int64:
|
||||
nan_array[:] = 0
|
||||
|
||||
arrays = [
|
||||
nan_array.copy() if not indice[axis]
|
||||
else reduce_map[method](array[tuple(indice)], axis=axis, keepdims=True).reshape(shape)
|
||||
for indice in rolling_indices]
|
||||
|
||||
return np.stack(arrays, axis=axis).reshape(array.shape).astype(array.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shape', [(10, 8, 15, 7), (5, 3, 8, 10)])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||
@pytest.mark.parametrize('window, min_periods', [(3, 3), (5, 3)])
|
||||
@pytest.mark.parametrize('center', [True, False])
|
||||
@pytest.mark.parametrize('axis', [2, 3, -1])
|
||||
@pytest.mark.parametrize('closed', ['left', 'both', 'right', 'neither'])
|
||||
@pytest.mark.parametrize('method', ['max', 'min', 'mean', 'sum', 'std', 'var'])
|
||||
def test_two_way(shape: List[int], dtype, window: int, min_periods: int, center: bool, axis: int, closed: str,
|
||||
method: str) -> np.ndarray:
|
||||
if dtype in (np.int32, np.int64):
|
||||
arr = np.random.randint(0, 100, size=shape)
|
||||
else:
|
||||
arr = np.random.random(shape).astype(dtype)
|
||||
expect_result = numpy_rolling(arr, window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
|
||||
method=method)
|
||||
rolling = RollingNet(window=window, min_periods=min_periods, center=center, axis=axis, closed=closed,
|
||||
method=method)
|
||||
actual_result = rolling(Tensor(arr)).asnumpy()
|
||||
print('arr: \n', arr, arr.dtype, arr.shape)
|
||||
print('np: \n', expect_result, expect_result.dtype, expect_result.shape)
|
||||
print('mine: \n', actual_result, actual_result.dtype, actual_result.shape)
|
||||
print(f'center: {center}, axis: {axis}, method: {method}')
|
||||
assert np.allclose(expect_result, actual_result, equal_nan=True)
|
Loading…
Reference in New Issue