diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index 6960c72b7a3..cb09b7db313 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -362,7 +362,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) { MS_EXCEPTION_IF_NULL(graph); - PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup; + PrimitivePtr emb_lookup_primitive = std::make_shared(kEmbeddingLookupOpName); emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); emb_lookup_primitive->set_attr(kAttrOffset, MakeValue(0)); std::vector emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices}; @@ -373,7 +373,7 @@ AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, A AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx, AnfNodePtr miss_value) { MS_EXCEPTION_IF_NULL(graph); - PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable; + PrimitivePtr cache_swap_table_primitive = std::make_shared(kCacheSwapTableOpName); std::vector cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx, miss_value}; auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes); @@ -383,7 +383,7 @@ AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_ta AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx, AnfNodePtr old_value) { MS_EXCEPTION_IF_NULL(graph); - PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache; + PrimitivePtr update_cache_primitive = std::make_shared(kUpdateCacheOpName); update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); auto params_ori_shp = params->Shape(); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 1d6008c6a87..9532c3ff3c0 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -210,6 +210,8 @@ constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; constexpr auto kPushOpName = "Push"; constexpr auto kPullOpName = "Pull"; +constexpr auto kUpdateCacheOpName = "UpdateCache"; +constexpr auto kCacheSwapTableOpName = "CacheSwapTable"; constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; constexpr auto kGatherV2OpName = "Gather"; diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index d09919eb7fd..d5937dbe899 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -661,7 +661,6 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); auto params = CheckArg(op_name, args_spec_list, 0); auto params_shp = params->shape(); MS_EXCEPTION_IF_NULL(params); @@ -673,8 +672,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit MS_EXCEPTION_IF_NULL(indices_shp); auto indices_shape = indices_shp->shape(); auto indices_max_shape = indices_shp->max_shape(); + auto indices_min_shape = indices_shp->min_shape(); ShapeVector shape; ShapeVector max_shape; + ShapeVector min_shape; shape.insert(shape.end(), indices_shape.begin(), indices_shape.end()); shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end()); if (!indices_max_shape.empty()) { @@ -683,9 +684,11 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit } else { max_shape = shape; } - ShapeVector min_shape; - for (size_t i = 0; i < max_shape.size(); ++i) { - min_shape.emplace_back(1); + if (!indices_min_shape.empty()) { + min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end()); + min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end()); + } else { + min_shape = shape; } AbstractTensorPtr ret = diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index e1cd7eb4d3a..43ea9cf3926 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -78,6 +78,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUnique, {InferImplUnique, true}}, {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimGather, {InferImplGatherV2, true}}, + {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, @@ -199,7 +200,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { {prim::kPrimLess, {InferImplLess, true}}, {prim::kPrimStack, {InferImplStack, true}}, {prim::kPrimPad, {InferImplPad, true}}, - {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimDiv, {InferImplDiv, true}}, {prim::kPrimRealDiv, {InferImplRealDiv, true}},