diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 3fc2fad836b..cb4e2753b55 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -167,7 +167,12 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p ShapeVector max_shape = shape->shape(); auto ids = std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); - auto ids_idx = std::make_shared(std::make_shared(32), shape->shape()); + // Currently we choose the same data type as input for the idx. + TypePtr ids_idx_type = kInt32; + if (input->element() != nullptr && input->element()->GetTypeTrack() != nullptr) { + ids_idx_type = input->element()->GetTypeTrack(); + } + auto ids_idx = std::make_shared(ids_idx_type, shape->shape()); // outputs: ids, ids_idx AbstractBasePtrList elements = {ids, ids_idx}; return std::make_shared(elements);