Add cpu op LogicalAnd, LogicalOr, LogicalNot

This commit is contained in:
yanglf1121 2021-01-19 13:23:05 +08:00
parent d4ef0452a6
commit 9c21f78f16
8 changed files with 184 additions and 4 deletions

View File

@ -180,6 +180,24 @@ void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out,
} }
} }
template <typename T>
void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] && input2[idx[1]];
}
}
template <typename T>
void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] || input2[idx[1]];
}
}
template <typename T> template <typename T>
void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) { void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
@ -248,6 +266,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = GREATEREQUAL; operate_type_ = GREATEREQUAL;
} else if (kernel_name == prim::kPrimLessEqual->name()) { } else if (kernel_name == prim::kPrimLessEqual->name()) {
operate_type_ = LESSEQUAL; operate_type_ = LESSEQUAL;
} else if (kernel_name == prim::kPrimLogicalAnd->name()) {
operate_type_ = LOGICALAND;
} else if (kernel_name == prim::kPrimLogicalOr->name()) {
operate_type_ = LOGICALOR;
} else if (kernel_name == prim::kPrimAssignAdd->name()) { } else if (kernel_name == prim::kPrimAssignAdd->name()) {
operate_type_ = ASSIGNADD; operate_type_ = ASSIGNADD;
} else if (kernel_name == prim::kPrimSquaredDifference->name()) { } else if (kernel_name == prim::kPrimSquaredDifference->name()) {
@ -366,6 +388,10 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input
std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end)); std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LESSEQUAL) { } else if (operate_type_ == LESSEQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual<T>, this, input1, input2, output, start, end)); threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALAND) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalAnd<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALOR) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalOr<T>, this, input1, input2, output, start, end));
} else { } else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_; MS_LOG(EXCEPTION) << "Not support " << operate_type_;
} }

View File

@ -71,6 +71,10 @@ class ArithmeticCPUKernel : public CPUKernel {
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T> template <typename T>
void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end);
std::vector<size_t> input_shape0_; std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_; std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_; std::vector<size_t> input_element_num0_;
@ -269,6 +273,12 @@ MS_REG_CPU_KERNEL(
LessEqual, LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel); ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -49,6 +49,13 @@ void Neg(const T *in, T *out, size_t start, size_t end) {
} }
} }
template <typename T>
void LogicalNot(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = !in[i];
}
}
template <typename T> template <typename T>
void OnesLike(const T *in, T *out, size_t start, size_t end) { void OnesLike(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
@ -99,6 +106,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = ZEROSLIKE; operate_type_ = ZEROSLIKE;
} else if (kernel_name == prim::kPrimNeg->name()) { } else if (kernel_name == prim::kPrimNeg->name()) {
operate_type_ = NEG; operate_type_ = NEG;
} else if (kernel_name == prim::kPrimLogicalNot->name()) {
operate_type_ = LOGICALNOT;
} else if (kernel_name == prim::kPrimSign->name()) { } else if (kernel_name == prim::kPrimSign->name()) {
operate_type_ = SIGN; operate_type_ = SIGN;
} else if (kernel_name == prim::kPrimFloor->name()) { } else if (kernel_name == prim::kPrimFloor->name()) {
@ -109,6 +118,7 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = GELU; operate_type_ = GELU;
} }
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
} }
bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -118,15 +128,55 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) { } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) {
LaunchKernelLogic<bool>(inputs, outputs);
} else { } else {
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
} }
return true; return true;
} }
template <typename T>
void ArithmeticSelfCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
}
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
}
template <typename T> template <typename T>
void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
if (target_dtype_ == kNumberTypeBool) {
LaunchKernelLogic<T>(inputs, outputs);
return;
}
T *input = reinterpret_cast<T *>(inputs[0]->addr); T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr); T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
@ -152,6 +202,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
threads.emplace_back(std::thread(Square<T>, input, output, start, end)); threads.emplace_back(std::thread(Square<T>, input, output, start, end));
} else if (operate_type_ == NEG) { } else if (operate_type_ == NEG) {
threads.emplace_back(std::thread(Neg<T>, input, output, start, end)); threads.emplace_back(std::thread(Neg<T>, input, output, start, end));
} else if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
} else if (operate_type_ == ONESLIKE) { } else if (operate_type_ == ONESLIKE) {
threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end)); threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end));
} else if (operate_type_ == ZEROSLIKE) { } else if (operate_type_ == ZEROSLIKE) {

View File

@ -35,9 +35,13 @@ class ArithmeticSelfCPUKernel : public CPUKernel {
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private: private:
OperateType operate_type_{SQUARE}; OperateType operate_type_{SQUARE};
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};
}; };
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -64,6 +68,8 @@ MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticSelfCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -84,6 +84,9 @@ enum OperateType {
EQUAL, EQUAL,
NOTEQUAL, NOTEQUAL,
LESSEQUAL, LESSEQUAL,
LOGICALAND,
LOGICALOR,
LOGICALNOT,
FLOOR, FLOOR,
SQUAREDDIFFERENCE, SQUAREDDIFFERENCE,
GREATER, GREATER,

View File

@ -61,6 +61,9 @@ inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");

View File

@ -2057,7 +2057,7 @@ class FloorDiv(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs. and the data type is the one with higher precision or higher digits among the two inputs.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
@ -2870,7 +2870,7 @@ class LogicalNot(PrimitiveWithInfer):
Tensor, the shape is the same as the `input_x`, and the dtype is bool. Tensor, the shape is the same as the `input_x`, and the dtype is bool.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_) >>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
@ -2913,7 +2913,7 @@ class LogicalAnd(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting, and the data type is bool. Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_) >>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
@ -2948,7 +2948,7 @@ class LogicalOr(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool. Tensor, the shape is the same as the one after broadcasting,and the data type is bool.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_) >>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)

View File

@ -0,0 +1,80 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op
def construct(self, *inputs):
return self.op(*inputs)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicaland():
op = P.LogicalAnd()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
input_y = Tensor(np.array([True, True, False]))
outputs = op_wrapper(input_x, input_y)
assert np.allclose(outputs.asnumpy(), (True, False, False))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicalor():
op = P.LogicalOr()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
input_y = Tensor(np.array([True, True, False]))
outputs = op_wrapper(input_x, input_y)
assert np.allclose(outputs.asnumpy(), (True, True, False))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicalnot():
op = P.LogicalNot()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
outputs = op_wrapper(input_x)
assert np.allclose(outputs.asnumpy(), (False, True, True))
if __name__ == '__main__':
test_logicaland()
test_logicalor()
test_logicalnot()