fix equal infer dtype bug: not bool
This commit is contained in:
parent
e7ea93dacd
commit
df69ae0c44
|
@ -56,7 +56,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr EqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
(void)InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
|
||||
} // namespace ops
|
||||
|
|
|
@ -53,7 +53,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
(void)InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), kBool);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||
} // namespace ops
|
||||
|
|
Loading…
Reference in New Issue