fix equal infer dtype bug: not bool

This commit is contained in:
lizhenyu 2021-06-22 17:27:22 +08:00
parent e7ea93dacd
commit df69ae0c44
2 changed files with 4 additions and 2 deletions

View File

@ -56,7 +56,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
(void)InferType(primitive, input_args);
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
}
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
} // namespace ops

View File

@ -53,7 +53,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
(void)InferType(primitive, input_args);
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
}
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
} // namespace ops