add Shift cpu OP

This commit is contained in:
zhujingxuan 2021-08-09 09:47:58 +08:00
parent bff5cbda9c
commit c7febdcca4
4 changed files with 451 additions and 0 deletions

View File

@ -62,6 +62,8 @@ const char DELTA[] = "delta";
const char SORTED[] = "sorted";
const char ADJ_ST[] = "adjoint_st";
const char ADJ_dT[] = "adjoint_dt";
const char PERIODS[] = "periods";
const char FILL_VALUE[] = "fill_value";
enum OperateType {
ADD = 0,

View File

@ -0,0 +1,146 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/shift_cpu_kernel.h"
#include <algorithm>
#include <cmath>
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
template <typename T>
void ShiftCpuKernel<T>::InitKernel(const CNodePtr &kernel_node) {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(EXCEPTION) << input_count << " inputs were provided, but Shift expects 1.";
}
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_count != 1) {
MS_LOG(EXCEPTION) << "Number of outputs is " << output_count << ", but should be 1 for Shift.";
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
periods_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, PERIODS);
auto fill_value = AnfAlgo::GetNodeAttr<float>(kernel_node, FILL_VALUE);
if constexpr (std::is_same<T, int64_t>::value || std::is_same<T, int32_t>::value) {
if (std::isnan(fill_value)) {
MS_LOG(EXCEPTION) << "for integer input, nan is not supported for fill_value ";
}
}
fill_value_ = static_cast<T>(fill_value);
auto axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
if (axis < 0) {
axis += input_shape.size();
}
if ((axis < 0) || (axis >= static_cast<int64_t>(input_shape.size()))) {
MS_LOG(EXCEPTION) << "axis should be smaller than the dimension of input tensor " << input_shape.size()
<< "D, but got " << axis;
}
// calculating axis size
size_t axis_t = axis;
outer_size_ = 1;
for (size_t i = 0; i < axis_t; i++) {
outer_size_ *= input_shape[i];
}
axis_size_ = input_shape[axis_t];
inner_size_ = 1;
for (size_t i = axis_t + 1; i < input_shape.size(); ++i) {
inner_size_ *= input_shape[i];
}
// index calculation
if (periods_ > 0) {
fill_begin_ = 0;
fill_size_ = periods_;
copy_src_begin_ = 0;
copy_dst_begin_ = periods_;
copy_size_ = input_shape[axis] - periods_;
} else if (periods_ < 0) {
fill_begin_ = input_shape[axis] + periods_;
fill_size_ = -periods_;
copy_src_begin_ = -periods_;
copy_dst_begin_ = 0;
copy_size_ = input_shape[axis] + periods_;
}
}
template <typename T>
bool ShiftCpuKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(EXCEPTION) << "Sort needs 1 input and 1 outputs, but get inputs: " << inputs.size()
<< "outputs: " << outputs.size();
}
auto input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr);
if (outputs[0]->size != inputs[0]->size) {
MS_LOG(EXCEPTION) << "Error output data size!";
}
// if periods_ is 0, do nothing
if (periods_ == 0) {
// directly copy input to output
memcpy(output, input, inputs[0]->size);
return true;
}
// periods is larger than size, all value of the tensor would be fill_value
if (std::abs(periods_) >= static_cast<int>(axis_size_)) {
std::fill_n(output, outer_size_ * axis_size_ * inner_size_, fill_value_);
return true;
}
if (inputs[0]->size != outer_size_ * axis_size_ * inner_size_ * sizeof(T)) {
MS_LOG(EXCEPTION) << "Error input data size!";
}
// check if the tensor is linear
if ((inner_size_ == 1) && (outer_size_ == 1)) {
// treat it as a simple 1D array
memcpy(output + copy_dst_begin_, input + copy_src_begin_, copy_size_ * sizeof(T));
std::fill_n(output + fill_begin_, fill_size_, fill_value_);
return true;
}
// normal procedure
std::vector<common::Task> tasks;
tasks.reserve(outer_size_);
for (size_t i = 0; i < outer_size_; ++i) {
tasks.emplace_back([this, i, input, output] {
size_t offset = i * axis_size_ * inner_size_;
size_t input_offset = offset + copy_src_begin_ * inner_size_;
size_t output_offset = offset + copy_dst_begin_ * inner_size_;
memcpy(output + output_offset, input + input_offset, copy_size_ * inner_size_ * sizeof(T));
size_t fill_offset = offset + fill_begin_ * inner_size_;
std::fill_n(output + fill_offset, fill_size_ * inner_size_, fill_value_);
return common::SUCCESS;
});
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ShiftCpuKernel : public CPUKernel {
public:
ShiftCpuKernel() = default;
~ShiftCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
// inputs
int64_t periods_{0};
T fill_value_{0};
// slice info
size_t outer_size_{0};
size_t axis_size_{0};
size_t inner_size_{0};
size_t copy_src_begin_{0};
size_t copy_dst_begin_{0};
size_t copy_size_{0};
size_t fill_begin_{0};
size_t fill_size_{0};
};
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ShiftCpuKernel,
bool)
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ShiftCpuKernel, float)
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ShiftCpuKernel, double)
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ShiftCpuKernel,
int32_t)
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ShiftCpuKernel,
int64_t)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_

View File

@ -0,0 +1,230 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Shift(PrimitiveWithInfer):
"""
Shift op frontend implementation
"""
@prim_attr_register
def __init__(self, periods=1, axis=-1, fill_value=np.nan):
"""Initialize Sort"""
self.periods = validator.check_value_type("periods", periods, [int], self.name)
self.axis = validator.check_value_type("axis", axis, [int], self.name)
self.fill_value = validator.check_value_type("fill_value", fill_value, [float], self.name)
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x_dtype", x_dtype,
[mstype.float32, mstype.float64, mstype.int32, mstype.int64, mstype.bool_],
self.name)
return x_dtype
class ShiftNet(nn.Cell):
def __init__(self, periods=1, axis=-1, fill_value=np.nan):
super(ShiftNet, self).__init__()
self.shift = Shift(periods, axis, fill_value)
def construct(self, x):
return self.shift(x)
def numpy_shift(array: np.ndarray, periods: int, axis: int, fill_value=np.nan) -> np.ndarray:
"""
numpy implementation for validation
"""
size = array.shape[axis]
assert periods in range(-size, size)
assert axis in range(-array.ndim, array.ndim)
copy_src_indices = [slice(None)] * array.ndim
copy_dst_indices = [slice(None)] * array.ndim
fill_indices = [slice(None)] * array.ndim
if periods > 0:
fill_indices[axis] = slice(None, periods)
copy_src_indices[axis] = slice(None, -periods)
copy_dst_indices[axis] = slice(periods, None)
elif periods < 0:
fill_indices[axis] = slice(periods, None)
copy_src_indices[axis] = slice(-periods, None)
copy_dst_indices[axis] = slice(None, periods)
else:
return array.copy()
result = np.empty_like(array)
result[tuple(fill_indices)] = fill_value
result[tuple(copy_dst_indices)] = array[tuple(copy_src_indices)]
return result
def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan):
numpy_result = numpy_shift(arr, periods=periods, axis=axis, fill_value=fill_value)
shift = ShiftNet(periods=periods, axis=axis, fill_value=fill_value)
mindspore_result = shift(Tensor(arr)).asnumpy()
print('numpy:\n')
print(numpy_result)
print('mindspore:\n')
print(mindspore_result)
assert np.allclose(numpy_result, mindspore_result, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('fill_value, dtype',
[(0.0, np.float32),
(0.0, np.float64),
(0.0, np.int32),
(0.0, np.int64),
(0.0, np.bool_),
(1.0, np.float32),
(1.0, np.float64),
(1.0, np.int32),
(1.0, np.int64),
(1.0, np.bool_),
(np.nan, np.float32),
(np.nan, np.float64),
(np.nan, np.bool_)]
)
def test_no_shift(fill_value, dtype):
arr = np.random.random((40, 60, 50, 30)).astype(dtype)
compare(arr, axis=0, periods=0, fill_value=fill_value)
compare(arr, axis=1, periods=0, fill_value=fill_value)
compare(arr, axis=2, periods=0, fill_value=fill_value)
compare(arr, axis=3, periods=0, fill_value=fill_value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('fill_value, dtype',
[(0.0, np.float32),
(0.0, np.float64),
(0.0, np.int32),
(0.0, np.int64),
(0.0, np.bool_),
(1.0, np.float32),
(1.0, np.float64),
(1.0, np.int32),
(1.0, np.int64),
(1.0, np.bool_),
(np.nan, np.float32),
(np.nan, np.float64),
(np.nan, np.bool_)]
)
def test_fancy_1d(fill_value, dtype):
arr = np.random.random((1, 1, 50, 1)).astype(dtype)
axis = 2
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
compare(arr, axis=axis, periods=28, fill_value=fill_value)
arr = np.random.random((70, 1, 1, 1)).astype(dtype)
axis = 0
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
compare(arr, axis=axis, periods=28, fill_value=fill_value)
arr = np.random.random((1, 1, 1, 80)).astype(dtype)
axis = 3
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
compare(arr, axis=axis, periods=28, fill_value=fill_value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('fill_value, dtype',
[(0.0, np.float32),
(0.0, np.float64),
(0.0, np.int32),
(0.0, np.int64),
(0.0, np.bool_),
(1.0, np.float32),
(1.0, np.float64),
(1.0, np.int32),
(1.0, np.int64),
(1.0, np.bool_),
(np.nan, np.float32),
(np.nan, np.float64),
(np.nan, np.bool_)]
)
def test_2d(fill_value, dtype):
arr = np.random.random((30, 40)).astype(dtype)
axis = 0
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
compare(arr, axis=axis, periods=27, fill_value=fill_value)
axis = 1
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
compare(arr, axis=axis, periods=28, fill_value=fill_value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('fill_value, dtype',
[(0.0, np.float32),
(0.0, np.float64),
(0.0, np.int32),
(0.0, np.int64),
(0.0, np.bool_),
(1.0, np.float32),
(1.0, np.float64),
(1.0, np.int32),
(1.0, np.int64),
(1.0, np.bool_),
(np.nan, np.float32),
(np.nan, np.float64),
(np.nan, np.bool_)]
)
def test_4d(fill_value, dtype):
arr = np.random.random((30, 40, 50, 60)).astype(dtype)
axis = 0
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
compare(arr, axis=axis, periods=28, fill_value=fill_value)
axis = 1
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
compare(arr, axis=axis, periods=34, fill_value=fill_value)
axis = 2
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
compare(arr, axis=axis, periods=48, fill_value=fill_value)
axis = 3
compare(arr, axis=axis, periods=-48, fill_value=fill_value)
compare(arr, axis=axis, periods=52, fill_value=fill_value)