!49349 [StubTensor]Optimize StubTensor

Merge pull request !49349 from jiaoy1224/fixbug2
This commit is contained in:
i-robot 2023-02-27 06:16:24 +00:00 committed by Gitee
commit 4a16ad175d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 191 additions and 152 deletions

View File

@ -463,7 +463,9 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePt
MS_EXCEPTION_IF_NULL(tensors);
ValuePtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = PyTensorCast(input_object);
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (IsStubTensor(input_object)) {
tensor_ptr = ConvertStubTensor(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);

View File

@ -34,8 +34,8 @@ namespace mindspore {
py::object AnyToPyData(const Any &value);
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
// Convert python (stub) tensor to c++ tensor.
COMMON_EXPORT tensor::TensorPtr PyTensorCast(const py::handle &obj);
COMMON_EXPORT bool IsStubTensor(const py::handle &obj);
COMMON_EXPORT tensor::TensorPtr ConvertStubTensor(const py::handle &obj);
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);
} // namespace mindspore

View File

@ -31,6 +31,7 @@
namespace mindspore {
namespace stub {
constexpr auto PY_ATTR_STUB = "stub";
constexpr auto PY_ATTR_TENSOR = "tensor";
namespace py = pybind11;
class StubNode;

View File

@ -599,7 +599,9 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
static const std::vector<DataConverterPtr> data_converters{
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
std::make_shared<ByTypeDataConverter<Tensor>>(PyTensorCast),
std::make_shared<ByFuncDataConverter>([](const py::object &obj) -> bool { return IsStubTensor(obj); },
[](const py::object &obj) -> ValuePtr { return ConvertStubTensor(obj); }),
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),

View File

@ -50,6 +50,7 @@
#include "utils/crypto.h"
#include "utils/phase.h"
#include "include/common/utils/comm_manager.h"
#include "include/common/utils/stub_tensor.h"
#include "utils/interpret_node_recorder.h"
#include "include/common/debug/anf_ir_dump.h"
#include "include/common/debug/dump_proto.h"
@ -210,7 +211,7 @@ bool CheckArgValid(const py::handle &arg) {
}
if (py::isinstance<Tensor>(arg)) {
TensorPtr tensor = PyTensorCast(arg);
auto tensor = py::cast<TensorPtr>(arg);
if (tensor->data_type() == kNumberTypeBool) {
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
<< "operator compilation failure. For more details, please refer to the FAQ at "
@ -218,9 +219,9 @@ bool CheckArgValid(const py::handle &arg) {
}
}
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
py::isinstance<Number>(arg) || py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) ||
py::isinstance<COOTensor>(arg);
return IsStubTensor(arg) || py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) ||
py::isinstance<py::none>(arg) || py::isinstance<Number>(arg) || py::isinstance<Tensor>(arg) ||
py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg);
}
std::string GetCompileExceptionInfo() {
@ -475,13 +476,22 @@ py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple
size_t count = 0;
for (auto arg_obj : inputs) {
std::shared_ptr<Tensor> m_tensor = nullptr;
bool is_tensor = false;
if (py::isinstance<Tensor>(arg_obj)) {
m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
is_tensor = true;
} else if (IsStubTensor(arg_obj)) {
m_tensor = ConvertStubTensor(arg_obj);
is_tensor = true;
}
if (is_tensor && m_tensor == nullptr) {
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
return false;
}
if (m_tensor != nullptr) {
MS_LOG(DEBUG) << "Verify Tensor";
std::shared_ptr<Tensor> m_tensor = PyTensorCast(arg_obj);
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
return false;
}
auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
ShapeVector sig_shape = sig->shape();
TypePtr sig_type = sig->Dtype();

View File

@ -74,11 +74,13 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"BasicLSTMCellCStateGradV2", kTupleTensor2},
{"BasicLSTMCellInputGrad", kTupleTensor2},
{"BasicLSTMCellWeightGrad", kTupleTensor2},
{"BatchNorm", kTuple},
{"BatchNorm", kTupleTensor5},
{"BatchNormFold2GradD", kTupleTensor4},
{"BatchNormFold2GradReduce", kTupleTensor2},
{"BatchNormFoldD", kTupleTensor7},
{"BatchNormGrad", kTuple},
{"BatchNormGrad", kTupleTensor3},
{"BatchNormGradWithActivation", kTupleTensor3},
{"BatchNormGradWithAddAndActivation", kTupleTensor4},
{"BatchNormGradGrad", kTupleTensor3},
{"BiasDropoutAdd", kTupleTensor2},
{"CSRSparseMatrixToSparseTensor", kTupleTensor3},
@ -127,6 +129,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"FusedSparseProximalAdagrad", kTupleTensor2},
{"GRU", kTupleTensor4},
{"GRUV2", kTupleTensor4},
{"GRUV2Grad", kTupleTensor3},
{"GRUV2HiddenGrad", kTupleTensor3},
{"GRUV2HiddenGradCell", kTupleTensor3},
{"Geqrf", kTupleTensor2},
@ -139,6 +142,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"InstanceNormGrad", kTupleTensor3},
{"InstanceNormV2", kTupleTensor3},
{"InstanceNormV2Grad", kTupleTensor3},
{"InvertPermutation", kAnyType},
{"LSTM", kTupleTensor5},
{"LSTMGrad", kTupleTensor4},
{"LSTMGradData", kTupleTensor3},
@ -183,8 +187,6 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
{"NMSWithMask", kTupleTensor3},
{"PReLUGrad", kTupleTensor2},
{"PSROIPooling", kAnyType},
{"PriorityReplayBufferDestroy", kTupleTensor5},
{"PriorityReplayBufferPush", kTupleTensor2},
{"PriorityReplayBufferSample", kTuple},
{"Qr", kTupleTensor2},
{"RNNTLoss", kTupleTensor2},

View File

@ -60,6 +60,21 @@ std::string GetObjIdFromPython(const py::handle &obj) {
return out.cast<std::string>();
}
std::string GetIdForPyTupleOrList(const py::handle &obj) {
auto p_list = py::cast<py::tuple>(obj);
string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
if (p_list.empty()) {
prefix = "Empty:";
} else {
for (size_t i = 0; i < p_list.size(); ++i) {
prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
}
}
prefix.pop_back();
prefix += ">";
return prefix;
}
std::string GetFnInfoByPyObj(const py::object &obj) {
std::string fn_info = obj.attr("__module__").cast<std::string>();
fn_info += "_" + obj.attr("__name__").cast<std::string>();
@ -299,7 +314,9 @@ void Common::ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph) {
std::string PyParser::GetIdByPyObj(const py::object &obj) {
if (py::isinstance<tensor::Tensor>(obj)) {
return PyTensorCast(obj)->id();
return obj.cast<tensor::TensorPtr>()->id();
} else if (IsStubTensor(obj)) {
return ConvertStubTensor(obj)->id();
} else if (py::isinstance<Cell>(obj)) {
return obj.cast<CellPtr>()->id();
} else if (py::isinstance<mindspore::Type>(obj)) {
@ -318,18 +335,7 @@ std::string PyParser::GetIdByPyObj(const py::object &obj) {
} else if (py::isinstance<py::ellipsis>(obj)) {
return "Ellipsis";
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
auto p_list = py::cast<py::tuple>(obj);
string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
if (p_list.empty()) {
prefix = "Empty:";
} else {
for (size_t i = 0; i < p_list.size(); ++i) {
prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
}
}
prefix.pop_back();
prefix += ">";
return prefix;
return GetIdForPyTupleOrList(obj);
} else if (py::isinstance<py::function>(obj)) {
return GetFnInfoByPyObj(obj);
}

View File

@ -215,8 +215,8 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args,
<< (py_args.size() - filter_args_size) << ", but got:" << grads.size() << ".";
}
for (size_t i = 0; i < grads.size(); i++) {
if (py::isinstance<tensor::Tensor>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i])) {
if (py::isinstance<tensor::Tensor>(py_args[i]) || IsStubTensor(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(grads[i]) && !IsStubTensor(grads[i])) {
MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
<< "th return value(gradient of the " << i << "th argument) should be Tensor, but got "
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
@ -313,8 +313,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
}
auto actual_out_tensor = PyTensorCast(grad_out);
auto expected_out_tensor = PyTensorCast(expected_grad_out);
tensor::TensorPtr actual_out_tensor =
IsStubTensor(grad_out) ? ConvertStubTensor(grad_out) : py::cast<tensor::TensorPtr>(grad_out);
tensor::TensorPtr expected_out_tensor = IsStubTensor(expected_grad_out)
? ConvertStubTensor(expected_grad_out)
: py::cast<tensor::TensorPtr>(expected_grad_out);
MS_EXCEPTION_IF_NULL(actual_out_tensor);
MS_EXCEPTION_IF_NULL(expected_out_tensor);
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {

View File

@ -685,20 +685,15 @@ py::object MakeCOOTensor(const VectorRef &value_list) {
return ret[0];
}
tensor::TensorPtr PyTensorCast(const py::handle &obj) {
if (!py::isinstance<tensor::Tensor>(obj)) {
return nullptr;
bool IsStubTensor(const py::handle &obj) { return py::hasattr(obj, stub::PY_ATTR_STUB); }
tensor::TensorPtr ConvertStubTensor(const py::handle &obj) {
auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
auto stub = py_stub.cast<stub::StubNodePtr>();
if (stub == nullptr) {
return py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
}
auto is_stub_tensor = py::hasattr(obj, stub::PY_ATTR_STUB);
if (!is_stub_tensor) {
return py::cast<tensor::TensorPtr>(obj);
}
auto stub_node = py::getattr(obj, stub::PY_ATTR_STUB);
auto is_stub_tensor_sync = py::isinstance<stub::StubNode>(stub_node);
if (!is_stub_tensor_sync) {
return py::cast<tensor::TensorPtr>(obj);
}
auto stub = py::getattr(obj, stub::PY_ATTR_STUB).cast<stub::StubNodePtr>();
return stub->WaitValue()->cast<tensor::TensorPtr>();
auto res = stub->WaitValue()->cast<tensor::TensorPtr>();
return res;
}
} // namespace mindspore

View File

@ -308,6 +308,10 @@ def get_log2_size(size):
return cast_res
def is_stub_tensor(tensor):
return hasattr(tensor, "stub")
class Validator:
"""validator for checking input parameters"""
@ -622,7 +626,7 @@ class Validator:
hit = False
for template_type in template_types:
if isinstance(template_type, mstype.Type):
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
hit = True
break
elif type_ is template_type:
@ -1021,9 +1025,9 @@ class Validator:
@staticmethod
def check_sparse_tensor_input(indices, values, shape):
"""Common input check for SparseTensors."""
if not isinstance(indices, Tensor_):
if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
if not isinstance(values, Tensor_):
if not isinstance(values, Tensor_) and not is_stub_tensor(values):
raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
if not isinstance(shape, tuple):
raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
@ -1031,7 +1035,7 @@ class Validator:
@staticmethod
def check_csr_tensor_input(indptr, indices, values, shape):
"""Checks inputs type for CSRTensor."""
if not isinstance(indptr, Tensor_):
if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
Validator.check_sparse_tensor_input(indices, values, shape)
@ -1069,13 +1073,13 @@ class Validator:
err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
raise ValueError(err_msg1 + err_msg2)
if len(values_shp) + 1 != len(csr_shp):
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"\
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "\
f"{len(csr_shp)}")
if values_shp[1: ] != csr_shp[2: ]:
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"\
f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"\
f"got: {values_shp[1: ]}")
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "
f"{len(csr_shp)}")
if values_shp[1:] != csr_shp[2:]:
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"
f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"
f"got: {values_shp[1: ]}")
@staticmethod
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):

View File

@ -14,46 +14,70 @@
# ============================================================================
"""Stub Tensor implementation."""
import inspect
from functools import reduce
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import type_size_in_bytes
from mindspore._c_expression import Tensor as Tensor_
from mindspore._c_expression import TensorNode, SequenceNode
from mindspore.common.api import _convert_python_data
class StubTensor(Tensor):
def _stub_member(var, init):
def getx(stub):
return init if stub.tensor is None else getattr(stub.tensor, var)
def setx(stub, value):
setattr(stub.stub_sync(), var, value)
return property(getx, setx)
def _stub_method(method):
def fun(*arg, **kwargs):
stub = arg[0]
arg = (stub.stub_sync(),) + arg[1:]
return method(*arg, **kwargs)
return fun
class StubTensor:
"""stub tensor for async op run."""
const_arg = _stub_member("const_arg", None)
init = _stub_member("init", None)
init_finished = _stub_member("init_finished", False)
virtual_flag = _stub_member("virtual_flag", False)
parent_tensor_ = _stub_member("parent_tensor_", None)
index_of_parent_ = _stub_member("index_of_parent_", None)
slice_num_of_persistent_data_ = _stub_member("slice_num_of_persistent_data_", None)
slice_shape_of_persistent_data_ = _stub_member("slice_shape_of_persistent_data_", None)
def __init__(self, stub):
Tensor.__init__(self, internal=True)
self.stub = stub
self.tensor = None
def __repr__(self):
self.stub_sync()
return super().__repr__()
__repr__ = _stub_method(Tensor.__repr__)
__str__ = _stub_method(Tensor.__str__)
__setitem__ = _stub_method(Tensor.__setitem__)
def __str__(self):
self.stub_sync()
return super().__str__()
def __setitem__(self, index, value):
self.stub_sync()
return super().__setitem__(index, value)
__lt__ = Tensor.__lt__
__le__ = Tensor.__le__
__gt__ = Tensor.__gt__
__ge__ = Tensor.__ge__
__eq__ = Tensor.__eq__
__ne__ = Tensor.__ne__
@property
def shape(self):
"""shape stub."""
if self.stub:
return self.stub.get_shape()
return super().shape
return self.tensor.shape
@property
def dtype(self):
"""dtype stub."""
if self.stub:
return self.stub.get_dtype()
return super().dtype
return self.tensor.dtype
@property
def size(self):
@ -76,31 +100,20 @@ class StubTensor(Tensor):
"""ndim stub."""
return len(self.shape)
@property
def adapter_flag(self):
return False
@property
def strides(self):
"""strides stub."""
self.stub_sync()
return super().strides
return self.stub_sync().strides
@property
def has_init(self):
"""has_init stub."""
return False
@property
def adapter_flag(self):
"""adapter_flag stub."""
if self.stub:
return False
return super().adapter_flag
def stub_sync(self):
"""data sync to get real tensor"""
if self.stub:
val = self.stub.get_value()
Tensor_.__init__(self, val)
self.stub = None
def ndimension(self):
r"""
Alias for :func:`mindspore.Tensor.ndim`.
@ -113,45 +126,30 @@ class StubTensor(Tensor):
"""
return self.ndim
def asnumpy(self):
"""api stub."""
self.stub_sync()
return super().asnumpy()
asnumpy = _stub_method(Tensor.asnumpy)
is_persistent_data = _stub_method(Tensor.is_persistent_data)
asnumpy_of_slice_persistent_data = _stub_method(Tensor.asnumpy_of_slice_persistent_data)
slice_num_of_persistent_data = _stub_method(Tensor.slice_num_of_persistent_data)
slice_shape_of_persistent_data = _stub_method(Tensor.slice_shape_of_persistent_data)
flush_from_cache = _stub_method(Tensor.flush_from_cache)
def is_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.is_persistent_data`.
"""
self.stub_sync()
super().is_persistent_data()
def stub_sync(self):
if self.stub:
val = self.stub.get_value()
self.tensor = Tensor(val, internal=True)
self.stub = None
return self.tensor
def asnumpy_of_slice_persistent_data(self, param_key, slice_index):
"""
For details, please refer to :`mindspore.common.tensor.asnumpy_of_slice_persistent_data`.
"""
self.stub_sync()
return super().asnumpy_of_slice_persistent_data(param_key, slice_index)
def slice_num_of_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.slice_num_of_persistent_data`.
"""
self.stub_sync()
return super().slice_num_of_persistent_data()
def _init_stub_tensor_api():
stub_func = dir(StubTensor)
for attr in dir(Tensor):
if attr not in stub_func:
func = inspect.getattr_static(Tensor, attr)
setattr(StubTensor, attr, func)
def slice_shape_of_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.slice_shape_of_persistent_data`.
"""
self.stub_sync()
return super().slice_shape_of_persistent_data()
def flush_from_cache(self):
"""
For details, please refer to :`mindspore.common.tensor.flush_from_cache`.
"""
self.stub_sync()
super().flush_from_cache()
_init_stub_tensor_api()
def _convert_stub(stub):

View File

@ -82,9 +82,5 @@ def dict_setitem(dic, key, val):
return dic
def is_stub_tensor(tensor):
return getattr(tensor, "stub", False)
def raise_func(type, script):
raise type(script)

View File

@ -42,7 +42,7 @@ from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor
from mindspore.parallel._ps_context import _is_role_sched
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
_get_pipeline_stages, _is_in_auto_parallel_mode
from mindspore._checkparam import Validator
from mindspore._checkparam import Validator, is_stub_tensor
from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable
from mindspore.common._register_for_adapter import ms_adapter_registry
@ -106,7 +106,8 @@ def _wrap_func(fn):
def _check_all_tensor(sequence):
for element in sequence:
if not isinstance(element, Tensor) and not (isinstance(element, tuple) and _check_all_tensor(element)):
if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple)
and _check_all_tensor(element)):
return False
return True
@ -359,7 +360,7 @@ class _MindsporeFunctionExecutor:
compile_args = _restore_mutable_attr(args, compile_args)
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
str(self.fn.__code__.co_firstlineno)
str(self.fn.__code__.co_firstlineno)
if _pynative_executor.grad_flag():
generate_name = generate_name + ".grad"
if _is_pynative_parallel():
@ -690,7 +691,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
... closure_fn(inputs, func)
"""
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " \
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
"Please use 'mindspore.jit' instead.")
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
@ -836,7 +837,7 @@ def ms_class(cls):
20
"""
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " \
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
"Please use 'mindspore.jit_class' instead.")
# Check if cls is of type class.
@ -1584,6 +1585,7 @@ def _bind_device_context():
"""Bind device context to current thread"""
_bind_device_ctx()
_cell_graph_executor = _CellGraphExecutor()
_pynative_executor = _PyNativeExecutor()

View File

@ -22,13 +22,13 @@ from typing import Tuple
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common._utils import is_stub_tensor
from mindspore.common.tensor import Tensor
from mindspore._c_expression import COOTensor as COOTensor_
from mindspore._c_expression import CSRTensor as CSRTensor_
from mindspore._c_expression import RowTensor as RowTensor_
from mindspore._c_expression import Tensor as Tensor_
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import is_stub_tensor
class RowTensorInner(RowTensor_):
@ -49,7 +49,7 @@ class RowTensorInner(RowTensor_):
# Init a RowTensor from indices, values and shape
else:
if is_stub_tensor(values):
values.stub_sync()
values = values.stub_sync()
RowTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -189,9 +189,9 @@ class SparseTensor(COOTensor_):
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
if is_stub_tensor(indices):
indices.stub_sync()
indices = indices.stub_sync()
if is_stub_tensor(values):
values.stub_sync()
values = values.stub_sync()
COOTensor_.__init__(self, indices, values, shape)
@property
@ -280,9 +280,9 @@ class COOTensor(COOTensor_):
validator.check_coo_tensor_dtype(indices.dtype)
indices = tensor_operator_registry.get('stop_gradient')(indices)
if is_stub_tensor(indices):
indices.stub_sync()
indices = indices.stub_sync()
if is_stub_tensor(values):
values.stub_sync()
values = values.stub_sync()
COOTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -617,11 +617,11 @@ class CSRTensor(CSRTensor_):
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
indices = tensor_operator_registry.get('stop_gradient')(indices)
if is_stub_tensor(indptr):
indptr.stub_sync()
indptr = indptr.stub_sync()
if is_stub_tensor(values):
values.stub_sync()
values = values.stub_sync()
if is_stub_tensor(indices):
indices.stub_sync()
indices = indices.stub_sync()
CSRTensor_.__init__(self, indptr, indices, values, shape)
self.init_finished = True

View File

@ -16,12 +16,13 @@
__all__ = ['Tensor']
import abc
import math
import numbers
import numpy as np
from mindspore.communication.management import get_group_size
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
from mindspore.common._utils import is_shape_unknown
from mindspore.common.seed import get_seed
from mindspore import context
from mindspore import log as logger
@ -30,7 +31,7 @@ from mindspore.common import dtype as mstype
from mindspore.common._utils import get_slice_num
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore._c_expression import Tensor as Tensor_
from mindspore._checkparam import Rel, check_is_number
from mindspore._checkparam import Rel, check_is_number, is_stub_tensor
from mindspore._checkparam import Validator as validator
from mindspore._check_jit_forbidden_api import jit_forbidden_register
@ -42,7 +43,7 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
def _check_input_data_type(input_data):
"""Check the type of input_data for Tensor"""
validator.check_value_type('input_data', input_data,
(Tensor_, np.ndarray, np.str_, list, tuple, float, int, bool, complex),
(Tensor_, Tensor, np.ndarray, np.str_, list, tuple, float, int, bool, complex),
'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
@ -67,7 +68,13 @@ def _check_input_data_type(input_data):
f"For Tensor, the input_data is {input_data} that contain unsupported element.")
class Tensor(Tensor_):
class _TensorMeta(type(Tensor_), abc.ABCMeta):
"""
Meta class for Tensor. Used internally.
"""
class Tensor(Tensor_, metaclass=_TensorMeta):
"""
Tensor is a data structure that stores an n-dimensional array.
@ -151,12 +158,13 @@ class Tensor(Tensor_):
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
self.init_finished = False
if internal:
if input_data is not None:
Tensor_.__init__(self, input_data)
else:
if is_stub_tensor(input_data):
input_data.stub_sync()
input_data = input_data.stub_sync()
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
@ -174,8 +182,8 @@ class Tensor(Tensor_):
else:
_check_input_data_type(input_data)
if dtype is not None:
validator.check_type_name(
'dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
validator.check_type_name('dtype', dtype, mstype.number_type +
(mstype.bool_, mstype.string), "Tensor")
else:
dtype = self._set_default_dtype(input_data, dtype)
@ -186,8 +194,8 @@ class Tensor(Tensor_):
Tensor_.__init__(self, input_data, dtype)
else:
Tensor_.__init__(self, input_data)
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
self.const_arg = const_arg
self.virtual_flag = False
self.init = init
@ -202,6 +210,16 @@ class Tensor(Tensor_):
self.slice_num_of_persistent_data_ = None
self.slice_shape_of_persistent_data_ = None
@classmethod
def __subclasshook__(cls, sub):
"""
Subclass with stub_sync attr will be instance of Tensor
"""
if cls is Tensor:
if any("stub_sync" in s.__dict__ for s in sub.__mro__):
return True
return NotImplemented
@staticmethod
def _set_default_dtype(input_data, dtype):
"""Set tensor default dtype"""
@ -401,8 +419,6 @@ class Tensor(Tensor_):
def __setitem__(self, index, value):
out = tensor_operator_registry.get('__setitem__')(self, index, value)
if is_stub_tensor(out):
out.stub_sync()
self.assign_value(out)
if self.parent_tensor_ is not None and self.index_of_parent_ is not None:
self.parent_tensor_.__setitem__(self.index_of_parent_, self)
@ -587,6 +603,8 @@ class Tensor(Tensor_):
Returns:
Tensor, Tensor that's been assigned.
"""
if is_stub_tensor(value):
value = value.stub_sync()
self.assign_value_cpp(value)
return self