forked from mindspore-Ecosystem/mindspore
convert MS C++ dtype to MS python
This commit is contained in:
parent
91ddeedef5
commit
d9390afad1
|
@ -31,7 +31,7 @@ from mindspore import Tensor
|
|||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
from mindspore.common.api import _MindsporeFunctionExecutor
|
||||
from mindspore.common.api import _MindsporeFunctionExecutor, _convert_data
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
|
@ -559,6 +559,7 @@ def eval_script(exp_str, params):
|
|||
global_params = params[0]
|
||||
local_params = params[1]
|
||||
try:
|
||||
local_params = _convert_data(local_params)
|
||||
obj = eval(exp_str, global_params, local_params)
|
||||
except Exception as e:
|
||||
error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Top-level reference to dtype of common module."""
|
||||
from . import dtype
|
||||
from .api import ms_function, ms_memory_recycle
|
||||
from .api import ms_function, ms_memory_recycle, _convert_data
|
||||
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||
|
@ -53,10 +53,10 @@ __all__ = [
|
|||
]
|
||||
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
||||
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype",
|
||||
"dtype", "_convert_data",
|
||||
"set_seed", "get_seed", # random seed
|
||||
"set_dump",
|
||||
"ms_memory_recycle"
|
||||
|
|
|
@ -47,6 +47,29 @@ cells_compile_cache = {}
|
|||
|
||||
BROADCAST_PHASE = "_broadcast_"
|
||||
|
||||
def _convert_data(data):
|
||||
"""
|
||||
Convert C++ data to python.
|
||||
|
||||
Args:
|
||||
data : The data need be convert.
|
||||
|
||||
Returns:
|
||||
data, a data convert C++ to python
|
||||
"""
|
||||
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
|
||||
return MsTensor(data, internal=True)
|
||||
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
|
||||
return MsCSRTensor(csr_tensor=data)
|
||||
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
|
||||
return MsCOOTensor(coo_tensor=data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_convert_data(x) for x in data)
|
||||
if isinstance(data, list):
|
||||
return list(_convert_data(x) for x in data)
|
||||
if isinstance(data, dict):
|
||||
return dict((_convert_data(key), _convert_data(value)) for key, value in data.items())
|
||||
return data
|
||||
|
||||
def _wrap_func(fn):
|
||||
"""
|
||||
|
@ -62,20 +85,6 @@ def _wrap_func(fn):
|
|||
@wraps(fn)
|
||||
def wrapper(*arg, **kwargs):
|
||||
results = fn(*arg, **kwargs)
|
||||
|
||||
def _convert_data(data):
|
||||
if isinstance(data, Tensor) and not isinstance(data, MsTensor):
|
||||
return MsTensor(data, internal=True)
|
||||
if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
|
||||
return MsCSRTensor(csr_tensor=data)
|
||||
if isinstance(data, COOTensor) and not isinstance(data, MsCOOTensor):
|
||||
return MsCOOTensor(coo_tensor=data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple(_convert_data(x) for x in data)
|
||||
if isinstance(data, list):
|
||||
return list(_convert_data(x) for x in data)
|
||||
return data
|
||||
|
||||
return _convert_data(results)
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -82,7 +82,6 @@ def test_fallback_abs_numpy():
|
|||
assert np.all(foo().asnumpy() == abs(np.array([-1, 2, -3])))
|
||||
|
||||
|
||||
@pytest.mark.skip("Not Supported yet need to convert C++ Tensor To python")
|
||||
def test_fallback_abs_cell_construct_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
|
@ -87,11 +86,10 @@ def test_fallback_all_tensor():
|
|||
assert (not x) and y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
|
||||
def test_fallback_all_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test all(numpy.array) in graph mode
|
||||
Description: Test all(Tensor) in graph mode
|
||||
Expectation: No exception
|
||||
"""
|
||||
|
||||
|
@ -102,7 +100,7 @@ def test_fallback_all_tensor_construct():
|
|||
return all(x), all(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and not y
|
||||
assert (not x) and y
|
||||
|
||||
|
||||
def test_fallback_any_tuple():
|
||||
|
@ -171,7 +169,6 @@ def test_fallback_any_tensor():
|
|||
assert (not x) and y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to python")
|
||||
def test_fallback_any_tensor_construct():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -186,4 +183,4 @@ def test_fallback_any_tensor_construct():
|
|||
return any(x), any(y)
|
||||
|
||||
x, y = foo()
|
||||
assert (not x) and not y
|
||||
assert (not x) and y
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
""" test graph fallback """
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import ms_function, Tensor
|
||||
|
||||
|
@ -120,7 +119,6 @@ def test_fallback_bool_tensor():
|
|||
assert x and not y
|
||||
|
||||
|
||||
@pytest.mark.skip("Not support yet should convert C++ Tensor to Python")
|
||||
def test_fallback_bool_tensor_construct():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
|
@ -207,7 +205,6 @@ def test_fallback_float_tensor():
|
|||
assert math.isclose(foo(), 1.5, abs_tol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.skip("Not supported need to convert C++ Tensor to py")
|
||||
def test_fallback_float_tensor_construct():
|
||||
"""
|
||||
Feature : JIT Fallback
|
||||
|
|
Loading…
Reference in New Issue