forked from mindspore-Ecosystem/mindspore
fix infer of onehot
This commit is contained in:
parent
ac6d75b803
commit
156c774490
|
@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
} // namespace
|
||||
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
|
||||
auto infer_type = OneHotInferType(primitive, input_args);
|
||||
auto infer_shape = OneHotInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
Loading…
Reference in New Issue