forked from mindspore-Ecosystem/mindspore
fix infer of onehot
This commit is contained in:
parent
ac6d75b803
commit
156c774490
|
@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
||||||
} // namespace
|
} // namespace
|
||||||
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
|
auto infer_type = OneHotInferType(primitive, input_args);
|
||||||
|
auto infer_shape = OneHotInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
|
Loading…
Reference in New Issue