forked from mindspore-Ecosystem/mindspore
add Shift cpu OP
This commit is contained in:
parent
bff5cbda9c
commit
c7febdcca4
|
@ -62,6 +62,8 @@ const char DELTA[] = "delta";
|
|||
const char SORTED[] = "sorted";
|
||||
const char ADJ_ST[] = "adjoint_st";
|
||||
const char ADJ_dT[] = "adjoint_dt";
|
||||
const char PERIODS[] = "periods";
|
||||
const char FILL_VALUE[] = "fill_value";
|
||||
|
||||
enum OperateType {
|
||||
ADD = 0,
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/shift_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void ShiftCpuKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
MS_LOG(EXCEPTION) << input_count << " inputs were provided, but Shift expects 1.";
|
||||
}
|
||||
|
||||
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_count != 1) {
|
||||
MS_LOG(EXCEPTION) << "Number of outputs is " << output_count << ", but should be 1 for Shift.";
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
|
||||
|
||||
periods_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, PERIODS);
|
||||
auto fill_value = AnfAlgo::GetNodeAttr<float>(kernel_node, FILL_VALUE);
|
||||
if constexpr (std::is_same<T, int64_t>::value || std::is_same<T, int32_t>::value) {
|
||||
if (std::isnan(fill_value)) {
|
||||
MS_LOG(EXCEPTION) << "for integer input, nan is not supported for fill_value ";
|
||||
}
|
||||
}
|
||||
fill_value_ = static_cast<T>(fill_value);
|
||||
|
||||
auto axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
if (axis < 0) {
|
||||
axis += input_shape.size();
|
||||
}
|
||||
if ((axis < 0) || (axis >= static_cast<int64_t>(input_shape.size()))) {
|
||||
MS_LOG(EXCEPTION) << "axis should be smaller than the dimension of input tensor " << input_shape.size()
|
||||
<< "D, but got " << axis;
|
||||
}
|
||||
|
||||
// calculating axis size
|
||||
size_t axis_t = axis;
|
||||
|
||||
outer_size_ = 1;
|
||||
for (size_t i = 0; i < axis_t; i++) {
|
||||
outer_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
axis_size_ = input_shape[axis_t];
|
||||
|
||||
inner_size_ = 1;
|
||||
for (size_t i = axis_t + 1; i < input_shape.size(); ++i) {
|
||||
inner_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
// index calculation
|
||||
if (periods_ > 0) {
|
||||
fill_begin_ = 0;
|
||||
fill_size_ = periods_;
|
||||
|
||||
copy_src_begin_ = 0;
|
||||
copy_dst_begin_ = periods_;
|
||||
copy_size_ = input_shape[axis] - periods_;
|
||||
} else if (periods_ < 0) {
|
||||
fill_begin_ = input_shape[axis] + periods_;
|
||||
fill_size_ = -periods_;
|
||||
|
||||
copy_src_begin_ = -periods_;
|
||||
copy_dst_begin_ = 0;
|
||||
copy_size_ = input_shape[axis] + periods_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ShiftCpuKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Sort needs 1 input and 1 outputs, but get inputs: " << inputs.size()
|
||||
<< "outputs: " << outputs.size();
|
||||
}
|
||||
auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
if (outputs[0]->size != inputs[0]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error output data size!";
|
||||
}
|
||||
|
||||
// if periods_ is 0, do nothing
|
||||
if (periods_ == 0) {
|
||||
// directly copy input to output
|
||||
memcpy(output, input, inputs[0]->size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// periods is larger than size, all value of the tensor would be fill_value
|
||||
if (std::abs(periods_) >= static_cast<int>(axis_size_)) {
|
||||
std::fill_n(output, outer_size_ * axis_size_ * inner_size_, fill_value_);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (inputs[0]->size != outer_size_ * axis_size_ * inner_size_ * sizeof(T)) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
|
||||
// check if the tensor is linear
|
||||
if ((inner_size_ == 1) && (outer_size_ == 1)) {
|
||||
// treat it as a simple 1D array
|
||||
memcpy(output + copy_dst_begin_, input + copy_src_begin_, copy_size_ * sizeof(T));
|
||||
std::fill_n(output + fill_begin_, fill_size_, fill_value_);
|
||||
return true;
|
||||
}
|
||||
|
||||
// normal procedure
|
||||
std::vector<common::Task> tasks;
|
||||
tasks.reserve(outer_size_);
|
||||
for (size_t i = 0; i < outer_size_; ++i) {
|
||||
tasks.emplace_back([this, i, input, output] {
|
||||
size_t offset = i * axis_size_ * inner_size_;
|
||||
size_t input_offset = offset + copy_src_begin_ * inner_size_;
|
||||
size_t output_offset = offset + copy_dst_begin_ * inner_size_;
|
||||
memcpy(output + output_offset, input + input_offset, copy_size_ * inner_size_ * sizeof(T));
|
||||
size_t fill_offset = offset + fill_begin_ * inner_size_;
|
||||
std::fill_n(output + fill_offset, fill_size_ * inner_size_, fill_value_);
|
||||
return common::SUCCESS;
|
||||
});
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ShiftCpuKernel : public CPUKernel {
|
||||
public:
|
||||
ShiftCpuKernel() = default;
|
||||
~ShiftCpuKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
// inputs
|
||||
int64_t periods_{0};
|
||||
T fill_value_{0};
|
||||
|
||||
// slice info
|
||||
size_t outer_size_{0};
|
||||
size_t axis_size_{0};
|
||||
size_t inner_size_{0};
|
||||
|
||||
size_t copy_src_begin_{0};
|
||||
size_t copy_dst_begin_{0};
|
||||
size_t copy_size_{0};
|
||||
|
||||
size_t fill_begin_{0};
|
||||
size_t fill_size_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ShiftCpuKernel,
|
||||
bool)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ShiftCpuKernel, float)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ShiftCpuKernel, double)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ShiftCpuKernel,
|
||||
int32_t)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Shift, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ShiftCpuKernel,
|
||||
int64_t)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SORT_CPU_KERNEL_H_
|
|
@ -0,0 +1,230 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import PrimitiveWithInfer, prim_attr_register
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class Shift(PrimitiveWithInfer):
|
||||
"""
|
||||
Shift op frontend implementation
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, periods=1, axis=-1, fill_value=np.nan):
|
||||
"""Initialize Sort"""
|
||||
self.periods = validator.check_value_type("periods", periods, [int], self.name)
|
||||
self.axis = validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.fill_value = validator.check_value_type("fill_value", fill_value, [float], self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid("x_dtype", x_dtype,
|
||||
[mstype.float32, mstype.float64, mstype.int32, mstype.int64, mstype.bool_],
|
||||
self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ShiftNet(nn.Cell):
|
||||
def __init__(self, periods=1, axis=-1, fill_value=np.nan):
|
||||
super(ShiftNet, self).__init__()
|
||||
self.shift = Shift(periods, axis, fill_value)
|
||||
|
||||
def construct(self, x):
|
||||
return self.shift(x)
|
||||
|
||||
|
||||
def numpy_shift(array: np.ndarray, periods: int, axis: int, fill_value=np.nan) -> np.ndarray:
|
||||
"""
|
||||
numpy implementation for validation
|
||||
"""
|
||||
size = array.shape[axis]
|
||||
assert periods in range(-size, size)
|
||||
assert axis in range(-array.ndim, array.ndim)
|
||||
|
||||
copy_src_indices = [slice(None)] * array.ndim
|
||||
copy_dst_indices = [slice(None)] * array.ndim
|
||||
fill_indices = [slice(None)] * array.ndim
|
||||
|
||||
if periods > 0:
|
||||
fill_indices[axis] = slice(None, periods)
|
||||
copy_src_indices[axis] = slice(None, -periods)
|
||||
copy_dst_indices[axis] = slice(periods, None)
|
||||
elif periods < 0:
|
||||
fill_indices[axis] = slice(periods, None)
|
||||
copy_src_indices[axis] = slice(-periods, None)
|
||||
copy_dst_indices[axis] = slice(None, periods)
|
||||
else:
|
||||
return array.copy()
|
||||
|
||||
result = np.empty_like(array)
|
||||
result[tuple(fill_indices)] = fill_value
|
||||
result[tuple(copy_dst_indices)] = array[tuple(copy_src_indices)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compare(arr: np.ndarray, periods: int, axis: int, fill_value=np.nan):
|
||||
numpy_result = numpy_shift(arr, periods=periods, axis=axis, fill_value=fill_value)
|
||||
shift = ShiftNet(periods=periods, axis=axis, fill_value=fill_value)
|
||||
mindspore_result = shift(Tensor(arr)).asnumpy()
|
||||
|
||||
print('numpy:\n')
|
||||
print(numpy_result)
|
||||
print('mindspore:\n')
|
||||
print(mindspore_result)
|
||||
assert np.allclose(numpy_result, mindspore_result, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('fill_value, dtype',
|
||||
[(0.0, np.float32),
|
||||
(0.0, np.float64),
|
||||
(0.0, np.int32),
|
||||
(0.0, np.int64),
|
||||
(0.0, np.bool_),
|
||||
(1.0, np.float32),
|
||||
(1.0, np.float64),
|
||||
(1.0, np.int32),
|
||||
(1.0, np.int64),
|
||||
(1.0, np.bool_),
|
||||
(np.nan, np.float32),
|
||||
(np.nan, np.float64),
|
||||
(np.nan, np.bool_)]
|
||||
)
|
||||
def test_no_shift(fill_value, dtype):
|
||||
arr = np.random.random((40, 60, 50, 30)).astype(dtype)
|
||||
|
||||
compare(arr, axis=0, periods=0, fill_value=fill_value)
|
||||
compare(arr, axis=1, periods=0, fill_value=fill_value)
|
||||
compare(arr, axis=2, periods=0, fill_value=fill_value)
|
||||
compare(arr, axis=3, periods=0, fill_value=fill_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('fill_value, dtype',
|
||||
[(0.0, np.float32),
|
||||
(0.0, np.float64),
|
||||
(0.0, np.int32),
|
||||
(0.0, np.int64),
|
||||
(0.0, np.bool_),
|
||||
(1.0, np.float32),
|
||||
(1.0, np.float64),
|
||||
(1.0, np.int32),
|
||||
(1.0, np.int64),
|
||||
(1.0, np.bool_),
|
||||
(np.nan, np.float32),
|
||||
(np.nan, np.float64),
|
||||
(np.nan, np.bool_)]
|
||||
)
|
||||
def test_fancy_1d(fill_value, dtype):
|
||||
arr = np.random.random((1, 1, 50, 1)).astype(dtype)
|
||||
|
||||
axis = 2
|
||||
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=28, fill_value=fill_value)
|
||||
|
||||
arr = np.random.random((70, 1, 1, 1)).astype(dtype)
|
||||
|
||||
axis = 0
|
||||
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=28, fill_value=fill_value)
|
||||
|
||||
arr = np.random.random((1, 1, 1, 80)).astype(dtype)
|
||||
|
||||
axis = 3
|
||||
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=28, fill_value=fill_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('fill_value, dtype',
|
||||
[(0.0, np.float32),
|
||||
(0.0, np.float64),
|
||||
(0.0, np.int32),
|
||||
(0.0, np.int64),
|
||||
(0.0, np.bool_),
|
||||
(1.0, np.float32),
|
||||
(1.0, np.float64),
|
||||
(1.0, np.int32),
|
||||
(1.0, np.int64),
|
||||
(1.0, np.bool_),
|
||||
(np.nan, np.float32),
|
||||
(np.nan, np.float64),
|
||||
(np.nan, np.bool_)]
|
||||
)
|
||||
def test_2d(fill_value, dtype):
|
||||
arr = np.random.random((30, 40)).astype(dtype)
|
||||
axis = 0
|
||||
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=27, fill_value=fill_value)
|
||||
|
||||
axis = 1
|
||||
compare(arr, axis=axis, periods=-35, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=28, fill_value=fill_value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('fill_value, dtype',
|
||||
[(0.0, np.float32),
|
||||
(0.0, np.float64),
|
||||
(0.0, np.int32),
|
||||
(0.0, np.int64),
|
||||
(0.0, np.bool_),
|
||||
(1.0, np.float32),
|
||||
(1.0, np.float64),
|
||||
(1.0, np.int32),
|
||||
(1.0, np.int64),
|
||||
(1.0, np.bool_),
|
||||
(np.nan, np.float32),
|
||||
(np.nan, np.float64),
|
||||
(np.nan, np.bool_)]
|
||||
)
|
||||
def test_4d(fill_value, dtype):
|
||||
arr = np.random.random((30, 40, 50, 60)).astype(dtype)
|
||||
|
||||
axis = 0
|
||||
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=28, fill_value=fill_value)
|
||||
|
||||
axis = 1
|
||||
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=34, fill_value=fill_value)
|
||||
|
||||
axis = 2
|
||||
compare(arr, axis=axis, periods=-24, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=48, fill_value=fill_value)
|
||||
|
||||
axis = 3
|
||||
compare(arr, axis=axis, periods=-48, fill_value=fill_value)
|
||||
compare(arr, axis=axis, periods=52, fill_value=fill_value)
|
Loading…
Reference in New Issue