!26275 [API] Tensor python float trans to ms float32

Merge pull request !26275 from kingxian/master
This commit is contained in:
i-robot 2021-11-22 11:13:15 +00:00 committed by Gitee
commit 4f55a3f52f
3 changed files with 24 additions and 9 deletions

View File

@ -517,6 +517,13 @@ class _CellGraphExecutor:
else:
_set_dataset_mode_config('normal')
@staticmethod
def _use_vm_mode():
enable_ge = context.get_context("enable_ge")
enable_debug_runtime = context.get_context("enable_debug_runtime")
exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
return not enable_ge or (enable_debug_runtime and exe_mode)
def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
"""
Compiles graph.
@ -543,7 +550,7 @@ class _CellGraphExecutor:
obj.arguments_key = str(key)
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
if phase in obj.compile_cache:
if phase in obj.compile_cache and self.has_compiled(phase):
logger.debug("%r graph has existed.", phase)
return phase, False
@ -556,9 +563,8 @@ class _CellGraphExecutor:
args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
_, args_list = _generate_pip_args(obj, *args_full)
enable_debug_runtime = context.get_context("enable_debug_runtime")
enable_ge = context.get_context("enable_ge")
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
use_vm = self._use_vm_mode()
result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name, self.enable_tuple_broaden)
obj.compile_cache.add(phase)
if not result:

View File

@ -130,17 +130,19 @@ class Tensor(Tensor_):
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")
if dtype is not None:
validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
else:
dtype = self._set_default_dtype(input_data, dtype)
if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']):
input_data = np.ascontiguousarray(input_data)
if dtype is None:
Tensor_.__init__(self, input_data)
else:
Tensor_.__init__(self, input_data, dtype)
Tensor_.__init__(self, input_data, dtype)
else:
Tensor_.__init__(self, dtype, shape)
self._virtual_flag = False
self.init = init
self.init_finished = True
@ -151,6 +153,13 @@ class Tensor(Tensor_):
self.parent_tensor_ = None
self.index_of_parent_ = None
@staticmethod
def _set_default_dtype(input_data, dtype):
if isinstance(input_data, (float, list, tuple)):
if np.array(input_data).dtype == np.float64:
return mstype.float32
return dtype
def __deepcopy__(self, memodict):
new_obj = Tensor(self)
new_obj.init = self.init

View File

@ -61,7 +61,7 @@ def test_tensor():
t3 = ms.Tensor(0.1)
assert isinstance(t3, ms.Tensor)
assert t3.dtype == ms.float64
assert t3.dtype == ms.float32
t4 = ms.Tensor(1)
assert isinstance(t4, ms.Tensor)
@ -149,7 +149,7 @@ def test_tensor_type_float64():
t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]])
assert isinstance(t, ms.Tensor)
assert t.shape == (2, 3)
assert t.dtype == ms.float64
assert t.dtype == ms.float32
t_zero = ms.Tensor(np.zeros([1, 2, 3]))
assert isinstance(t_zero, ms.Tensor)