Add cpu op LogicalAnd, LogicalOr, LogicalNot

This commit is contained in:
yanglf1121 2021-01-19 13:23:05 +08:00
parent d4ef0452a6
commit 9c21f78f16
8 changed files with 184 additions and 4 deletions

View File

@ -180,6 +180,24 @@ void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out,
}
}
template <typename T>
void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] && input2[idx[1]];
}
}
template <typename T>
void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] || input2[idx[1]];
}
}
template <typename T>
void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
@ -248,6 +266,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = GREATEREQUAL;
} else if (kernel_name == prim::kPrimLessEqual->name()) {
operate_type_ = LESSEQUAL;
} else if (kernel_name == prim::kPrimLogicalAnd->name()) {
operate_type_ = LOGICALAND;
} else if (kernel_name == prim::kPrimLogicalOr->name()) {
operate_type_ = LOGICALOR;
} else if (kernel_name == prim::kPrimAssignAdd->name()) {
operate_type_ = ASSIGNADD;
} else if (kernel_name == prim::kPrimSquaredDifference->name()) {
@ -366,6 +388,10 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input
std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LESSEQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALAND) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalAnd<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALOR) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalOr<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}

View File

@ -71,6 +71,10 @@ class ArithmeticCPUKernel : public CPUKernel {
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
@ -269,6 +273,12 @@ MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -49,6 +49,13 @@ void Neg(const T *in, T *out, size_t start, size_t end) {
}
}
template <typename T>
void LogicalNot(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = !in[i];
}
}
template <typename T>
void OnesLike(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
@ -99,6 +106,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = ZEROSLIKE;
} else if (kernel_name == prim::kPrimNeg->name()) {
operate_type_ = NEG;
} else if (kernel_name == prim::kPrimLogicalNot->name()) {
operate_type_ = LOGICALNOT;
} else if (kernel_name == prim::kPrimSign->name()) {
operate_type_ = SIGN;
} else if (kernel_name == prim::kPrimFloor->name()) {
@ -109,6 +118,7 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = GELU;
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
}
bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -118,15 +128,55 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) {
LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) {
LaunchKernelLogic<bool>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support.";
}
return true;
}
template <typename T>
void ArithmeticSelfCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
}
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
}
template <typename T>
void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
if (target_dtype_ == kNumberTypeBool) {
LaunchKernelLogic<T>(inputs, outputs);
return;
}
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
@ -152,6 +202,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
threads.emplace_back(std::thread(Square<T>, input, output, start, end));
} else if (operate_type_ == NEG) {
threads.emplace_back(std::thread(Neg<T>, input, output, start, end));
} else if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
} else if (operate_type_ == ONESLIKE) {
threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end));
} else if (operate_type_ == ZEROSLIKE) {

View File

@ -35,9 +35,13 @@ class ArithmeticSelfCPUKernel : public CPUKernel {
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
OperateType operate_type_{SQUARE};
TypeId dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
@ -64,6 +68,8 @@ MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticSelfCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -84,6 +84,9 @@ enum OperateType {
EQUAL,
NOTEQUAL,
LESSEQUAL,
LOGICALAND,
LOGICALOR,
LOGICALNOT,
FLOOR,
SQUAREDDIFFERENCE,
GREATER,

View File

@ -61,6 +61,9 @@ inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");

View File

@ -2057,7 +2057,7 @@ class FloorDiv(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)
@ -2870,7 +2870,7 @@ class LogicalNot(PrimitiveWithInfer):
Tensor, the shape is the same as the `input_x`, and the dtype is bool.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
@ -2913,7 +2913,7 @@ class LogicalAnd(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting, and the data type is bool.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)
@ -2948,7 +2948,7 @@ class LogicalOr(_LogicBinaryOp):
Tensor, the shape is the same as the one after broadcasting,and the data type is bool.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([True, False, True]), mindspore.bool_)

View File

@ -0,0 +1,80 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class OpNetWrapper(nn.Cell):
def __init__(self, op):
super(OpNetWrapper, self).__init__()
self.op = op
def construct(self, *inputs):
return self.op(*inputs)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicaland():
op = P.LogicalAnd()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
input_y = Tensor(np.array([True, True, False]))
outputs = op_wrapper(input_x, input_y)
assert np.allclose(outputs.asnumpy(), (True, False, False))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicalor():
op = P.LogicalOr()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
input_y = Tensor(np.array([True, True, False]))
outputs = op_wrapper(input_x, input_y)
assert np.allclose(outputs.asnumpy(), (True, True, False))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_logicalnot():
op = P.LogicalNot()
op_wrapper = OpNetWrapper(op)
input_x = Tensor(np.array([True, False, False]))
outputs = op_wrapper(input_x)
assert np.allclose(outputs.asnumpy(), (False, True, True))
if __name__ == '__main__':
test_logicaland()
test_logicalor()
test_logicalnot()