index_add fix wrong validation
This commit is contained in:
parent
70e7877155
commit
72f4a71b38
|
@ -4397,10 +4397,10 @@ class MatrixInverse(PrimitiveWithInfer):
|
|||
|
||||
class IndexAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
Adds tenosr y to specified axis and indices of tensor x.
|
||||
Adds tensor y to specified axis and indices of tensor x.
|
||||
|
||||
Args:
|
||||
axis (int): The dimension along wich to index.
|
||||
axis (int): The dimension along which to index.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
|
||||
|
@ -4453,8 +4453,6 @@ class IndexAdd(PrimitiveWithInfer):
|
|||
validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name)
|
||||
axis = self.axis if self.axis >= 0 else x_rank + self.axis
|
||||
for dim in range(x_rank):
|
||||
if dim == axis:
|
||||
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.GE, self.name)
|
||||
else:
|
||||
if dim != axis:
|
||||
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
|
Loading…
Reference in New Issue