forked from mindspore-Ecosystem/mindspore
Equal
This commit is contained in:
parent
41f8d65f2e
commit
3705b138df
|
@ -24,6 +24,8 @@
|
|||
#include "ops/exp.h"
|
||||
#include "ops/real_div.h"
|
||||
#include "ops/add.h"
|
||||
#include "ops/equal.h"
|
||||
#include "ops/not_equal.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "ops/tile.h"
|
||||
|
@ -175,8 +177,9 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
|
||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
||||
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
||||
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
||||
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
|
||||
{prim::kPrimEqual, {ops::EqualInfer, nullptr, true}},
|
||||
{prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
||||
|
|
|
@ -46,6 +46,8 @@ constexpr auto kScalarFloor = "ScalarFloor";
|
|||
constexpr auto kScalarUadd = "ScalarUadd";
|
||||
constexpr auto kScalarUsub = "ScalarUsub";
|
||||
constexpr auto kExp = "Exp";
|
||||
constexpr auto kEqual = "Equal";
|
||||
constexpr auto kNotEqual = "NotEqual";
|
||||
constexpr auto kSub = "Sub";
|
||||
constexpr auto kMul = "Mul";
|
||||
constexpr auto kRealDiv = "RealDiv";
|
||||
|
@ -112,8 +114,8 @@ inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
|||
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
||||
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>(kEqual);
|
||||
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>(kNotEqual);
|
||||
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
|
||||
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
|
||||
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,12 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/equal.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "ops/equal.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
@ -30,17 +30,27 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,10 +25,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameEqual = "Equal";
|
||||
constexpr auto kNameEqual = prim::kEqual;
|
||||
class Equal : public PrimitiveC {
|
||||
public:
|
||||
Equal() : PrimitiveC(kNameEqual) { InitIOName({"x", "y"}, {"output"}); }
|
||||
Equal() : PrimitiveC(prim::kPrimEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
|
||||
~Equal() = default;
|
||||
MS_DECLARE_PARENT(Equal, PrimitiveC);
|
||||
void Init() {}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,9 +15,47 @@
|
|||
*/
|
||||
|
||||
#include "ops/not_equal.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,20 +16,27 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
||||
#define MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameNotEqual = "NotEqual";
|
||||
constexpr auto kNameNotEqual = prim::kNotEqual;
|
||||
class NotEqual : public PrimitiveC {
|
||||
public:
|
||||
NotEqual() : PrimitiveC(kNameNotEqual) { InitIOName({"x", "y"}, {"output"}); }
|
||||
NotEqual() : PrimitiveC(prim::kPrimNotEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
|
||||
~NotEqual() = default;
|
||||
MS_DECLARE_PARENT(NotEqual, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimitiveNotEqualPtr = std::shared_ptr<NotEqual>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue