forked from mindspore-Ecosystem/mindspore
!12341 fix st error for gpu type support
From: @yanglf1121 Reviewed-by: @liangchenghui,@kingxian Signed-off-by: @liangchenghui,@kingxian
This commit is contained in:
commit
94054172ce
|
@ -1208,14 +1208,15 @@ def cumsum(a, axis=None, dtype=None):
|
|||
if _check_same_type(original_dtype, mstype.bool_) or \
|
||||
_check_same_type(original_dtype, mstype.int8) or \
|
||||
_check_same_type(original_dtype, mstype.int16):
|
||||
a = a.astype(mstype.int32)
|
||||
original_dtype = mstype.int32
|
||||
a = a.astype(mstype.float32)
|
||||
if axis is None:
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
_check_axis_in_range(axis, a.ndim)
|
||||
if dtype is not None and not _check_same_type(original_dtype, dtype):
|
||||
return _cumsum_default(a, axis).astype(dtype, copy=False)
|
||||
return _cumsum_default(a, axis)
|
||||
return _cumsum_default(a, axis).astype(original_dtype, copy=False)
|
||||
|
||||
|
||||
def _index(i, size, Cartesian=True):
|
||||
|
|
|
@ -369,8 +369,8 @@ def test_linspace():
|
|||
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
start = onp.random.random([2, 1, 4])
|
||||
stop = onp.random.random([1, 5, 1])
|
||||
start = onp.random.random([2, 1, 4]).astype("float32")
|
||||
stop = onp.random.random([1, 5, 1]).astype("float32")
|
||||
actual = onp.linspace(start, stop, num=20, retstep=True,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
|
||||
|
|
|
@ -1047,8 +1047,7 @@ def onp_split(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_split():
|
||||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
|
@ -1081,8 +1080,7 @@ def onp_vsplit(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_vsplit():
|
||||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
|
||||
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
|
@ -1115,8 +1113,7 @@ def onp_hsplit(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_hsplit():
|
||||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32'),
|
||||
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float64')
|
||||
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
|
@ -1149,8 +1146,7 @@ def onp_dsplit(input_array):
|
|||
@pytest.mark.env_onecard
|
||||
def test_dsplit():
|
||||
onp_arrs = [
|
||||
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32'),
|
||||
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float64')
|
||||
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32')
|
||||
]
|
||||
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
|
||||
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
|
||||
|
|
|
@ -235,7 +235,7 @@ def test_true_divide():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_power():
|
||||
run_binop_test(mnp_power, onp_power, test_case)
|
||||
run_binop_test(mnp_power, onp_power, test_case, error=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -245,7 +245,7 @@ def test_power():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_float_power():
|
||||
run_binop_test(mnp_float_power, onp_float_power, test_case)
|
||||
run_binop_test(mnp_float_power, onp_float_power, test_case, error=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -422,8 +422,8 @@ def onp_average(x):
|
|||
def test_average():
|
||||
arr1 = rand_int(2, 3, 4, 5)
|
||||
arr2 = rand_int(4, 5, 1, 3, 1)
|
||||
run_single_test(mnp_average, onp_average, arr1)
|
||||
run_single_test(mnp_average, onp_average, arr2)
|
||||
run_single_test(mnp_average, onp_average, arr1, error=1e-5)
|
||||
run_single_test(mnp_average, onp_average, arr2, error=1e-5)
|
||||
|
||||
|
||||
def mnp_count_nonzero(x):
|
||||
|
@ -784,14 +784,14 @@ def test_maximum():
|
|||
def mnp_clip(x):
|
||||
a = mnp.clip(x, mnp.asarray(10.0), mnp.asarray([2,]))
|
||||
b = mnp.clip(x, 0, 1)
|
||||
c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float64)
|
||||
c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float32)
|
||||
return a, b, c
|
||||
|
||||
|
||||
def onp_clip(x):
|
||||
a = onp.clip(x, onp.asarray(10.0), onp.asarray([2,]))
|
||||
b = onp.clip(x, 0, 1)
|
||||
c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float64)
|
||||
c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float32)
|
||||
return a, b, c
|
||||
|
||||
|
||||
|
@ -1029,8 +1029,8 @@ def test_fix():
|
|||
x = rand_int(2, 3)
|
||||
y = rand_int(2, 3)
|
||||
floats = onp.divide(onp.subtract(x, y), y)
|
||||
match_res(mnp_fix, onp_fix, floats)
|
||||
run_binop_test(mnp_fmod, onp_fmod, test_case)
|
||||
match_res(mnp_fix, onp_fix, floats, error=1e-5)
|
||||
run_binop_test(mnp_fmod, onp_fmod, test_case, error=1e-5)
|
||||
|
||||
|
||||
def mnp_trunc(x):
|
||||
|
@ -1051,7 +1051,7 @@ def test_trunc():
|
|||
x = rand_int(2, 3)
|
||||
y = rand_int(2, 3)
|
||||
floats = onp.divide(onp.subtract(x, y), y)
|
||||
match_res(mnp_trunc, onp_trunc, floats)
|
||||
match_res(mnp_trunc, onp_trunc, floats, error=1e-5)
|
||||
|
||||
|
||||
def mnp_exp(x):
|
||||
|
@ -1137,7 +1137,7 @@ def test_negative():
|
|||
for where in where_lst:
|
||||
onp_neg = onp_negative(arr, out=out, where=where)
|
||||
mnp_neg = mnp_negative(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where))
|
||||
match_array(mnp_neg.asnumpy(), onp_neg)
|
||||
match_array(mnp_neg.asnumpy(), onp_neg, 1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
|
|
@ -118,25 +118,25 @@ def match_meta(actual, expected):
|
|||
assert actual.dtype == expected.dtype
|
||||
|
||||
|
||||
def run_binop_test(mnp_fn, onp_fn, test_case):
|
||||
def run_binop_test(mnp_fn, onp_fn, test_case, error=0):
|
||||
for arr in test_case.arrs:
|
||||
match_res(mnp_fn, onp_fn, arr, arr)
|
||||
match_res(mnp_fn, onp_fn, arr, arr, error=error)
|
||||
|
||||
for scalar in test_case.scalars:
|
||||
match_res(mnp_fn, onp_fn, arr, scalar)
|
||||
match_res(mnp_fn, onp_fn, scalar, arr)
|
||||
match_res(mnp_fn, onp_fn, arr, scalar, error=error)
|
||||
match_res(mnp_fn, onp_fn, scalar, arr, error=error)
|
||||
|
||||
for scalar1 in test_case.scalars:
|
||||
for scalar2 in test_case.scalars:
|
||||
match_res(mnp_fn, onp_fn, scalar1, scalar2)
|
||||
match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error)
|
||||
|
||||
for expanded_arr1 in test_case.expanded_arrs:
|
||||
for expanded_arr2 in test_case.expanded_arrs:
|
||||
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2)
|
||||
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error)
|
||||
|
||||
for broadcastable1 in test_case.broadcastables:
|
||||
for broadcastable2 in test_case.broadcastables:
|
||||
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2)
|
||||
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error)
|
||||
|
||||
|
||||
def run_unary_test(mnp_fn, onp_fn, test_case, error=0):
|
||||
|
|
Loading…
Reference in New Issue