!3699 fix scatterop error msg
Merge pull request !3699 from fangzehua/fix_scatterop
This commit is contained in:
commit
c700fc5515
|
@ -50,7 +50,7 @@ class _ScatterOp(PrimitiveWithInfer):
|
||||||
|
|
||||||
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
|
||||||
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
|
||||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
|
raise ValueError(f"For '{prim_name}', "
|
||||||
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
|
||||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class _ScatterNdOp(_ScatterOp):
|
||||||
validator.check('the dimension of x', len(x_shape),
|
validator.check('the dimension of x', len(x_shape),
|
||||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
|
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
|
||||||
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or updates_shape = "
|
raise ValueError(f"For '{prim_name}', updates_shape = "
|
||||||
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
|
f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
|
||||||
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue