!40915 unify tensor default dtype

Merge pull request !40915 from shaojunsong/fix/default_dtype
This commit is contained in:
i-robot 2022-08-30 02:06:07 +00:00 committed by Gitee
commit 6434a05c13
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 5 additions and 5 deletions

View File

@ -185,10 +185,10 @@ class Tensor(Tensor_):
def _set_default_dtype(input_data, dtype):
"""Set tensor default dtype"""
if isinstance(input_data, (float, list, tuple)):
if np.array(input_data).dtype == np.float64 or np.array(input_data).dtype == np.float32:
if np.array(input_data).dtype == np.float64:
return mstype.float32
elif isinstance(input_data, (int, list, tuple)):
if np.array(input_data).dtype == np.int64 or np.array(input_data).dtype == np.int32:
if isinstance(input_data, (int, list, tuple)):
if np.array(input_data).dtype == np.int32 or np.array(input_data).dtype == np.int64:
return mstype.int64
return dtype

View File

@ -58,7 +58,7 @@ def test_net(data_type, index_type, error):
dim0 = x.shape[0]
indices = np.random.randint(0, dim0, size=index_shape).astype(index_type)
segment_ids = np.random.randint(0, 2 * dim0, size=index_shape).astype(index_type)
segment_ids = sorted(segment_ids)
segment_ids = np.array(sorted(segment_ids)).astype(index_type)
np_out = sparse_segment_mean_numpy(x, indices, segment_ids)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")

View File

@ -58,7 +58,7 @@ def test_net(data_type, index_type, error):
dim0 = x.shape[0]
indices = np.random.randint(0, dim0, size=index_shape).astype(index_type)
segment_ids = np.random.randint(0, 2 * dim0, size=index_shape).astype(index_type)
segment_ids = sorted(segment_ids)
segment_ids = np.array(sorted(segment_ids)).astype(index_type)
np_out = sparse_segment_mean_numpy(x, indices, segment_ids)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")