revise logical ops error message

This commit is contained in:
w00517672 2023-03-03 17:08:29 +08:00
parent b2a482fb04
commit 6e824737c9
4 changed files with 44 additions and 20 deletions

View File

@ -35,13 +35,20 @@ abstract::ShapePtr LogicalAndInferShape(const PrimitivePtr &primitive, const std
}
TypePtr LogicalAndInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
auto infer_dtype = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kBool};
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return infer_dtype;
auto x_dtype = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_dtype);
auto y_dtype = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(y_dtype);
const std::basic_string<char> kBool = "Tensor[Bool]";
std::ostringstream buffer;
buffer << "For primitive[LogicalAnd], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
if (x_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
}
if (y_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
}
return x_dtype;
}
} // namespace

View File

@ -40,7 +40,8 @@ TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<Abstract
auto op_name = prim->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto infer_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, common_valid_types_with_complex_and_bool, op_name);
std::set<TypePtr> local_bool = {kBool};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name);
return kBool;
}
} // namespace

View File

@ -35,13 +35,20 @@ abstract::ShapePtr LogicalOrInferShape(const PrimitivePtr &primitive, const std:
}
TypePtr LogicalOrInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
auto infer_dtype = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kBool};
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return infer_dtype;
auto x_dtype = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_dtype);
auto y_dtype = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(y_dtype);
const std::basic_string<char> kBool = "Tensor[Bool]";
std::ostringstream buffer;
buffer << "For primitive[LogicalOr], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
if (x_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
}
if (y_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
}
return x_dtype;
}
} // namespace

View File

@ -33,11 +33,20 @@ abstract::ShapePtr LogicalXorInferShape(const PrimitivePtr &primitive, const std
}
TypePtr LogicalXorInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kBool};
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
auto x_dtype = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_dtype);
auto y_dtype = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(y_dtype);
const std::basic_string<char> kBool = "Tensor[Bool]";
std::ostringstream buffer;
buffer << "For primitive[LogicalXor], the input argument[x, y, ] must be a type of {Tensor[Bool], }, but got ";
if (x_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << x_dtype->ToString() << ".";
}
if (y_dtype->ToString() != kBool) {
MS_EXCEPTION(TypeError) << buffer.str() << y_dtype->ToString() << ".";
}
return x_dtype;
}
} // namespace