!26685 [Numpy-Native] fix searchsorted

Merge pull request !26685 from huangmengxi/fix_numpy
This commit is contained in:
i-robot 2021-11-24 02:23:04 +00:00 committed by Gitee
commit 4deba2a463
2 changed files with 2 additions and 2 deletions

View File

@ -1026,7 +1026,7 @@ def searchsorted(x, v, side='left', sorter=None):
i = F.fill(mstype.int32, shape, 0)
j = F.fill(mstype.int32, shape, a.size)
sort_range = F.make_range(get_log2_size(F.shape_mul(shape) + 1))
sort_range = F.make_range(get_log2_size(F.shape_mul(a.shape) + 1))
for _ in sort_range:
mid = (i - F.neg_tensor(j))//2
mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,))))

View File

@ -1869,7 +1869,7 @@ class Tensor(Tensor_):
j = tensor_operator_registry.get('fill')(mstype.int32, shape, a.size)
sort_range = tuple(range(validator.get_log2_size(
tensor_operator_registry.get('shape_mul')(shape) + 1)))
tensor_operator_registry.get('shape_mul')(a.shape) + 1)))
for _ in sort_range:
mid = (i - -j)//2
mask = less_op(v, tensor_operator_registry.get('gather_nd')(a, mid.reshape(mid.shape + (1,))))