forked from mindspore-Ecosystem/mindspore
!40915 unify tensor default dtype
Merge pull request !40915 from shaojunsong/fix/default_dtype
This commit is contained in:
commit
6434a05c13
|
@ -185,10 +185,10 @@ class Tensor(Tensor_):
|
|||
def _set_default_dtype(input_data, dtype):
|
||||
"""Set tensor default dtype"""
|
||||
if isinstance(input_data, (float, list, tuple)):
|
||||
if np.array(input_data).dtype == np.float64 or np.array(input_data).dtype == np.float32:
|
||||
if np.array(input_data).dtype == np.float64:
|
||||
return mstype.float32
|
||||
elif isinstance(input_data, (int, list, tuple)):
|
||||
if np.array(input_data).dtype == np.int64 or np.array(input_data).dtype == np.int32:
|
||||
if isinstance(input_data, (int, list, tuple)):
|
||||
if np.array(input_data).dtype == np.int32 or np.array(input_data).dtype == np.int64:
|
||||
return mstype.int64
|
||||
return dtype
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_net(data_type, index_type, error):
|
|||
dim0 = x.shape[0]
|
||||
indices = np.random.randint(0, dim0, size=index_shape).astype(index_type)
|
||||
segment_ids = np.random.randint(0, 2 * dim0, size=index_shape).astype(index_type)
|
||||
segment_ids = sorted(segment_ids)
|
||||
segment_ids = np.array(sorted(segment_ids)).astype(index_type)
|
||||
np_out = sparse_segment_mean_numpy(x, indices, segment_ids)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_net(data_type, index_type, error):
|
|||
dim0 = x.shape[0]
|
||||
indices = np.random.randint(0, dim0, size=index_shape).astype(index_type)
|
||||
segment_ids = np.random.randint(0, 2 * dim0, size=index_shape).astype(index_type)
|
||||
segment_ids = sorted(segment_ids)
|
||||
segment_ids = np.array(sorted(segment_ids)).astype(index_type)
|
||||
np_out = sparse_segment_mean_numpy(x, indices, segment_ids)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
|
Loading…
Reference in New Issue