From a211f2f5951ebf929c9533ee018df7f51af58b5c Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Thu, 12 Aug 2021 10:53:44 +0800 Subject: [PATCH] add a check for scalar indices of embeddinglookup --- mindspore/ops/operations/array_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 50afe154728..f855a2da3ca 100755 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -5725,6 +5725,9 @@ class EmbeddingLookup(PrimitiveWithCheck): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) + indices_shp = indices['shape'] + if not indices_shp: + raise ValueError("'indices' should NOT be a scalar.") params_shp = params['shape'] if len(params_shp) > 2: raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))