diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index b85568f505e..8c94e57eba5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -62,6 +62,8 @@ const char DELTA[] = "delta"; const char SORTED[] = "sorted"; const char ADJ_ST[] = "adjoint_st"; const char ADJ_dT[] = "adjoint_dt"; +const char PERIODS[] = "periods"; +const char FILL_VALUE[] = "fill_value"; enum OperateType { ADD = 0, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.cc new file mode 100644 index 00000000000..42cf8de9429 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/shift_cpu_kernel.h" +#include +#include +#include "common/thread_pool.h" + +namespace mindspore { +namespace kernel { +template +void ShiftCpuKernel::InitKernel(const CNodePtr &kernel_node) { + size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_count != 1) { + MS_LOG(EXCEPTION) << input_count << " inputs were provided, but Shift expects 1."; + } + + size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_count != 1) { + MS_LOG(EXCEPTION) << "Number of outputs is " << output_count << ", but should be 1 for Shift."; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex); + + periods_ = AnfAlgo::GetNodeAttr(kernel_node, PERIODS); + auto fill_value = AnfAlgo::GetNodeAttr(kernel_node, FILL_VALUE); + if constexpr (std::is_same::value || std::is_same::value) { + if (std::isnan(fill_value)) { + MS_LOG(EXCEPTION) << "for integer input, nan is not supported for fill_value "; + } + } + fill_value_ = static_cast(fill_value); + + auto axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis < 0) { + axis += input_shape.size(); + } + if ((axis < 0) || (axis >= static_cast(input_shape.size()))) { + MS_LOG(EXCEPTION) << "axis should be smaller than the dimension of input tensor " << input_shape.size() + << "D, but got " << axis; + } + + // calculating axis size + size_t axis_t = axis; + + outer_size_ = 1; + for (size_t i = 0; i < axis_t; i++) { + outer_size_ *= input_shape[i]; + } + + axis_size_ = input_shape[axis_t]; + + inner_size_ = 1; + for (size_t i = axis_t + 1; i < input_shape.size(); ++i) { + inner_size_ *= input_shape[i]; + } + + // index calculation + if (periods_ > 0) { + fill_begin_ = 0; + fill_size_ = periods_; + + copy_src_begin_ = 0; + copy_dst_begin_ = periods_; + copy_size_ = input_shape[axis] - periods_; + } else if (periods_ < 0) { + fill_begin_ = input_shape[axis] + periods_; + fill_size_ = -periods_; + + copy_src_begin_ = -periods_; + copy_dst_begin_ = 0; + copy_size_ = input_shape[axis] + periods_; + } +} + +template +bool ShiftCpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "Sort needs 1 input and 1 outputs, but get inputs: " << inputs.size() + << "outputs: " << outputs.size(); + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + + if (outputs[0]->size != inputs[0]->size) { + MS_LOG(EXCEPTION) << "Error output data size!"; + } + + // if periods_ is 0, do nothing + if (periods_ == 0) { + // directly copy input to output + memcpy(output, input, inputs[0]->size); + return true; + } + + // periods is larger than size, all value of the tensor would be fill_value + if (std::abs(periods_) >= static_cast(axis_size_)) { + std::fill_n(output, outer_size_ * axis_size_ * inner_size_, fill_value_); + return true; + } + + if (inputs[0]->size != outer_size_ * axis_size_ * inner_size_ * sizeof(T)) { + MS_LOG(EXCEPTION) << "Error input data size!"; + } + + // check if the tensor is linear + if ((inner_size_ == 1) && (outer_size_ == 1)) { + // treat it as a simple 1D array + memcpy(output + copy_dst_begin_, input + copy_src_begin_, copy_size_ * sizeof(T)); + std::fill_n(output + fill_begin_, fill_size_, fill_value_); + return true; + } + + // normal procedure + std::vector tasks; + tasks.reserve(outer_size_); + for (size_t i = 0; i < outer_size_; ++i) { + tasks.emplace_back([this, i, input, output] { + size_t offset = i * axis_size_ * inner_size_; + size_t input_offset = offset + copy_src_begin_ * inner_size_; + size_t output_offset = offset + copy_dst_begin_ * inner_size_; + memcpy(output + output_offset, input + input_offset, copy_size_ * inner_size_ * sizeof(T)); + size_t fill_offset = offset + fill_begin_ * inner_size_; + std::fill_n(output + fill_offset, fill_size_ * inner_size_, fill_value_); + return common::SUCCESS; + }); + } + common::ThreadPool::GetInstance().SyncRun(tasks); + return true; +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.h new file mode 100644 index 00000000000..2482760803e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/shift_cpu_kernel.h @@ -0,0 +1,73 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/cpu/nnacl/op_base.h" + +namespace mindspore { +namespace kernel { +template +class ShiftCpuKernel : public CPUKernel { + public: + ShiftCpuKernel() = default; + ~ShiftCpuKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + // inputs + int64_t periods_{0}; + T fill_value_{0}; + + // slice info + size_t outer_size_{0}; + size_t axis_size_{0}; + size_t inner_size_{0}; + + size_t copy_src_begin_{0}; + size_t copy_dst_begin_{0}; + size_t copy_size_{0}; + + size_t fill_begin_{0}; + size_t fill_size_{0}; +}; + +MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ShiftCpuKernel, + bool) + +MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ShiftCpuKernel, float) + +MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ShiftCpuKernel, double) + +MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ShiftCpuKernel, + int32_t) + +MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ShiftCpuKernel, + int64_t) + +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_shift_op.py b/tests/st/ops/cpu/test_shift_op.py new file mode 100644 index 00000000000..43b25a7c954 --- /dev/null +++ b/tests/st/ops/cpu/test_shift_op.py @@ -0,0 +1,230 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import PrimitiveWithInfer, prim_attr_register +from mindspore._checkparam import Validator as validator +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class Shift(PrimitiveWithInfer): + """ + Shift op frontend implementation + """ + + @prim_attr_register + def __init__(self, periods=1, axis=-1, fill_value=np.nan): + """Initialize Sort""" + self.periods = validator.check_value_type("periods", periods, [int], self.name) + self.axis = validator.check_value_type("axis", axis, [int], self.name) + self.fill_value = validator.check_value_type("fill_value", fill_value, [float], self.name) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x_dtype", x_dtype, + [mstype.float32, mstype.float64, mstype.int32, mstype.int64, mstype.bool_], + self.name) + return x_dtype + + +class ShiftNet(nn.Cell): + def __init__(self, periods=1, axis=-1, fill_value=np.nan): + super(ShiftNet, self).__init__() + self.shift = Shift(periods, axis, fill_value) + + def construct(self, x): + return self.shift(x) + + +def numpy_shift(array: np.ndarray, periods: int, axis: int, fill_value=np.nan) -> np.ndarray: + """ + numpy implementation for validation + """ + size = array.shape[axis] + assert periods in range(-size, size) + assert axis in range(-array.ndim, array.ndim) + + copy_src_indices = [slice(None)] * array.ndim + copy_dst_indices = [slice(None)] * array.ndim + fill_indices = [slice(None)] * array.ndim + + if periods > 0: + fill_indices[axis] = slice(None, periods) + copy_src_indices[axis] = slice(None, -periods) + copy_dst_indices[axis] = slice(periods, None) + elif periods < 0: + fill_indices[axis] = slice(periods, None) + copy_src_indices[axis] = slice(-periods, None) + copy_dst_indices[axis] = slice(None, periods) + else: + return array.copy() + + result = np.empty_like(array) + result[tuple(fill_indices)] = fill_value + result[tuple(copy_dst_indices)] = array[tuple(copy_src_indices)] + + return result + + +def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan): + numpy_result = numpy_shift(arr, periods=periods, axis=axis, fill_value=fill_value) + shift = ShiftNet(periods=periods, axis=axis, fill_value=fill_value) + mindspore_result = shift(Tensor(arr)).asnumpy() + + print('numpy:\n') + print(numpy_result) + print('mindspore:\n') + print(mindspore_result) + assert np.allclose(numpy_result, mindspore_result, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('fill_value, dtype', + [(0.0, np.float32), + (0.0, np.float64), + (0.0, np.int32), + (0.0, np.int64), + (0.0, np.bool_), + (1.0, np.float32), + (1.0, np.float64), + (1.0, np.int32), + (1.0, np.int64), + (1.0, np.bool_), + (np.nan, np.float32), + (np.nan, np.float64), + (np.nan, np.bool_)] + ) +def test_no_shift(fill_value, dtype): + arr = np.random.random((40, 60, 50, 30)).astype(dtype) + + compare(arr, axis=0, periods=0, fill_value=fill_value) + compare(arr, axis=1, periods=0, fill_value=fill_value) + compare(arr, axis=2, periods=0, fill_value=fill_value) + compare(arr, axis=3, periods=0, fill_value=fill_value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('fill_value, dtype', + [(0.0, np.float32), + (0.0, np.float64), + (0.0, np.int32), + (0.0, np.int64), + (0.0, np.bool_), + (1.0, np.float32), + (1.0, np.float64), + (1.0, np.int32), + (1.0, np.int64), + (1.0, np.bool_), + (np.nan, np.float32), + (np.nan, np.float64), + (np.nan, np.bool_)] + ) +def test_fancy_1d(fill_value, dtype): + arr = np.random.random((1, 1, 50, 1)).astype(dtype) + + axis = 2 + compare(arr, axis=axis, periods=-35, fill_value=fill_value) + compare(arr, axis=axis, periods=28, fill_value=fill_value) + + arr = np.random.random((70, 1, 1, 1)).astype(dtype) + + axis = 0 + compare(arr, axis=axis, periods=-35, fill_value=fill_value) + compare(arr, axis=axis, periods=28, fill_value=fill_value) + + arr = np.random.random((1, 1, 1, 80)).astype(dtype) + + axis = 3 + compare(arr, axis=axis, periods=-35, fill_value=fill_value) + compare(arr, axis=axis, periods=28, fill_value=fill_value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('fill_value, dtype', + [(0.0, np.float32), + (0.0, np.float64), + (0.0, np.int32), + (0.0, np.int64), + (0.0, np.bool_), + (1.0, np.float32), + (1.0, np.float64), + (1.0, np.int32), + (1.0, np.int64), + (1.0, np.bool_), + (np.nan, np.float32), + (np.nan, np.float64), + (np.nan, np.bool_)] + ) +def test_2d(fill_value, dtype): + arr = np.random.random((30, 40)).astype(dtype) + axis = 0 + compare(arr, axis=axis, periods=-24, fill_value=fill_value) + compare(arr, axis=axis, periods=27, fill_value=fill_value) + + axis = 1 + compare(arr, axis=axis, periods=-35, fill_value=fill_value) + compare(arr, axis=axis, periods=28, fill_value=fill_value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('fill_value, dtype', + [(0.0, np.float32), + (0.0, np.float64), + (0.0, np.int32), + (0.0, np.int64), + (0.0, np.bool_), + (1.0, np.float32), + (1.0, np.float64), + (1.0, np.int32), + (1.0, np.int64), + (1.0, np.bool_), + (np.nan, np.float32), + (np.nan, np.float64), + (np.nan, np.bool_)] + ) +def test_4d(fill_value, dtype): + arr = np.random.random((30, 40, 50, 60)).astype(dtype) + + axis = 0 + compare(arr, axis=axis, periods=-24, fill_value=fill_value) + compare(arr, axis=axis, periods=28, fill_value=fill_value) + + axis = 1 + compare(arr, axis=axis, periods=-24, fill_value=fill_value) + compare(arr, axis=axis, periods=34, fill_value=fill_value) + + axis = 2 + compare(arr, axis=axis, periods=-24, fill_value=fill_value) + compare(arr, axis=axis, periods=48, fill_value=fill_value) + + axis = 3 + compare(arr, axis=axis, periods=-48, fill_value=fill_value) + compare(arr, axis=axis, periods=52, fill_value=fill_value)