diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.cc new file mode 100644 index 00000000000..3a93b97aec3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +const size_t kReduceTypeAll = 1; +const size_t kReduceTypeAny = 2; +const size_t kMaxDim = 100; +static std::map reduce_types_map_ = {{"ReduceAll", 1}, {"ReduceAny", 2}}; + +template +void ReduceLogicCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + reduce_type_ = reduce_types_map_[kernel_name]; + if (reduce_type_ == 0) { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + CheckAxis(kernel_node); + if (shape_.empty()) { + shape_.push_back(1); + } + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] <= 0) { + MS_LOG(EXCEPTION) << "shape value is invalid."; + } + left_dims_ *= shape_[i]; + } + for (size_t i = 0; i < axis_.size(); ++i) { + stride_ *= shape_[axis_[i]]; + } + if (stride_ <= 0) { + MS_LOG(EXCEPTION) << "stride_ must greater than zero."; + } + left_dims_ = left_dims_ / stride_; +} + +template +bool ReduceLogicCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + size_t out_size = left_dims_ * sizeof(T); + size_t in_size = stride_ * out_size; + if (inputs[0]->size != in_size || outputs[0]->size != out_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + int size = inputs[0]->size / sizeof(T); + std::deque new_inputs(IntToSize(size), false); + std::vector transpose_axis; + for (size_t i = 0; i < shape_.size(); ++i) { + bool insert = true; + for (size_t j = 0; j < axis_.size(); ++j) { + if (axis_[j] == i) { + insert = false; + break; + } + } + if (insert) { + transpose_axis.push_back(i); + } + } + (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); + Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_inputs[0]); + ConvertDataToOutput(&new_inputs[0], output); + return true; +} + +template +void ReduceLogicCPUKernel::CheckAxis(const CNodePtr &kernel_node) { + auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); + if (axis_addr->isa() || axis_addr->isa()) { + std::vector attr_axis; + std::vector attr_axis_me = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + (void)std::transform(attr_axis_me.begin(), attr_axis_me.end(), std::back_inserter(attr_axis), + [](const int64_t &value) { return static_cast(value); }); + if (attr_axis.size() > shape_.size()) { + MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); + } else if (attr_axis.empty()) { + for (size_t i = 0; i < shape_.size(); ++i) { + axis_.push_back(i); + } + } else { + for (auto axis : attr_axis) { + while (axis < 0) { + axis += SizeToInt(shape_.size()); + } + if (IntToSize(axis) >= (shape_.size())) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis_.push_back(IntToSize(axis)); + } + } + } else if (axis_addr->isa()) { + int axis = static_cast(AnfAlgo::GetNodeAttr(kernel_node, AXIS)); + while (axis < 0) { + axis += SizeToInt(shape_.size()); + } + if (IntToSize(axis) >= shape_.size()) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis_.push_back(IntToSize(axis)); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } +} + +template +void ReduceLogicCPUKernel::ConvertDataToOutput(const T *new_input, T *output) { + if (reduce_type_ == kReduceTypeAll) { + for (size_t i = 0; i < left_dims_; ++i) { + auto value{true}; + for (size_t k = 0; k < stride_; ++k) { + value &= new_input[i * stride_ + k]; + } + output[i] = value; + } + } else if (reduce_type_ == kReduceTypeAny) { + for (size_t i = 0; i < left_dims_; ++i) { + auto value{false}; + for (size_t k = 0; k < stride_; ++k) { + value |= new_input[i * stride_ + k]; + } + output[i] = value; + } + } else { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << reduce_type_ << " is not supported."; + } +} + +template +void ReduceLogicCPUKernel::Transpose(const int size, const T *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, T *output) { + int size_offset[kMaxDim]; + size_offset[0] = size / SizeToInt(input_shape[0]); + for (int i = 1; i < shape_size; ++i) { + size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); + } + auto task = [&](size_t start, size_t end) { + int pos_array[kMaxDim]; + for (size_t position = start; position < end; position += 1) { + size_t temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (int i = 1; i < shape_size; ++i) { + temp_position -= pos_array[i - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + size_t new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; + size_t new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); + new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; + } + output[new_position] = input[position]; + } + }; + CPUKernelUtils::ParallelFor(task, size); + return; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.h new file mode 100644 index 00000000000..b94e52d5fce --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_logic_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_LOGIC_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_LOGIC_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class ReduceLogicCPUKernel : public CPUKernel { + public: + ReduceLogicCPUKernel() = default; + ~ReduceLogicCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void Transpose(const int size, const T *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, T *output); + void ConvertDataToOutput(const T *input, T *output); + void CheckAxis(const CNodePtr &kernel_node); + size_t reduce_type_ = 0; + std::vector axis_; + std::vector shape_; + size_t left_dims_ = 1; + size_t stride_ = 1; +}; + +MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ReduceLogicCPUKernel, bool); +MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ReduceLogicCPUKernel, bool); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_LOGIC_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 4e18c500140..f15ae1fab8d 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -474,7 +474,7 @@ class ReduceAll(_Reduce): the shape of output is :math:`(x_1, x_4, ..., x_R)`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([[True, False], [True, True]])) @@ -516,7 +516,7 @@ class ReduceAny(_Reduce): the shape of output is :math:`(x_1, x_4, ..., x_R)`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([[True, False], [True, True]])) diff --git a/tests/st/ops/cpu/test_reduce_op.py b/tests/st/ops/cpu/test_reduce_op.py index bae0c6c72bb..5876dd797d8 100644 --- a/tests/st/ops/cpu/test_reduce_op.py +++ b/tests/st/ops/cpu/test_reduce_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,6 +61,27 @@ class NetReduce(nn.Cell): self.reduce_min(indice, self.axis6)) +class NetReduceLogic(nn.Cell): + def __init__(self): + super(NetReduceLogic, self).__init__() + self.axis0 = 0 + self.axis1 = -1 + self.axis2 = (0, 1, 2) + self.axis3 = () + self.reduce_all = P.ReduceAll(False) + self.reduce_any = P.ReduceAny(False) + + @ms_function + def construct(self, indice): + return (self.reduce_all(indice, self.axis0), + self.reduce_all(indice, self.axis1), + self.reduce_all(indice, self.axis2), + self.reduce_all(indice, self.axis3), + self.reduce_any(indice, self.axis0), + self.reduce_any(indice, self.axis1), + self.reduce_any(indice, self.axis2), + self.reduce_any(indice, self.axis3),) + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @@ -125,4 +146,38 @@ def test_reduce(): assert (output[16].asnumpy() == expect_11).all() assert (output[17].asnumpy() == 0.0).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_reduce_logic(): + reduce_logic = NetReduceLogic() + indice_bool = Tensor([[[False, True, True, True, False, True], + [True, True, True, True, True, False]], + [[True, False, True, True, False, True], + [True, False, False, True, True, True]], + [[True, True, True, False, False, False], + [True, True, True, False, True, True]]]) + output = reduce_logic(indice_bool) + expect_all_1 = np.array([[False, False, True, False, False, False], + [True, False, False, False, True, False]]) + expect_all_2 = np.array([[False, False], [False, False], [False, False]]) + expect_all_3 = False + expect_all_4 = False + expect_any_1 = np.array([[True, True, True, True, False, True], [True, True, True, True, True, True]]) + expect_any_2 = np.array([[True, True], [True, True], [True, True]]) + expect_any_3 = True + expect_any_4 = True + + assert (output[0].asnumpy() == expect_all_1).all() + assert (output[1].asnumpy() == expect_all_2).all() + assert (output[2].asnumpy() == expect_all_3).all() + assert (output[3].asnumpy() == expect_all_4).all() + assert (output[4].asnumpy() == expect_any_1).all() + assert (output[5].asnumpy() == expect_any_2).all() + assert (output[6].asnumpy() == expect_any_3).all() + assert (output[7].asnumpy() == expect_any_4).all() + + test_reduce() +test_reduce_logic()