forked from mindspore-Ecosystem/mindspore
revise logical ops error message
This commit is contained in:
parent
b2a482fb04
commit
6e824737c9
|
@ -35,13 +35,20 @@ abstract::ShapePtr LogicalAndInferShape(const PrimitivePtr &primitive, const std
|
|||
}
|
||||
|
||||
TypePtr LogicalAndInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
auto infer_dtype = input_args[0]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return infer_dtype;
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_dtype);
|
||||
auto y_dtype = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(y_dtype);
|
||||
const std::basic_string<char> kBool = "Tensor[Bool]";
|
||||
std::ostringstream buffer;
|
||||
buffer << "For primitive[LogicalAnd], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
|
||||
if (x_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
|
||||
}
|
||||
if (y_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
|
||||
}
|
||||
return x_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -40,7 +40,8 @@ TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<Abstract
|
|||
auto op_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto infer_dtype = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, common_valid_types_with_complex_and_bool, op_name);
|
||||
std::set<TypePtr> local_bool = {kBool};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name);
|
||||
return kBool;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -35,13 +35,20 @@ abstract::ShapePtr LogicalOrInferShape(const PrimitivePtr &primitive, const std:
|
|||
}
|
||||
|
||||
TypePtr LogicalOrInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
auto infer_dtype = input_args[0]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return infer_dtype;
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_dtype);
|
||||
auto y_dtype = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(y_dtype);
|
||||
const std::basic_string<char> kBool = "Tensor[Bool]";
|
||||
std::ostringstream buffer;
|
||||
buffer << "For primitive[LogicalOr], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
|
||||
if (x_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
|
||||
}
|
||||
if (y_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
|
||||
}
|
||||
return x_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -33,11 +33,20 @@ abstract::ShapePtr LogicalXorInferShape(const PrimitivePtr &primitive, const std
|
|||
}
|
||||
|
||||
TypePtr LogicalXorInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_dtype);
|
||||
auto y_dtype = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(y_dtype);
|
||||
const std::basic_string<char> kBool = "Tensor[Bool]";
|
||||
std::ostringstream buffer;
|
||||
buffer << "For primitive[LogicalXor], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
|
||||
if (x_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
|
||||
}
|
||||
if (y_dtype->ToString() != kBool) {
|
||||
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
|
||||
}
|
||||
return x_dtype;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
Loading…
Reference in New Issue