From e64a53bf1b979ba8613c1a80cb439a3af1c6cad6 Mon Sep 17 00:00:00 2001 From: wuxuejian Date: Fri, 3 Jul 2020 10:59:10 +0800 Subject: [PATCH] fix embeddinglookupgrad when param shape is one dim --- mindspore/ops/_grad/grad_array_ops.py | 5 +++-- tests/st/ops/ascend/test_embedding_lookup.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index b53a7412fca..1155fb7c034 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self): # Reshape the 'new_indices' new_indices_shape_changed = (size_op(new_indices),) new_indices = reshape_op(new_indices, new_indices_shape_changed) - x_shp_tail = x_shp[1:] - actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail + actual_dout_shape_changed = new_indices_shape_changed + if len(x_shp) > 1: + actual_dout_shape_changed += x_shp[1:] actual_dout = reshape_op(dout, actual_dout_shape_changed) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return bprop diff --git a/tests/st/ops/ascend/test_embedding_lookup.py b/tests/st/ops/ascend/test_embedding_lookup.py index 483fdcdbc4d..6aee25d9da5 100644 --- a/tests/st/ops/ascend/test_embedding_lookup.py +++ b/tests/st/ops/ascend/test_embedding_lookup.py @@ -15,6 +15,7 @@ import numpy as np import mindspore.context as context +import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P