forked from mindspore-Ecosystem/mindspore
Equal
This commit is contained in:
parent
41f8d65f2e
commit
3705b138df
|
@ -24,6 +24,8 @@
|
||||||
#include "ops/exp.h"
|
#include "ops/exp.h"
|
||||||
#include "ops/real_div.h"
|
#include "ops/real_div.h"
|
||||||
#include "ops/add.h"
|
#include "ops/add.h"
|
||||||
|
#include "ops/equal.h"
|
||||||
|
#include "ops/not_equal.h"
|
||||||
#include "abstract/abstract_function.h"
|
#include "abstract/abstract_function.h"
|
||||||
#include "abstract/infer_functions.h"
|
#include "abstract/infer_functions.h"
|
||||||
#include "ops/tile.h"
|
#include "ops/tile.h"
|
||||||
|
@ -175,8 +177,9 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||||
{prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
|
{prim::kPrimAdd, {ops::AddInfer, nullptr, false}},
|
||||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
||||||
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
{prim::kPrimSub, {InferImplSub, nullptr, false}},
|
||||||
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
|
||||||
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
|
{prim::kPrimTile, {ops::TileInfer, nullptr, true}},
|
||||||
|
{prim::kPrimEqual, {ops::EqualInfer, nullptr, true}},
|
||||||
|
{prim::kPrimNotEqual, {ops::NotEqualInfer, nullptr, true}},
|
||||||
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
||||||
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
||||||
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
||||||
|
|
|
@ -46,6 +46,8 @@ constexpr auto kScalarFloor = "ScalarFloor";
|
||||||
constexpr auto kScalarUadd = "ScalarUadd";
|
constexpr auto kScalarUadd = "ScalarUadd";
|
||||||
constexpr auto kScalarUsub = "ScalarUsub";
|
constexpr auto kScalarUsub = "ScalarUsub";
|
||||||
constexpr auto kExp = "Exp";
|
constexpr auto kExp = "Exp";
|
||||||
|
constexpr auto kEqual = "Equal";
|
||||||
|
constexpr auto kNotEqual = "NotEqual";
|
||||||
constexpr auto kSub = "Sub";
|
constexpr auto kSub = "Sub";
|
||||||
constexpr auto kMul = "Mul";
|
constexpr auto kMul = "Mul";
|
||||||
constexpr auto kRealDiv = "RealDiv";
|
constexpr auto kRealDiv = "RealDiv";
|
||||||
|
@ -112,8 +114,8 @@ inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
||||||
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||||
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||||
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||||
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>(kEqual);
|
||||||
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>(kNotEqual);
|
||||||
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
|
inline const PrimitivePtr kPrimLogicalAnd = std::make_shared<Primitive>("LogicalAnd");
|
||||||
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
|
inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalOr");
|
||||||
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,12 +14,12 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "ops/equal.h"
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/equal.h"
|
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
@ -30,17 +30,27 @@ namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
return BroadCastInferShape(op_name, input_args);
|
return BroadCastInferShape(op_name, input_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto op_name = prim->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||||
MS_LOG(EXCEPTION) << "nullptr";
|
MS_LOG(EXCEPTION) << "nullptr";
|
||||||
}
|
}
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
types.emplace("x", input_args[0]->BuildType());
|
||||||
types.emplace("y", input_args[1]->BuildType());
|
types.emplace("y", input_args[1]->BuildType());
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -25,10 +25,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameEqual = "Equal";
|
constexpr auto kNameEqual = prim::kEqual;
|
||||||
class Equal : public PrimitiveC {
|
class Equal : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
Equal() : PrimitiveC(kNameEqual) { InitIOName({"x", "y"}, {"output"}); }
|
Equal() : PrimitiveC(prim::kPrimEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
|
||||||
~Equal() = default;
|
~Equal() = default;
|
||||||
MS_DECLARE_PARENT(Equal, PrimitiveC);
|
MS_DECLARE_PARENT(Equal, PrimitiveC);
|
||||||
void Init() {}
|
void Init() {}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -15,9 +15,47 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ops/not_equal.h"
|
#include "ops/not_equal.h"
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
return BroadCastInferShape(op_name, input_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto op_name = prim->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
std::map<std::string, TypePtr> types;
|
||||||
|
types.emplace("x", input_args[0]->BuildType());
|
||||||
|
types.emplace("y", input_args[1]->BuildType());
|
||||||
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||||
|
InferShape(primitive, input_args)->shape());
|
||||||
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,20 +16,27 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
#ifndef MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
||||||
#define MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
#define MINDSPORE_CORE_OPS_NOT_EQUAL_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameNotEqual = "NotEqual";
|
constexpr auto kNameNotEqual = prim::kNotEqual;
|
||||||
class NotEqual : public PrimitiveC {
|
class NotEqual : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
NotEqual() : PrimitiveC(kNameNotEqual) { InitIOName({"x", "y"}, {"output"}); }
|
NotEqual() : PrimitiveC(prim::kPrimNotEqual->name()) { InitIOName({"x", "y"}, {"output"}); }
|
||||||
~NotEqual() = default;
|
~NotEqual() = default;
|
||||||
MS_DECLARE_PARENT(NotEqual, PrimitiveC);
|
MS_DECLARE_PARENT(NotEqual, PrimitiveC);
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
using PrimitiveNotEqualPtr = std::shared_ptr<NotEqual>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue