This commit is contained in:
shen_jingxing 2021-06-02 11:36:16 +08:00
parent 41f8d65f2e
commit 3705b138df
6 changed files with 73 additions and 13 deletions

View File

@ -24,6 +24,8 @@
#include "ops/exp.h" #include "ops/exp.h"
#include "ops/real_div.h" #include "ops/real_div.h"
#include "ops/add.h" #include "ops/add.h"
#include "ops/equal.h"
#include "ops/not_equal.h"
#include "abstract/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/infer_functions.h" #include "abstract/infer_functions.h"
#include "ops/tile.h" #include "ops/tile.h"
@ -175,8 +177,9 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimAdd, {ops::AddInfer, nullptr, false}}, {prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
{prim::kPrimSub, {InferImplSub, nullptr, false}}, {prim::kPrimSub, {InferImplSub, nullptr, false}},
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
{prim::kPrimTile, {ops::TileInfer, nullptr, true}}, {prim::kPrimTile, {ops::TileInfer, nullptr, true}},
{prim::kPrimEqual, {ops::EqualInfer, nullptr, true}},
{prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},

View File

@ -46,6 +46,8 @@ constexpr auto kScalarFloor = "ScalarFloor";
constexpr auto kScalarUadd = "ScalarUadd"; constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub"; constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kExp = "Exp"; constexpr auto kExp = "Exp";
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
constexpr auto kSub = "Sub"; constexpr auto kSub = "Sub";
constexpr auto kMul = "Mul"; constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv"; constexpr auto kRealDiv = "RealDiv";
@ -112,8 +114,8 @@ inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual"); inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>(kEqual);
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual"); inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>(kNotEqual);
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd"); inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr"); inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot"); inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#include "ops/equal.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "ops/equal.h"
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h" #include "abstract/primitive_infer_map.h"
@ -30,17 +30,27 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(op_name, input_args); return BroadCastInferShape(op_name, input_args);
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr"; MS_LOG(EXCEPTION) << "nullptr";
} }
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType()); types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType()); types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
} }
} // namespace } // namespace

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -25,10 +25,10 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
constexpr auto kNameEqual = "Equal"; constexpr auto kNameEqual = prim::kEqual;
class Equal : public PrimitiveC { class Equal : public PrimitiveC {
public: public:
Equal() : PrimitiveC(kNameEqual) { InitIOName({"x", "y"}, {"output"}); } Equal() : PrimitiveC(prim::kPrimEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
~Equal() = default; ~Equal() = default;
MS_DECLARE_PARENT(Equal, PrimitiveC); MS_DECLARE_PARENT(Equal, PrimitiveC);
void Init() {} void Init() {}

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,9 +15,47 @@
*/ */
#include "ops/not_equal.h" #include "ops/not_equal.h"
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(op_name, input_args);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
}
} // namespace
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual); REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,20 +16,27 @@
#ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_ #ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_
#define MINDSPORE_CORE_OPS_NOT_EQUAL_H_ #define MINDSPORE_CORE_OPS_NOT_EQUAL_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
constexpr auto kNameNotEqual = "NotEqual"; constexpr auto kNameNotEqual = prim::kNotEqual;
class NotEqual : public PrimitiveC { class NotEqual : public PrimitiveC {
public: public:
NotEqual() : PrimitiveC(kNameNotEqual) { InitIOName({"x", "y"}, {"output"}); } NotEqual() : PrimitiveC(prim::kPrimNotEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
~NotEqual() = default; ~NotEqual() = default;
MS_DECLARE_PARENT(NotEqual, PrimitiveC); MS_DECLARE_PARENT(NotEqual, PrimitiveC);
void Init() {} void Init() {}
}; };
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimitiveNotEqualPtr = std::shared_ptr<NotEqual>;
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore