forked from mindspore-Ecosystem/mindspore
!26275 [API] Tensor python float trans to ms float32
Merge pull request !26275 from kingxian/master
This commit is contained in:
commit
4f55a3f52f
|
@ -517,6 +517,13 @@ class _CellGraphExecutor:
|
|||
else:
|
||||
_set_dataset_mode_config('normal')
|
||||
|
||||
@staticmethod
|
||||
def _use_vm_mode():
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
||||
exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
|
||||
return not enable_ge or (enable_debug_runtime and exe_mode)
|
||||
|
||||
def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
@ -543,7 +550,7 @@ class _CellGraphExecutor:
|
|||
obj.arguments_key = str(key)
|
||||
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
||||
|
||||
if phase in obj.compile_cache:
|
||||
if phase in obj.compile_cache and self.has_compiled(phase):
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
|
||||
|
@ -556,9 +563,8 @@ class _CellGraphExecutor:
|
|||
args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
|
||||
_, args_list = _generate_pip_args(obj, *args_full)
|
||||
|
||||
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
|
||||
use_vm = self._use_vm_mode()
|
||||
result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name, self.enable_tuple_broaden)
|
||||
obj.compile_cache.add(phase)
|
||||
if not result:
|
||||
|
|
|
@ -130,17 +130,19 @@ class Tensor(Tensor_):
|
|||
if isinstance(input_data, (tuple, list)):
|
||||
if np.array(input_data).dtype not in valid_dtypes:
|
||||
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
|
||||
|
||||
if dtype is not None:
|
||||
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
|
||||
else:
|
||||
dtype = self._set_default_dtype(input_data, dtype)
|
||||
|
||||
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
|
||||
input_data = np.ascontiguousarray(input_data)
|
||||
if dtype is None:
|
||||
Tensor_.__init__(self, input_data)
|
||||
else:
|
||||
Tensor_.__init__(self, input_data, dtype)
|
||||
|
||||
Tensor_.__init__(self, input_data, dtype)
|
||||
else:
|
||||
Tensor_.__init__(self, dtype, shape)
|
||||
|
||||
self._virtual_flag = False
|
||||
self.init = init
|
||||
self.init_finished = True
|
||||
|
@ -151,6 +153,13 @@ class Tensor(Tensor_):
|
|||
self.parent_tensor_ = None
|
||||
self.index_of_parent_ = None
|
||||
|
||||
@staticmethod
|
||||
def _set_default_dtype(input_data, dtype):
|
||||
if isinstance(input_data, (float, list, tuple)):
|
||||
if np.array(input_data).dtype == np.float64:
|
||||
return mstype.float32
|
||||
return dtype
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
new_obj = Tensor(self)
|
||||
new_obj.init = self.init
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_tensor():
|
|||
|
||||
t3 = ms.Tensor(0.1)
|
||||
assert isinstance(t3, ms.Tensor)
|
||||
assert t3.dtype == ms.float64
|
||||
assert t3.dtype == ms.float32
|
||||
|
||||
t4 = ms.Tensor(1)
|
||||
assert isinstance(t4, ms.Tensor)
|
||||
|
@ -149,7 +149,7 @@ def test_tensor_type_float64():
|
|||
t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]])
|
||||
assert isinstance(t, ms.Tensor)
|
||||
assert t.shape == (2, 3)
|
||||
assert t.dtype == ms.float64
|
||||
assert t.dtype == ms.float32
|
||||
|
||||
t_zero = ms.Tensor(np.zeros([1, 2, 3]))
|
||||
assert isinstance(t_zero, ms.Tensor)
|
||||
|
|
Loading…
Reference in New Issue