!16360 numpy-native bug fix

From: @jachua
Reviewed-by: @liangchenghui,@guoqi1024
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-05-18 10:05:44 +08:00 committed by Gitee
commit 44334dfb56
8 changed files with 44 additions and 18 deletions

View File

@ -24,5 +24,11 @@ MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32)
BroadcastToGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastToGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
BroadcastToGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastToGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastToGpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -697,6 +697,12 @@ template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2,
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr,
half *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const int16_t *input_addr,
int16_t *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const int32_t *input_addr,
int32_t *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const int64_t *input_addr,
int64_t *output_addr, cudaStream_t stream);

View File

@ -48,6 +48,8 @@ MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnaryOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -4512,6 +4512,8 @@ def digitize(x, bins, right=False):
[1 3 3 4 5]
"""
x, bins = _to_tensor(x, bins)
if F.rank(bins) > 1:
_raise_value_error('bins should be 1-dimensional')
if x.size == 0:
return x
if bins.size == 0:
@ -4542,7 +4544,7 @@ def bincount(x, weights=None, minlength=0, length=None):
are treated as zeros instead.
Args:
x (Union[int, float, bool, list, tuple, Tensor]): 1-d input array.
x (Union[list, tuple, Tensor]): 1-d input array.
weights (Union[int, float, bool, list, tuple, Tensor], optional): Weights,
array of the same shape as `x`.
minlength (int, optional): A minimum number of bins for the output array.
@ -4573,6 +4575,8 @@ def bincount(x, weights=None, minlength=0, length=None):
x = _to_tensor(x)
if F.rank(x) != 1:
_raise_value_error('`x` should be one-dimensional')
if not _check_is_int(F.dtype(x)):
_raise_type_error('`x` should be an array of ints')
x = clip(x, 0, None)
if length is None:
if F.isconstant(x):
@ -4781,7 +4785,7 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=False): # pyl
if _get_device() == 'Ascend':
stacked_indices = F.cast(stacked_indices, mstype.float32)
flattened_hist = F.reduce_sum(stacked_indices.astype(mstype.float32), 0)
count = bincount(flattened_hist, weights, length=flattened_bin_size)
count = bincount(flattened_hist.astype(mstype.int32), weights, length=flattened_bin_size)
count = F.reshape(count, nbin)
slices = _list_comprehensions(ndim, F.make_slice(1, -1, 1), True)
count = count[slices]
@ -5128,8 +5132,8 @@ def polymul(a1, a2):
Examples:
>>> import mindspore.numpy as np
>>> print(np.polyder([1, 1, 1, 1]))
[3 2 1]
>>> print(np.polymul([3, 1, 2], [2, 5]))
[ 6 17 9 10]
"""
a1 = _to_poly1d(a1)
a2 = _to_poly1d(a2)
@ -5176,7 +5180,7 @@ def polyint(p, m=1, k=None):
k = atleast_1d(_to_tensor(k))
if k.size == 1:
k = F.tile(k, (m,))
k = F.reshape(k, (m, 1))
k = F.expand_dims(k, -1)
for i in range(m):
coeff = _to_tensor(F.make_range(_type_convert(int, p.size), 0, -1))
p = concatenate((true_divide(p, coeff), k[i]))

View File

@ -553,7 +553,7 @@ def get_bprop_pow(self):
def bprop(x, power, out, dout):
bc_dx = power * pow_op(x, power - 1.0) * dout
x[x < 0] = 1
x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
bc_dpower = out * ln(x) * dout
return binop_grad_common(x, power, bc_dx, bc_dpower)

View File

@ -615,7 +615,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = F.shape_mul(F.shape(value))
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
@ -636,7 +636,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
def tensor_setitem_by_tensor_with_number(data, index, value):
value = F.fill(F.dtype(data), (1,), value)
value = F.fill(F.dtype(data), (), value)
return tensor_setitem_by_tensor_with_tensor(data, index, value)
@ -868,8 +868,8 @@ def remove_expanded_dims(tuple_index, data_shape, value):
expand_true = has_true and not(has_false or has_sequence) # whether to expand dimension at True
tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
not_expanded_dim = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
rem_ndim, not_expanded_dim)
not_expanded_dim, idx_advanced = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
rem_ndim, not_expanded_dim)
if not indices_out:
indices_out = (True,)

View File

@ -781,7 +781,13 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim
else:
tensor_dims = (True,)*tensor_index_ndim
not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
return not_expanded_dim + (True,)*rem_ndim
not_expanded_dim = not_expanded_dim + (True,)*rem_ndim
count_leading_false = 0
while count_leading_false < len(not_expanded_dim) and not not_expanded_dim[count_leading_false]:
count_leading_false += 1
idx_advanced = max(0, idx_advanced - count_leading_false)
return not_expanded_dim, idx_advanced
@constexpr

View File

@ -2164,13 +2164,15 @@ def test_digitize():
def test_bincount():
x = onp.random.randint(0, 10, 20)
weights = onp.random.randn(20)
match_res(mnp.bincount, onp.bincount, x)
match_res(mnp.bincount, onp.bincount, x, minlength=25)
match_res(mnp.bincount, onp.bincount, x, weights, error=3)
match_res(mnp.bincount, onp.bincount, x, weights, minlength=25, error=3)
match_res(mnp.bincount, onp.bincount, x, dtype=mnp.int32)
match_res(mnp.bincount, onp.bincount, x, minlength=25, dtype=mnp.int32)
match_all_arrays(mnp.bincount(to_tensor(x), to_tensor(weights)),
onp.bincount(x, weights), error=3)
match_all_arrays(mnp.bincount(to_tensor(x), to_tensor(weights), minlength=25),
onp.bincount(x, weights, minlength=25), error=3)
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -2193,7 +2195,7 @@ def test_histogram():
match_all_arrays(mnp_res, onp_res, error=1)
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -2236,7 +2238,7 @@ def test_histogramdd():
match_all_arrays(mnp_res[1], onp_res[1], error=3)
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training