fix greater op infervalue
This commit is contained in:
parent
5c643a207f
commit
42c2676c1e
|
@ -27,6 +27,7 @@
|
||||||
#include "ops/add.h"
|
#include "ops/add.h"
|
||||||
#include "ops/equal.h"
|
#include "ops/equal.h"
|
||||||
#include "ops/greater_equal.h"
|
#include "ops/greater_equal.h"
|
||||||
|
#include "ops/greater.h"
|
||||||
#include "ops/not_equal.h"
|
#include "ops/not_equal.h"
|
||||||
#include "ops/neg.h"
|
#include "ops/neg.h"
|
||||||
#include "ops/mul.h"
|
#include "ops/mul.h"
|
||||||
|
@ -256,6 +257,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||||
{prim::kPrimNeg, R{ops::NegInfer, nullptr, false}},
|
{prim::kPrimNeg, R{ops::NegInfer, nullptr, false}},
|
||||||
{prim::kPrimTile, R{ops::TileInfer, nullptr, true}},
|
{prim::kPrimTile, R{ops::TileInfer, nullptr, true}},
|
||||||
{prim::kPrimEqual, R{ops::EqualInfer, nullptr, true}},
|
{prim::kPrimEqual, R{ops::EqualInfer, nullptr, true}},
|
||||||
|
{prim::kPrimGreaterEqual, R{ops::GreaterInfer, nullptr, true}},
|
||||||
{prim::kPrimGreaterEqual, R{ops::GreaterEqualInfer, nullptr, true}},
|
{prim::kPrimGreaterEqual, R{ops::GreaterEqualInfer, nullptr, true}},
|
||||||
{prim::kPrimNotEqual, R{ops::NotEqualInfer, nullptr, true}},
|
{prim::kPrimNotEqual, R{ops::NotEqualInfer, nullptr, true}},
|
||||||
{prim::kPrimLog, R{ops::LogInfer, nullptr, true}},
|
{prim::kPrimLog, R{ops::LogInfer, nullptr, true}},
|
||||||
|
|
|
@ -47,6 +47,6 @@ AbstractBasePtr GreaterInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
||||||
auto infer_shape = GreaterInferShape(primitive, input_args);
|
auto infer_shape = GreaterInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(Greater, prim::kPrimGreater, GreaterInfer, nullptr, true);
|
REGISTER_PRIMITIVE_C(kNameGreater, Greater);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -511,13 +511,8 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
||||||
MS_EXCEPTION(TypeError) << buffer.str();
|
MS_EXCEPTION(TypeError) << buffer.str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto check_type = _CheckTypeSame(types, prim_name, false);
|
(void)_CheckTypeSame(types, prim_name, false);
|
||||||
std::string input_names;
|
return CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
|
||||||
for (const auto &item : types) {
|
|
||||||
(void)input_names.append(item.first);
|
|
||||||
(void)input_names.append(", ");
|
|
||||||
}
|
|
||||||
return CheckTensorSubClass(input_names, check_type, check_list, prim_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||||
|
@ -538,7 +533,7 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return CheckTensorSubClass(type_name, element, check_list, prim_name);
|
return CheckTensorSubClass(type_name, type, check_list, prim_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
||||||
|
@ -577,8 +572,13 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
|
||||||
TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
|
TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
|
||||||
const std::set<TypePtr> &template_types, const string &prim_name,
|
const std::set<TypePtr> &template_types, const string &prim_name,
|
||||||
bool is_mix) {
|
bool is_mix) {
|
||||||
if (CheckType(type, template_types)) {
|
auto real_type = type;
|
||||||
return type;
|
if (type->isa<TensorType>()) {
|
||||||
|
auto tensor_type = type->cast<TensorTypePtr>();
|
||||||
|
real_type = tensor_type->element();
|
||||||
|
}
|
||||||
|
if (CheckType(real_type, template_types)) {
|
||||||
|
return real_type;
|
||||||
}
|
}
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
|
buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
|
||||||
|
@ -617,13 +617,8 @@ TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const
|
||||||
TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
|
TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
|
||||||
const std::set<TypePtr> &valid_values,
|
const std::set<TypePtr> &valid_values,
|
||||||
const std::string &prim_name, const bool allow_mix) {
|
const std::string &prim_name, const bool allow_mix) {
|
||||||
auto arg_ = _CheckTypeSame(args, prim_name, allow_mix);
|
(void)_CheckTypeSame(args, prim_name, allow_mix);
|
||||||
std::string input_names;
|
return CheckTensorSubClass(args.begin()->first, args.begin()->second, valid_values, prim_name, true);
|
||||||
for (const auto &item : args) {
|
|
||||||
(void)input_names.append(item.first);
|
|
||||||
(void)input_names.append(", ");
|
|
||||||
}
|
|
||||||
return CheckTensorSubClass(input_names, arg_, valid_values, prim_name, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
|
||||||
|
|
|
@ -26,7 +26,7 @@ resize_nearest_neighbor_op_info = TBERegOp("DynamicResizeNearestNeighbor") \
|
||||||
.dynamic_shape(True) \
|
.dynamic_shape(True) \
|
||||||
.attr("align_corners", "optional", "bool", "all", "false") \
|
.attr("align_corners", "optional", "bool", "all", "false") \
|
||||||
.attr("half_pixel_centers", "optional", "bool", "all", "false") \
|
.attr("half_pixel_centers", "optional", "bool", "all", "false") \
|
||||||
.input(0, "x", False, "reqxuired", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.input(1, "size", False, "required", "all") \
|
.input(1, "size", False, "required", "all") \
|
||||||
.output(0, "y", True, "required", "all") \
|
.output(0, "y", True, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
|
||||||
|
|
|
@ -3683,6 +3683,14 @@ class Greater(_LogicBinaryOp):
|
||||||
[False True False]
|
[False True False]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def infer_value(self, x, y):
|
||||||
|
if x is not None and y is not None:
|
||||||
|
x = x.asnumpy()
|
||||||
|
y = y.asnumpy()
|
||||||
|
out = np.array(np.greater(x, y))
|
||||||
|
return Tensor(out)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GreaterEqual(_LogicBinaryOp):
|
class GreaterEqual(_LogicBinaryOp):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -54,10 +54,10 @@ def test_invert_int_tensor():
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError) as err:
|
||||||
net(input_x)
|
net(input_x)
|
||||||
assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \
|
assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \
|
||||||
"but got Int32." in str(err.value)
|
"but got Tensor[Int32]." in str(err.value)
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
with pytest.raises(TypeError) as err:
|
with pytest.raises(TypeError) as err:
|
||||||
net(input_x)
|
net(input_x)
|
||||||
assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \
|
assert "For primitive[LogicalNot], the input argument[x] must be a type of { Tensor[Bool],}, " \
|
||||||
"but got Int32." in str(err.value)
|
"but got Tensor[Int32]." in str(err.value)
|
||||||
|
|
Loading…
Reference in New Issue