!12341 fix st error for gpu type support

From: @yanglf1121
Reviewed-by: @liangchenghui,@kingxian
Signed-off-by: @liangchenghui,@kingxian
This commit is contained in:
mindspore-ci-bot 2021-02-18 14:22:19 +08:00 committed by Gitee
commit 94054172ce
5 changed files with 26 additions and 29 deletions

View File

@ -1208,14 +1208,15 @@ def cumsum(a, axis=None, dtype=None):
if _check_same_type(original_dtype, mstype.bool_) or \
_check_same_type(original_dtype, mstype.int8) or \
_check_same_type(original_dtype, mstype.int16):
a = a.astype(mstype.int32)
original_dtype = mstype.int32
a = a.astype(mstype.float32)
if axis is None:
a = a.ravel()
axis = 0
_check_axis_in_range(axis, a.ndim)
if dtype is not None and not _check_same_type(original_dtype, dtype):
return _cumsum_default(a, axis).astype(dtype, copy=False)
return _cumsum_default(a, axis)
return _cumsum_default(a, axis).astype(original_dtype, copy=False)
def _index(i, size, Cartesian=True):

View File

@ -369,8 +369,8 @@ def test_linspace():
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
match_array(actual, expected, error=6)
start = onp.random.random([2, 1, 4])
stop = onp.random.random([1, 5, 1])
start = onp.random.random([2, 1, 4]).astype("float32")
stop = onp.random.random([1, 5, 1]).astype("float32")
actual = onp.linspace(start, stop, num=20, retstep=True,
endpoint=False, dtype=onp.float32)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,

View File

@ -1047,8 +1047,7 @@ def onp_split(input_array):
@pytest.mark.env_onecard
def test_split():
onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
@ -1081,8 +1080,7 @@ def onp_vsplit(input_array):
@pytest.mark.env_onecard
def test_vsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
@ -1115,8 +1113,7 @@ def onp_hsplit(input_array):
@pytest.mark.env_onecard
def test_hsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32'),
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float64')
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
@ -1149,8 +1146,7 @@ def onp_dsplit(input_array):
@pytest.mark.env_onecard
def test_dsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32'),
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float64')
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):

View File

@ -235,7 +235,7 @@ def test_true_divide():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_power():
run_binop_test(mnp_power, onp_power, test_case)
run_binop_test(mnp_power, onp_power, test_case, error=1e-5)
@pytest.mark.level1
@ -245,7 +245,7 @@ def test_power():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_float_power():
run_binop_test(mnp_float_power, onp_float_power, test_case)
run_binop_test(mnp_float_power, onp_float_power, test_case, error=1e-5)
@pytest.mark.level1
@ -422,8 +422,8 @@ def onp_average(x):
def test_average():
arr1 = rand_int(2, 3, 4, 5)
arr2 = rand_int(4, 5, 1, 3, 1)
run_single_test(mnp_average, onp_average, arr1)
run_single_test(mnp_average, onp_average, arr2)
run_single_test(mnp_average, onp_average, arr1, error=1e-5)
run_single_test(mnp_average, onp_average, arr2, error=1e-5)
def mnp_count_nonzero(x):
@ -784,14 +784,14 @@ def test_maximum():
def mnp_clip(x):
a = mnp.clip(x, mnp.asarray(10.0), mnp.asarray([2,]))
b = mnp.clip(x, 0, 1)
c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float64)
c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float32)
return a, b, c
def onp_clip(x):
a = onp.clip(x, onp.asarray(10.0), onp.asarray([2,]))
b = onp.clip(x, 0, 1)
c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float64)
c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float32)
return a, b, c
@ -1029,8 +1029,8 @@ def test_fix():
x = rand_int(2, 3)
y = rand_int(2, 3)
floats = onp.divide(onp.subtract(x, y), y)
match_res(mnp_fix, onp_fix, floats)
run_binop_test(mnp_fmod, onp_fmod, test_case)
match_res(mnp_fix, onp_fix, floats, error=1e-5)
run_binop_test(mnp_fmod, onp_fmod, test_case, error=1e-5)
def mnp_trunc(x):
@ -1051,7 +1051,7 @@ def test_trunc():
x = rand_int(2, 3)
y = rand_int(2, 3)
floats = onp.divide(onp.subtract(x, y), y)
match_res(mnp_trunc, onp_trunc, floats)
match_res(mnp_trunc, onp_trunc, floats, error=1e-5)
def mnp_exp(x):
@ -1137,7 +1137,7 @@ def test_negative():
for where in where_lst:
onp_neg = onp_negative(arr, out=out, where=where)
mnp_neg = mnp_negative(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where))
match_array(mnp_neg.asnumpy(), onp_neg)
match_array(mnp_neg.asnumpy(), onp_neg, 1e-5)
@pytest.mark.level1

View File

@ -118,25 +118,25 @@ def match_meta(actual, expected):
assert actual.dtype == expected.dtype
def run_binop_test(mnp_fn, onp_fn, test_case):
def run_binop_test(mnp_fn, onp_fn, test_case, error=0):
for arr in test_case.arrs:
match_res(mnp_fn, onp_fn, arr, arr)
match_res(mnp_fn, onp_fn, arr, arr, error=error)
for scalar in test_case.scalars:
match_res(mnp_fn, onp_fn, arr, scalar)
match_res(mnp_fn, onp_fn, scalar, arr)
match_res(mnp_fn, onp_fn, arr, scalar, error=error)
match_res(mnp_fn, onp_fn, scalar, arr, error=error)
for scalar1 in test_case.scalars:
for scalar2 in test_case.scalars:
match_res(mnp_fn, onp_fn, scalar1, scalar2)
match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error)
for expanded_arr1 in test_case.expanded_arrs:
for expanded_arr2 in test_case.expanded_arrs:
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2)
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error)
for broadcastable1 in test_case.broadcastables:
for broadcastable2 in test_case.broadcastables:
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2)
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error)
def run_unary_test(mnp_fn, onp_fn, test_case, error=0):