This commit is contained in:
huangmengxi 2021-05-20 17:49:39 +08:00
parent beb516b4a2
commit ba2757b830
2 changed files with 9 additions and 8 deletions

View File

@ -2141,13 +2141,12 @@ def lcm(x1, x2, dtype=None):
def _lcm(x1, x2):
"""Calculates lcm without applying keyword arguments"""
common_divisor = _gcd(x1, x2)
dtype = _promote(F.dtype(x1), F.dtype(x2))
x1 = x1.astype(mstype.float32)
x2 = x2.astype(mstype.float32)
q1 = F.tensor_div(x1, common_divisor)
q2 = F.tensor_div(x2, common_divisor)
res = F.tensor_mul(F.tensor_mul(q1, q2), common_divisor)
dtype = F.dtype(res)
if not _check_is_float(dtype):
# F.absolute only supports float
res = F.cast(res, mstype.float32)
return F.absolute(res).astype(dtype)
return _apply_tensor_op(_lcm, x1, x2, dtype=dtype)
@ -4602,7 +4601,7 @@ def histogram(a, bins=10, range=None, weights=None, density=False): # pylint: di
Deprecated numpy argument `normed` is not supported.
Args:
x (Union[int, float, bool, list, tuple, Tensor]): Input data. The histogram
a (Union[int, float, bool, list, tuple, Tensor]): Input data. The histogram
is computed over the flattened array.
bins (Union[int, tuple, list, Tensor], optional): If `bins` is an int, it
defines the number of equal-width bins in the given range (10, by
@ -5278,6 +5277,8 @@ def unwrap(p, discont=3.141592653589793, axis=-1):
>>> print(np.unwrap(phase))
[ 0.0000000e+00 7.8539819e-01 1.5707964e+00 -7.8539848e-01 -4.7683716e-07]
"""
if not isinstance(discont, (int, float)):
_raise_type_error('discont should be a float')
p = _to_tensor(p)
ndim = F.rank(p)
axis = _check_axis_in_range(axis, ndim)

View File

@ -2172,7 +2172,7 @@ def test_bincount():
onp.bincount(x, weights, minlength=25), error=3)
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -2195,7 +2195,7 @@ def test_histogram():
match_all_arrays(mnp_res, onp_res, error=1)
@pytest.mark.level2
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@ -2238,7 +2238,7 @@ def test_histogramdd():
match_all_arrays(mnp_res[1], onp_res[1], error=3)
@pytest.mark.level2
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training