forked from mindspore-Ecosystem/mindspore
fix embeddinglookupgrad when param shape is one dim
This commit is contained in:
parent
055e99021d
commit
e64a53bf1b
|
@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self):
|
|||
# Reshape the 'new_indices'
|
||||
new_indices_shape_changed = (size_op(new_indices),)
|
||||
new_indices = reshape_op(new_indices, new_indices_shape_changed)
|
||||
x_shp_tail = x_shp[1:]
|
||||
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||
actual_dout_shape_changed = new_indices_shape_changed
|
||||
if len(x_shp) > 1:
|
||||
actual_dout_shape_changed += x_shp[1:]
|
||||
actual_dout = reshape_op(dout, actual_dout_shape_changed)
|
||||
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
|
||||
return bprop
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
|
Loading…
Reference in New Issue