!17283 Fix different error type for OneHot between Ascend and GPU

Merge pull request !17283 from LiangZhibo/onehot
This commit is contained in:
i-robot 2021-06-24 03:32:11 +00:00 committed by Gitee
commit 9d64607345
1 changed files with 3 additions and 1 deletions

View File

@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
} // namespace
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
auto infer_type = OneHotInferType(primitive, input_args);
auto infer_shape = OneHotInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
} // namespace ops