forked from mindspore-Ecosystem/mindspore
!13080 fix embeddinglookup infer
From: @fangzehua Reviewed-by: Signed-off-by:
This commit is contained in:
commit
654771df13
|
@ -362,7 +362,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input
|
|||
|
||||
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
|
||||
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
|
||||
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
|
||||
emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
|
||||
std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
|
||||
|
@ -373,7 +373,7 @@ AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, A
|
|||
AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
|
||||
AnfNodePtr miss_value) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable;
|
||||
PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
|
||||
std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
|
||||
miss_value};
|
||||
auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
|
||||
|
@ -383,7 +383,7 @@ AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_ta
|
|||
AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
|
||||
AnfNodePtr old_value) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache;
|
||||
PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
|
||||
update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
|
||||
|
||||
auto params_ori_shp = params->Shape();
|
||||
|
|
|
@ -210,6 +210,8 @@ constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
|
|||
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
|
||||
constexpr auto kPushOpName = "Push";
|
||||
constexpr auto kPullOpName = "Pull";
|
||||
constexpr auto kUpdateCacheOpName = "UpdateCache";
|
||||
constexpr auto kCacheSwapTableOpName = "CacheSwapTable";
|
||||
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
||||
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
|
||||
constexpr auto kGatherV2OpName = "Gather";
|
||||
|
|
|
@ -661,7 +661,6 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
|
|||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto params_shp = params->shape();
|
||||
MS_EXCEPTION_IF_NULL(params);
|
||||
|
@ -673,8 +672,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
|
|||
MS_EXCEPTION_IF_NULL(indices_shp);
|
||||
auto indices_shape = indices_shp->shape();
|
||||
auto indices_max_shape = indices_shp->max_shape();
|
||||
auto indices_min_shape = indices_shp->min_shape();
|
||||
ShapeVector shape;
|
||||
ShapeVector max_shape;
|
||||
ShapeVector min_shape;
|
||||
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
|
||||
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
|
||||
if (!indices_max_shape.empty()) {
|
||||
|
@ -683,9 +684,11 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
|
|||
} else {
|
||||
max_shape = shape;
|
||||
}
|
||||
ShapeVector min_shape;
|
||||
for (size_t i = 0; i < max_shape.size(); ++i) {
|
||||
min_shape.emplace_back(1);
|
||||
if (!indices_min_shape.empty()) {
|
||||
min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
|
||||
min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end());
|
||||
} else {
|
||||
min_shape = shape;
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret =
|
||||
|
|
|
@ -78,6 +78,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGather, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
|
||||
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
|
||||
|
@ -199,7 +200,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimLess, {InferImplLess, true}},
|
||||
{prim::kPrimStack, {InferImplStack, true}},
|
||||
{prim::kPrimPad, {InferImplPad, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
|
|
Loading…
Reference in New Issue