forked from mindspore-Ecosystem/mindspore
!16360 numpy-native bug fix
From: @jachua Reviewed-by: @liangchenghui,@guoqi1024 Signed-off-by: @liangchenghui
This commit is contained in:
commit
44334dfb56
|
@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32)
|
|||
BroadcastToGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastToGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
BroadcastToGpuKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastToGpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
BroadcastToGpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -697,6 +697,12 @@ template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2,
|
|||
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
|
||||
const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr,
|
||||
half *output_addr, cudaStream_t stream);
|
||||
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
|
||||
const size_t &o1, const size_t &o2, const size_t &o3, const int16_t *input_addr,
|
||||
int16_t *output_addr, cudaStream_t stream);
|
||||
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
|
||||
const size_t &o1, const size_t &o2, const size_t &o3, const int32_t *input_addr,
|
||||
int32_t *output_addr, cudaStream_t stream);
|
||||
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
|
||||
const size_t &o1, const size_t &o2, const size_t &o3, const int64_t *input_addr,
|
||||
int64_t *output_addr, cudaStream_t stream);
|
||||
|
|
|
@ -48,6 +48,8 @@ MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
|
|
@ -4512,6 +4512,8 @@ def digitize(x, bins, right=False):
|
|||
[1 3 3 4 5]
|
||||
"""
|
||||
x, bins = _to_tensor(x, bins)
|
||||
if F.rank(bins) > 1:
|
||||
_raise_value_error('bins should be 1-dimensional')
|
||||
if x.size == 0:
|
||||
return x
|
||||
if bins.size == 0:
|
||||
|
@ -4542,7 +4544,7 @@ def bincount(x, weights=None, minlength=0, length=None):
|
|||
are treated as zeros instead.
|
||||
|
||||
Args:
|
||||
x (Union[int, float, bool, list, tuple, Tensor]): 1-d input array.
|
||||
x (Union[list, tuple, Tensor]): 1-d input array.
|
||||
weights (Union[int, float, bool, list, tuple, Tensor], optional): Weights,
|
||||
array of the same shape as `x`.
|
||||
minlength (int, optional): A minimum number of bins for the output array.
|
||||
|
@ -4573,6 +4575,8 @@ def bincount(x, weights=None, minlength=0, length=None):
|
|||
x = _to_tensor(x)
|
||||
if F.rank(x) != 1:
|
||||
_raise_value_error('`x` should be one-dimensional')
|
||||
if not _check_is_int(F.dtype(x)):
|
||||
_raise_type_error('`x` should be an array of ints')
|
||||
x = clip(x, 0, None)
|
||||
if length is None:
|
||||
if F.isconstant(x):
|
||||
|
@ -4781,7 +4785,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False): # pyl
|
|||
if _get_device() == 'Ascend':
|
||||
stacked_indices = F.cast(stacked_indices, mstype.float32)
|
||||
flattened_hist = F.reduce_sum(stacked_indices.astype(mstype.float32), 0)
|
||||
count = bincount(flattened_hist, weights, length=flattened_bin_size)
|
||||
count = bincount(flattened_hist.astype(mstype.int32), weights, length=flattened_bin_size)
|
||||
count = F.reshape(count, nbin)
|
||||
slices = _list_comprehensions(ndim, F.make_slice(1, -1, 1), True)
|
||||
count = count[slices]
|
||||
|
@ -5128,8 +5132,8 @@ def polymul(a1, a2):
|
|||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> print(np.polyder([1, 1, 1, 1]))
|
||||
[3 2 1]
|
||||
>>> print(np.polymul([3, 1, 2], [2, 5]))
|
||||
[ 6 17 9 10]
|
||||
"""
|
||||
a1 = _to_poly1d(a1)
|
||||
a2 = _to_poly1d(a2)
|
||||
|
@ -5176,7 +5180,7 @@ def polyint(p, m=1, k=None):
|
|||
k = atleast_1d(_to_tensor(k))
|
||||
if k.size == 1:
|
||||
k = F.tile(k, (m,))
|
||||
k = F.reshape(k, (m, 1))
|
||||
k = F.expand_dims(k, -1)
|
||||
for i in range(m):
|
||||
coeff = _to_tensor(F.make_range(_type_convert(int, p.size), 0, -1))
|
||||
p = concatenate((true_divide(p, coeff), k[i]))
|
||||
|
|
|
@ -553,7 +553,7 @@ def get_bprop_pow(self):
|
|||
|
||||
def bprop(x, power, out, dout):
|
||||
bc_dx = power * pow_op(x, power - 1.0) * dout
|
||||
x[x < 0] = 1
|
||||
x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
|
||||
bc_dpower = out * ln(x) * dout
|
||||
return binop_grad_common(x, power, bc_dx, bc_dpower)
|
||||
|
||||
|
|
|
@ -615,7 +615,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|||
data_shape = F.shape(data)
|
||||
data_shape = const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value)
|
||||
size = F.shape_mul(F.shape(value))
|
||||
size = const_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
|
@ -636,7 +636,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|||
|
||||
|
||||
def tensor_setitem_by_tensor_with_number(data, index, value):
|
||||
value = F.fill(F.dtype(data), (1,), value)
|
||||
value = F.fill(F.dtype(data), (), value)
|
||||
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
||||
|
||||
|
||||
|
@ -868,8 +868,8 @@ def remove_expanded_dims(tuple_index, data_shape, value):
|
|||
expand_true = has_true and not(has_false or has_sequence) # whether to expand dimension at True
|
||||
tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
|
||||
rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
|
||||
not_expanded_dim = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
|
||||
rem_ndim, not_expanded_dim)
|
||||
not_expanded_dim, idx_advanced = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
|
||||
rem_ndim, not_expanded_dim)
|
||||
if not indices_out:
|
||||
indices_out = (True,)
|
||||
|
||||
|
|
|
@ -781,7 +781,13 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim
|
|||
else:
|
||||
tensor_dims = (True,)*tensor_index_ndim
|
||||
not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
|
||||
return not_expanded_dim + (True,)*rem_ndim
|
||||
not_expanded_dim = not_expanded_dim + (True,)*rem_ndim
|
||||
|
||||
count_leading_false = 0
|
||||
while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]:
|
||||
count_leading_false += 1
|
||||
idx_advanced = max(0, idx_advanced - count_leading_false)
|
||||
return not_expanded_dim, idx_advanced
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -2164,13 +2164,15 @@ def test_digitize():
|
|||
def test_bincount():
|
||||
x = onp.random.randint(0, 10, 20)
|
||||
weights = onp.random.randn(20)
|
||||
match_res(mnp.bincount, onp.bincount, x)
|
||||
match_res(mnp.bincount, onp.bincount, x, minlength=25)
|
||||
match_res(mnp.bincount, onp.bincount, x, weights, error=3)
|
||||
match_res(mnp.bincount, onp.bincount, x, weights, minlength=25, error=3)
|
||||
match_res(mnp.bincount, onp.bincount, x, dtype=mnp.int32)
|
||||
match_res(mnp.bincount, onp.bincount, x, minlength=25, dtype=mnp.int32)
|
||||
match_all_arrays(mnp.bincount(to_tensor(x), to_tensor(weights)),
|
||||
onp.bincount(x, weights), error=3)
|
||||
match_all_arrays(mnp.bincount(to_tensor(x), to_tensor(weights), minlength=25),
|
||||
onp.bincount(x, weights, minlength=25), error=3)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -2193,7 +2195,7 @@ def test_histogram():
|
|||
match_all_arrays(mnp_res, onp_res, error=1)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -2236,7 +2238,7 @@ def test_histogramdd():
|
|||
match_all_arrays(mnp_res[1], onp_res[1], error=3)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
Loading…
Reference in New Issue