forked from mindspore-Ecosystem/mindspore
!17283 Fix different error type for OneHot between Ascend and GPU
Merge pull request !17283 from LiangZhibo/onehot
This commit is contained in:
commit
9d64607345
|
@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
} // namespace
|
||||
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
|
||||
auto infer_type = OneHotInferType(primitive, input_args);
|
||||
auto infer_shape = OneHotInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
Loading…
Reference in New Issue