forked from mindspore-Ecosystem/mindspore
!26685 [Numpy-Native] fix searchsorted
Merge pull request !26685 from huangmengxi/fix_numpy
This commit is contained in:
commit
4deba2a463
|
@ -1026,7 +1026,7 @@ def searchsorted(x, v, side='left', sorter=None):
|
|||
i = F.fill(mstype.int32, shape, 0)
|
||||
j = F.fill(mstype.int32, shape, a.size)
|
||||
|
||||
sort_range = F.make_range(get_log2_size(F.shape_mul(shape) + 1))
|
||||
sort_range = F.make_range(get_log2_size(F.shape_mul(a.shape) + 1))
|
||||
for _ in sort_range:
|
||||
mid = (i - F.neg_tensor(j))//2
|
||||
mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,))))
|
||||
|
|
|
@ -1869,7 +1869,7 @@ class Tensor(Tensor_):
|
|||
j = tensor_operator_registry.get('fill')(mstype.int32, shape, a.size)
|
||||
|
||||
sort_range = tuple(range(validator.get_log2_size(
|
||||
tensor_operator_registry.get('shape_mul')(shape) + 1)))
|
||||
tensor_operator_registry.get('shape_mul')(a.shape) + 1)))
|
||||
for _ in sort_range:
|
||||
mid = (i - -j)//2
|
||||
mask = less_op(v, tensor_operator_registry.get('gather_nd')(a, mid.reshape(mid.shape + (1,))))
|
||||
|
|
Loading…
Reference in New Issue