From d9390afad102d61171eec8bf1c4317df347867d7 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Thu, 10 Mar 2022 15:04:25 +0800 Subject: [PATCH] convert MS C++ dtype to MS python --- .../python/mindspore/_extends/parse/parser.py | 3 +- mindspore/python/mindspore/common/__init__.py | 6 +-- mindspore/python/mindspore/common/api.py | 37 ++++++++++++------- .../test_graph_fallback_python_builtin_abs.py | 1 - ...t_graph_fallback_python_builtin_all_any.py | 9 ++--- ..._fallback_python_builtin_bool_int_float.py | 3 -- 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index e70bd4adff5..85d2ddb9df9 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -31,7 +31,7 @@ from mindspore import Tensor from mindspore import log as logger from mindspore import nn from mindspore import ops -from mindspore.common.api import _MindsporeFunctionExecutor +from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data from mindspore.common.dtype import pytype_to_dtype from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT @@ -559,6 +559,7 @@ def eval_script(exp_str, params): global_params = params[0] local_params = params[1] try: + local_params = _convert_data(local_params) obj = eval(exp_str, global_params, local_params) except Exception as e: error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \ diff --git a/mindspore/python/mindspore/common/__init__.py b/mindspore/python/mindspore/common/__init__.py index 9c795336c2a..8cee874cf16 100644 --- a/mindspore/python/mindspore/common/__init__.py +++ b/mindspore/python/mindspore/common/__init__.py @@ -14,7 +14,7 @@ # ============================================================================ """Top-level reference to dtype of common module.""" from . import dtype -from .api import ms_function, ms_memory_recycle +from .api import ms_function, ms_memory_recycle, _convert_data from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \ uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \ float32, single, float64, double, bool_, float_, list_, tuple_, int_, \ @@ -53,10 +53,10 @@ __all__ = [ ] __all__.extend([ - "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor + "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter - "dtype", + "dtype", "_convert_data", "set_seed", "get_seed", # random seed "set_dump", "ms_memory_recycle" diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index b8b9d199645..656149e13af 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -47,6 +47,29 @@ cells_compile_cache = {} BROADCAST_PHASE = "_broadcast_" +def _convert_data(data): + """ + Convert C++ data to python. + + Args: + data : The data need be convert. + + Returns: + data, a data convert C++ to python + """ + if isinstance(data, Tensor) and not isinstance(data, MsTensor): + return MsTensor(data, internal=True) + if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor): + return MsCSRTensor(csr_tensor=data) + if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor): + return MsCOOTensor(coo_tensor=data) + if isinstance(data, tuple): + return tuple(_convert_data(x) for x in data) + if isinstance(data, list): + return list(_convert_data(x) for x in data) + if isinstance(data, dict): + return dict((_convert_data(key), _convert_data(value)) for key, value in data.items()) + return data def _wrap_func(fn): """ @@ -62,20 +85,6 @@ def _wrap_func(fn): @wraps(fn) def wrapper(*arg, **kwargs): results = fn(*arg, **kwargs) - - def _convert_data(data): - if isinstance(data, Tensor) and not isinstance(data, MsTensor): - return MsTensor(data, internal=True) - if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor): - return MsCSRTensor(csr_tensor=data) - if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor): - return MsCOOTensor(coo_tensor=data) - if isinstance(data, tuple): - return tuple(_convert_data(x) for x in data) - if isinstance(data, list): - return list(_convert_data(x) for x in data) - return data - return _convert_data(results) return wrapper diff --git a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_abs.py b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_abs.py index 83d62dddeb4..28f627e143d 100644 --- a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_abs.py +++ b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_abs.py @@ -82,7 +82,6 @@ def test_fallback_abs_numpy(): assert np.all(foo().asnumpy() == abs(np.array([-1, 2, -3]))) -@pytest.mark.skip("Not Supported yet need to convert C++ Tensor To python") def test_fallback_abs_cell_construct_tensor(): """ Feature: JIT Fallback diff --git a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_all_any.py b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_all_any.py index 104dd379d61..723d631d0ea 100644 --- a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_all_any.py +++ b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_all_any.py @@ -14,7 +14,6 @@ # ============================================================================ """ test graph fallback """ import numpy as np -import pytest from mindspore import ms_function, context, Tensor @@ -87,11 +86,10 @@ def test_fallback_all_tensor(): assert (not x) and y -@pytest.mark.skip("Not support yet should convert C++ Tensor to python") def test_fallback_all_tensor_construct(): """ Feature: JIT Fallback - Description: Test all(numpy.array) in graph mode + Description: Test all(Tensor) in graph mode Expectation: No exception """ @@ -102,7 +100,7 @@ def test_fallback_all_tensor_construct(): return all(x), all(y) x, y = foo() - assert (not x) and not y + assert (not x) and y def test_fallback_any_tuple(): @@ -171,7 +169,6 @@ def test_fallback_any_tensor(): assert (not x) and y -@pytest.mark.skip("Not support yet should convert C++ Tensor to python") def test_fallback_any_tensor_construct(): """ Feature: JIT Fallback @@ -186,4 +183,4 @@ def test_fallback_any_tensor_construct(): return any(x), any(y) x, y = foo() - assert (not x) and not y + assert (not x) and y diff --git a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_bool_int_float.py b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_bool_int_float.py index b6b890bd0de..ad10b4e4a8b 100644 --- a/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_bool_int_float.py +++ b/tests/ut/python/fallback/python_builtin/test_graph_fallback_python_builtin_bool_int_float.py @@ -15,7 +15,6 @@ """ test graph fallback """ import math import numpy as np -import pytest from mindspore import ms_function, Tensor @@ -120,7 +119,6 @@ def test_fallback_bool_tensor(): assert x and not y -@pytest.mark.skip("Not support yet should convert C++ Tensor to Python") def test_fallback_bool_tensor_construct(): """ Feature : JIT Fallback @@ -207,7 +205,6 @@ def test_fallback_float_tensor(): assert math.isclose(foo(), 1.5, abs_tol=1e-5) -@pytest.mark.skip("Not supported need to convert C++ Tensor to py") def test_fallback_float_tensor_construct(): """ Feature : JIT Fallback