From 09b5e545dd99795aa2b0d732155d0f766d88735f Mon Sep 17 00:00:00 2001 From: zheng_pengfei <18865382565@163.com> Date: Thu, 23 Dec 2021 20:35:26 +0800 Subject: [PATCH] [feat] [assistant] [I48OBC] add new LogicalNot --- mindspore/core/ops/logical_not.cc | 32 +++++++++++----- mindspore/core/ops/logical_not.h | 4 +- .../mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/logical_not_ds.py | 37 +++++++++++++++++++ .../mindspore/ops/operations/math_ops.py | 14 +------ tests/ut/python/pipeline/parse/test_invert.py | 8 ++-- 6 files changed, 68 insertions(+), 28 deletions(-) create mode 100644 mindspore/python/mindspore/ops/_op_impl/tbe/logical_not_ds.py diff --git a/mindspore/core/ops/logical_not.cc b/mindspore/core/ops/logical_not.cc index 5b71d133ee9..e6edb1bea24 100644 --- a/mindspore/core/ops/logical_not.cc +++ b/mindspore/core/ops/logical_not.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,26 @@ * limitations under the License. */ -#include "ops/logical_not.h" - +#include +#include #include +#include +#include +#include "ops/logical_not.h" #include "ops/op_utils.h" +#include "utils/check_convert_utils.h" namespace mindspore { namespace ops { namespace { abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - return std::make_shared(in_shape); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; + return std::make_shared(in_shape, min_shape, max_shape); } TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector &input_args) { @@ -35,14 +42,21 @@ TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vectorBuildType(); std::set local_bool = {kBool}; - return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); + return infer_dtype; } } // namespace + AbstractBasePtr LogicalNotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(LogicalNotInferType(primitive, input_args), - LogicalNotInferShape(primitive, input_args)->shape()); + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); + auto infer_type = LogicalNotInferType(primitive, input_args); + auto infer_shape = LogicalNotInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } -REGISTER_PRIMITIVE_C(kNameLogicalNot, LogicalNot); + +REGISTER_PRIMITIVE_EVAL_IMPL(LogicalNot, prim::kPrimLogicalNot, LogicalNotInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/logical_not.h b/mindspore/core/ops/logical_not.h index fe0ff299c86..3056b20af58 100644 --- a/mindspore/core/ops/logical_not.h +++ b/mindspore/core/ops/logical_not.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,9 +37,9 @@ class MS_CORE_API LogicalNot : public PrimitiveC { /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.LogicalNot for the inputs. void Init() {} }; - AbstractBasePtr LogicalNotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using kPrimLogicalNotPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index d9f1ebb831d..9c8568e2419 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -203,6 +203,7 @@ from .less_equal import _less_equal_tbe from .less_equal_ds import _less_equal_ds_tbe from .logical_and import _logical_and_tbe from .logical_not import _logical_not_tbe +from .logical_not_ds import _logical_not_ds_tbe from .logical_or import _logical_or_tbe from .logical_or_ds import _logical_or_ds_tbe from .reduce_max import _reduce_max_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/logical_not_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/logical_not_ds.py new file mode 100644 index 00000000000..fa0ef228e67 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/logical_not_ds.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LogicalNot op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +logical_not_op_info = TBERegOp("LogicalNot") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("logical_not.so") \ + .compute_cost(10) \ + .kernel_name("logical_not") \ + .dynamic_shape(True) \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.BOOL_None, DataType.BOOL_None) \ + .get_op_info() + + +@op_info_register(logical_not_op_info) +def _logical_not_ds_tbe(): + """LogicalNot TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 6ed49dcc974..123f8b8858c 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -3873,7 +3873,7 @@ class LessEqual(_LogicBinaryOp): """ -class LogicalNot(PrimitiveWithInfer): +class LogicalNot(Primitive): """ Computes the "logical NOT" of a tensor element-wise. @@ -3907,18 +3907,6 @@ class LogicalNot(PrimitiveWithInfer): """Initialize LogicalNot""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name + " or '~' operator") - return mstype.tensor_type(mstype.bool_) - - def infer_value(self, x): - if x is not None: - x = x.asnumpy() - return Tensor(np.logical_not(x)) - return None class LogicalAnd(_LogicBinaryOp): diff --git a/tests/ut/python/pipeline/parse/test_invert.py b/tests/ut/python/pipeline/parse/test_invert.py index 985f4a4f038..8792e8139b4 100644 --- a/tests/ut/python/pipeline/parse/test_invert.py +++ b/tests/ut/python/pipeline/parse/test_invert.py @@ -53,11 +53,11 @@ def test_invert_int_tensor(): context.set_context(mode=context.PYNATIVE_MODE) with pytest.raises(TypeError) as err: net(input_x) - assert "For 'LogicalNot or '~' operator', the type of 'x' should be Tensor[Bool], " \ - "but got Tensor[Int32]" in str(err.value) + assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \ + "but got Int32." in str(err.value) context.set_context(mode=context.GRAPH_MODE) with pytest.raises(TypeError) as err: net(input_x) - assert "For 'LogicalNot or '~' operator', the type of 'x' should be Tensor[Bool], " \ - "but got Tensor[Int32]" in str(err.value) + assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \ + "but got Int32." in str(err.value)