diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 7faa2089e04..2ede23a989d 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -517,6 +517,13 @@ class _CellGraphExecutor: else: _set_dataset_mode_config('normal') + @staticmethod + def _use_vm_mode(): + enable_ge = context.get_context("enable_ge") + enable_debug_runtime = context.get_context("enable_debug_runtime") + exe_mode = context.get_context("mode") == context.PYNATIVE_MODE + return not enable_ge or (enable_debug_runtime and exe_mode) + def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False): """ Compiles graph. @@ -543,7 +550,7 @@ class _CellGraphExecutor: obj.arguments_key = str(key) phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key - if phase in obj.compile_cache: + if phase in obj.compile_cache and self.has_compiled(phase): logger.debug("%r graph has existed.", phase) return phase, False @@ -556,9 +563,8 @@ class _CellGraphExecutor: args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank()) _, args_list = _generate_pip_args(obj, *args_full) - enable_debug_runtime = context.get_context("enable_debug_runtime") enable_ge = context.get_context("enable_ge") - use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) + use_vm = self._use_vm_mode() result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name, self.enable_tuple_broaden) obj.compile_cache.add(phase) if not result: diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index e30bd1a8d53..51475075749 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -130,17 +130,19 @@ class Tensor(Tensor_): if isinstance(input_data, (tuple, list)): if np.array(input_data).dtype not in valid_dtypes: raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") + if dtype is not None: validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor") + else: + dtype = self._set_default_dtype(input_data, dtype) if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): input_data = np.ascontiguousarray(input_data) - if dtype is None: - Tensor_.__init__(self, input_data) - else: - Tensor_.__init__(self, input_data, dtype) + + Tensor_.__init__(self, input_data, dtype) else: Tensor_.__init__(self, dtype, shape) + self._virtual_flag = False self.init = init self.init_finished = True @@ -151,6 +153,13 @@ class Tensor(Tensor_): self.parent_tensor_ = None self.index_of_parent_ = None + @staticmethod + def _set_default_dtype(input_data, dtype): + if isinstance(input_data, (float, list, tuple)): + if np.array(input_data).dtype == np.float64: + return mstype.float32 + return dtype + def __deepcopy__(self, memodict): new_obj = Tensor(self) new_obj.init = self.init diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 9acb4254996..110772b8ed2 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -61,7 +61,7 @@ def test_tensor(): t3 = ms.Tensor(0.1) assert isinstance(t3, ms.Tensor) - assert t3.dtype == ms.float64 + assert t3.dtype == ms.float32 t4 = ms.Tensor(1) assert isinstance(t4, ms.Tensor) @@ -149,7 +149,7 @@ def test_tensor_type_float64(): t = ms.Tensor([[1.0, 2, 3], [4, 5, 6]]) assert isinstance(t, ms.Tensor) assert t.shape == (2, 3) - assert t.dtype == ms.float64 + assert t.dtype == ms.float32 t_zero = ms.Tensor(np.zeros([1, 2, 3])) assert isinstance(t_zero, ms.Tensor)