From 3705b138df816b0182c8bc9649e8820fab065870 Mon Sep 17 00:00:00 2001 From: shen_jingxing Date: Wed, 2 Jun 2021 11:36:16 +0800 Subject: [PATCH] Equal --- .../core/abstract/primitive_infer_map.cc | 5 ++- mindspore/core/base/core_ops.h | 6 ++- mindspore/core/ops/equal.cc | 16 ++++++-- mindspore/core/ops/equal.h | 6 +-- mindspore/core/ops/not_equal.cc | 40 ++++++++++++++++++- mindspore/core/ops/not_equal.h | 13 ++++-- 6 files changed, 73 insertions(+), 13 deletions(-) diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 5c41e841559..03af4064052 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -24,6 +24,8 @@ #include "ops/exp.h" #include "ops/real_div.h" #include "ops/add.h" +#include "ops/equal.h" +#include "ops/not_equal.h" #include "abstract/abstract_function.h" #include "abstract/infer_functions.h" #include "ops/tile.h" @@ -175,8 +177,9 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { {prim::kPrimAdd, {ops::AddInfer, nullptr, false}}, {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, {prim::kPrimSub, {InferImplSub, nullptr, false}}, - {prim::kPrimEqual, {InferImplEqual, nullptr, true}}, {prim::kPrimTile, {ops::TileInfer, nullptr, true}}, + {prim::kPrimEqual, {ops::EqualInfer, nullptr, true}}, + {prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}}, {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e580471ae20..e24f37707df 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -46,6 +46,8 @@ constexpr auto kScalarFloor = "ScalarFloor"; constexpr auto kScalarUadd = "ScalarUadd"; constexpr auto kScalarUsub = "ScalarUsub"; constexpr auto kExp = "Exp"; +constexpr auto kEqual = "Equal"; +constexpr auto kNotEqual = "NotEqual"; constexpr auto kSub = "Sub"; constexpr auto kMul = "Mul"; constexpr auto kRealDiv = "RealDiv"; @@ -112,8 +114,8 @@ inline const PrimitivePtr kPrimGreater = std::make_shared("Greater"); inline const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); inline const PrimitivePtr kPrimLess = std::make_shared("Less"); inline const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); -inline const PrimitivePtr kPrimEqual = std::make_shared("Equal"); -inline const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); +inline const PrimitivePtr kPrimEqual = std::make_shared(kEqual); +inline const PrimitivePtr kPrimNotEqual = std::make_shared(kNotEqual); inline const PrimitivePtr kPrimLogicalAnd = std::make_shared("LogicalAnd"); inline const PrimitivePtr kPrimLogicalOr = std::make_shared("LogicalOr"); inline const PrimitivePtr kPrimLogicalNot = std::make_shared("LogicalNot"); diff --git a/mindspore/core/ops/equal.cc b/mindspore/core/ops/equal.cc index b9066430bee..4b829b25f15 100644 --- a/mindspore/core/ops/equal.cc +++ b/mindspore/core/ops/equal.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ +#include "ops/equal.h" #include #include #include #include #include -#include "ops/equal.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" @@ -30,17 +30,27 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } return BroadCastInferShape(op_name, input_args); } TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto op_name = prim->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { MS_LOG(EXCEPTION) << "nullptr"; } std::map types; types.emplace("x", input_args[0]->BuildType()); types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); } } // namespace diff --git a/mindspore/core/ops/equal.h b/mindspore/core/ops/equal.h index e5f86c19f2c..aebeae317b4 100644 --- a/mindspore/core/ops/equal.h +++ b/mindspore/core/ops/equal.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,10 +25,10 @@ namespace mindspore { namespace ops { -constexpr auto kNameEqual = "Equal"; +constexpr auto kNameEqual = prim::kEqual; class Equal : public PrimitiveC { public: - Equal() : PrimitiveC(kNameEqual) { InitIOName({"x", "y"}, {"output"}); } + Equal() : PrimitiveC(prim::kPrimEqual->name()) { InitIOName({"x", "y"}, {"output"}); } ~Equal() = default; MS_DECLARE_PARENT(Equal, PrimitiveC); void Init() {} diff --git a/mindspore/core/ops/not_equal.cc b/mindspore/core/ops/not_equal.cc index c303ee7159e..494a7e67e53 100644 --- a/mindspore/core/ops/not_equal.cc +++ b/mindspore/core/ops/not_equal.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,47 @@ */ #include "ops/not_equal.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" namespace mindspore { namespace ops { +namespace { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + return BroadCastInferShape(op_name, input_args); +} + +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto op_name = prim->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + std::map types; + types.emplace("x", input_args[0]->BuildType()); + types.emplace("y", input_args[1]->BuildType()); + return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); +} +} // namespace + +AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); +} REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/not_equal.h b/mindspore/core/ops/not_equal.h index 0ed50774001..852dc3ecc34 100644 --- a/mindspore/core/ops/not_equal.h +++ b/mindspore/core/ops/not_equal.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,27 @@ #ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_ #define MINDSPORE_CORE_OPS_NOT_EQUAL_H_ +#include +#include + #include "ops/primitive_c.h" #include "abstract/abstract_value.h" #include "utils/check_convert_utils.h" namespace mindspore { namespace ops { -constexpr auto kNameNotEqual = "NotEqual"; +constexpr auto kNameNotEqual = prim::kNotEqual; class NotEqual : public PrimitiveC { public: - NotEqual() : PrimitiveC(kNameNotEqual) { InitIOName({"x", "y"}, {"output"}); } + NotEqual() : PrimitiveC(prim::kPrimNotEqual->name()) { InitIOName({"x", "y"}, {"output"}); } ~NotEqual() = default; MS_DECLARE_PARENT(NotEqual, PrimitiveC); void Init() {} }; + +AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimitiveNotEqualPtr = std::shared_ptr; } // namespace ops } // namespace mindspore