!13080 fix embeddinglookup infer

From: @fangzehua
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-12 14:27:15 +08:00 committed by Gitee
commit 654771df13
4 changed files with 13 additions and 8 deletions

View File

@ -362,7 +362,7 @@ void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input
AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup;
PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
@ -373,7 +373,7 @@ AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, A
AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
AnfNodePtr miss_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable;
PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
miss_value};
auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
@ -383,7 +383,7 @@ AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_ta
AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
AnfNodePtr old_value) {
MS_EXCEPTION_IF_NULL(graph);
PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache;
PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
auto params_ori_shp = params->Shape();

View File

@ -210,6 +210,8 @@ constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
constexpr auto kPushOpName = "Push";
constexpr auto kPullOpName = "Pull";
constexpr auto kUpdateCacheOpName = "UpdateCache";
constexpr auto kCacheSwapTableOpName = "CacheSwapTable";
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
constexpr auto kGatherV2OpName = "Gather";

View File

@ -661,7 +661,6 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
MS_EXCEPTION_IF_NULL(params);
@ -673,8 +672,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
MS_EXCEPTION_IF_NULL(indices_shp);
auto indices_shape = indices_shp->shape();
auto indices_max_shape = indices_shp->max_shape();
auto indices_min_shape = indices_shp->min_shape();
ShapeVector shape;
ShapeVector max_shape;
ShapeVector min_shape;
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
if (!indices_max_shape.empty()) {
@ -683,9 +684,11 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit
} else {
max_shape = shape;
}
ShapeVector min_shape;
for (size_t i = 0; i < max_shape.size(); ++i) {
min_shape.emplace_back(1);
if (!indices_min_shape.empty()) {
min_shape.insert(min_shape.end(), indices_min_shape.begin(), indices_min_shape.end());
min_shape.insert(min_shape.end(), params_shape.begin() + 1, params_shape.end());
} else {
min_shape = shape;
}
AbstractTensorPtr ret =

View File

@ -78,6 +78,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGather, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
@ -199,7 +200,6 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimLess, {InferImplLess, true}},
{prim::kPrimStack, {InferImplStack, true}},
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},