convert MS C++ dtype to MS python

This commit is contained in:
lianliguang 2022-03-10 15:04:25 +08:00
parent 91ddeedef5
commit d9390afad1
6 changed files with 31 additions and 28 deletions

View File

@ -31,7 +31,7 @@ from mindspore import Tensor
from mindspore import log as logger
from mindspore import nn
from mindspore import ops
from mindspore.common.api import _MindsporeFunctionExecutor
from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data
from mindspore.common.dtype import pytype_to_dtype
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
@ -559,6 +559,7 @@ def eval_script(exp_str, params):
global_params = params[0]
local_params = params[1]
try:
local_params = _convert_data(local_params)
obj = eval(exp_str, global_params, local_params)
except Exception as e:
error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Top-level reference to dtype of common module."""
from . import dtype
from .api import ms_function, ms_memory_recycle
from .api import ms_function, ms_memory_recycle, _convert_data
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
@ -53,10 +53,10 @@ __all__ = [
]
__all__.extend([
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype",
"dtype", "_convert_data",
"set_seed", "get_seed", # random seed
"set_dump",
"ms_memory_recycle"

View File

@ -47,6 +47,29 @@ cells_compile_cache = {}
BROADCAST_PHASE = "_broadcast_"
def _convert_data(data):
"""
Convert C++ data to python.
Args:
data : The data need be convert.
Returns:
data, a data convert C++ to python
"""
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
return MsTensor(data, internal=True)
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
return MsCSRTensor(csr_tensor=data)
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
return MsCOOTensor(coo_tensor=data)
if isinstance(data, tuple):
return tuple(_convert_data(x) for x in data)
if isinstance(data, list):
return list(_convert_data(x) for x in data)
if isinstance(data, dict):
return dict((_convert_data(key), _convert_data(value)) for key, value in data.items())
return data
def _wrap_func(fn):
"""
@ -62,20 +85,6 @@ def _wrap_func(fn):
@wraps(fn)
def wrapper(*arg, **kwargs):
results = fn(*arg, **kwargs)
def _convert_data(data):
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
return MsTensor(data, internal=True)
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
return MsCSRTensor(csr_tensor=data)
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
return MsCOOTensor(coo_tensor=data)
if isinstance(data, tuple):
return tuple(_convert_data(x) for x in data)
if isinstance(data, list):
return list(_convert_data(x) for x in data)
return data
return _convert_data(results)
return wrapper

View File

@ -82,7 +82,6 @@ def test_fallback_abs_numpy():
assert np.all(foo().asnumpy() == abs(np.array([-1, 2, -3])))
@pytest.mark.skip("Not Supported yet need to convert C++ Tensor To python")
def test_fallback_abs_cell_construct_tensor():
"""
Feature: JIT Fallback

View File

@ -14,7 +14,6 @@
# ============================================================================
""" test graph fallback """
import numpy as np
import pytest
from mindspore import ms_function, context, Tensor
@ -87,11 +86,10 @@ def test_fallback_all_tensor():
assert (not x) and y
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
def test_fallback_all_tensor_construct():
"""
Feature: JIT Fallback
Description: Test all(numpy.array) in graph mode
Description: Test all(Tensor) in graph mode
Expectation: No exception
"""
@ -102,7 +100,7 @@ def test_fallback_all_tensor_construct():
return all(x), all(y)
x, y = foo()
assert (not x) and not y
assert (not x) and y
def test_fallback_any_tuple():
@ -171,7 +169,6 @@ def test_fallback_any_tensor():
assert (not x) and y
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
def test_fallback_any_tensor_construct():
"""
Feature: JIT Fallback
@ -186,4 +183,4 @@ def test_fallback_any_tensor_construct():
return any(x), any(y)
x, y = foo()
assert (not x) and not y
assert (not x) and y

View File

@ -15,7 +15,6 @@
""" test graph fallback """
import math
import numpy as np
import pytest
from mindspore import ms_function, Tensor
@ -120,7 +119,6 @@ def test_fallback_bool_tensor():
assert x and not y
@pytest.mark.skip("Not support yet should convert C++ Tensor to Python")
def test_fallback_bool_tensor_construct():
"""
Feature : JIT Fallback
@ -207,7 +205,6 @@ def test_fallback_float_tensor():
assert math.isclose(foo(), 1.5, abs_tol=1e-5)
@pytest.mark.skip("Not supported need to convert C++ Tensor to py")
def test_fallback_float_tensor_construct():
"""
Feature : JIT Fallback