forked from mindspore-Ecosystem/mindspore
!21844 [EmbeddingLookup] Add a check for the case of scalar indices
Merge pull request !21844 from Xiaoda/83-add-a-check-for-embeddinglookup-scalar-indices-r.14
This commit is contained in:
commit
c1a90b0870
|
@ -5721,6 +5721,9 @@ class EmbeddingLookup(PrimitiveWithCheck):
|
||||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||||
params_shp = params['shape']
|
params_shp = params['shape']
|
||||||
|
indices_shp = indices['shape']
|
||||||
|
if not indices_shp:
|
||||||
|
raise ValueError("'indices' should NOT be a scalar.")
|
||||||
if len(params_shp) > 2:
|
if len(params_shp) > 2:
|
||||||
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue