!21699 [EmbeddingLookup] Add a check for scalar indices of EmbeddingLookup

Merge pull request !21699 from Xiaoda/80-add-a-check-for-embeddinglookup-scalar-indice
This commit is contained in:
i-robot 2021-08-12 07:57:46 +00:00 committed by Gitee
commit e7b56e2f15
1 changed files with 3 additions and 0 deletions

View File

@ -5725,6 +5725,9 @@ class EmbeddingLookup(PrimitiveWithCheck):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
indices_shp = indices['shape']
if not indices_shp:
raise ValueError("'indices' should NOT be a scalar.")
params_shp = params['shape']
if len(params_shp) > 2:
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))