fix infer of onehot

This commit is contained in:
l00591931 2021-05-29 11:03:40 +08:00
parent ac6d75b803
commit 156c774490
1 changed files with 3 additions and 1 deletions

View File

@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
} // namespace
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
auto infer_type = OneHotInferType(primitive, input_args);
auto infer_shape = OneHotInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
} // namespace ops