From f8da54b48f20066741e5b729fb9fe0fce4aa0a0e Mon Sep 17 00:00:00 2001 From: Yang Jiao Date: Fri, 24 Feb 2023 17:24:27 +0800 Subject: [PATCH] opytimize stubtensor --- .../ccsrc/backend/graph_compiler/backend.cc | 4 +- .../include/common/utils/convert_utils_py.h | 4 +- .../ccsrc/include/common/utils/stub_tensor.h | 1 + .../pipeline/jit/parse/data_converter.cc | 4 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 28 ++-- .../pipeline/pynative/predict_out_type_map.h | 10 +- .../ccsrc/pipeline/pynative/pynative_utils.cc | 32 +++-- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 11 +- mindspore/ccsrc/utils/convert_utils_py.cc | 23 ++-- mindspore/python/mindspore/_checkparam.py | 26 ++-- .../python/mindspore/common/_stub_tensor.py | 128 +++++++++--------- mindspore/python/mindspore/common/_utils.py | 4 - mindspore/python/mindspore/common/api.py | 12 +- .../python/mindspore/common/sparse_tensor.py | 18 +-- mindspore/python/mindspore/common/tensor.py | 38 ++++-- 15 files changed, 191 insertions(+), 152 deletions(-) diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index 9d0b3cd2aaa..96fa5c12b20 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -463,7 +463,9 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector(input_object)) { - tensor_ptr = PyTensorCast(input_object); + tensor_ptr = py::cast(input_object); + } else if (IsStubTensor(input_object)) { + tensor_ptr = ConvertStubTensor(input_object); } else if (py::isinstance(input_object)) { double input_value = py::cast(input_object); tensor_ptr = std::make_shared(input_value, kFloat32); diff --git a/mindspore/ccsrc/include/common/utils/convert_utils_py.h b/mindspore/ccsrc/include/common/utils/convert_utils_py.h index 16e20d0c62f..6120b0066d9 100644 --- a/mindspore/ccsrc/include/common/utils/convert_utils_py.h +++ b/mindspore/ccsrc/include/common/utils/convert_utils_py.h @@ -34,8 +34,8 @@ namespace mindspore { py::object AnyToPyData(const Any &value); COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr); COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr); -// Convert python (stub) tensor to c++ tensor. -COMMON_EXPORT tensor::TensorPtr PyTensorCast(const py::handle &obj); +COMMON_EXPORT bool IsStubTensor(const py::handle &obj); +COMMON_EXPORT tensor::TensorPtr ConvertStubTensor(const py::handle &obj); COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, const std::shared_ptr &ret_val); } // namespace mindspore diff --git a/mindspore/ccsrc/include/common/utils/stub_tensor.h b/mindspore/ccsrc/include/common/utils/stub_tensor.h index 6a88cee84e8..e22b517e007 100644 --- a/mindspore/ccsrc/include/common/utils/stub_tensor.h +++ b/mindspore/ccsrc/include/common/utils/stub_tensor.h @@ -31,6 +31,7 @@ namespace mindspore { namespace stub { constexpr auto PY_ATTR_STUB = "stub"; +constexpr auto PY_ATTR_TENSOR = "tensor"; namespace py = pybind11; class StubNode; diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index e0c81082b2b..90a192bfeb4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -599,7 +599,9 @@ static const std::vector &GetDataConverters() { static const std::vector data_converters{ // AdapterTensor needs to be processed before Tensor because it inherits from Tensor. std::make_shared(IsAdapterTensor, ConvertAdapterTensor), - std::make_shared>(PyTensorCast), + std::make_shared([](const py::object &obj) -> bool { return IsStubTensor(obj); }, + [](const py::object &obj) -> ValuePtr { return ConvertStubTensor(obj); }), + std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 2adf504170d..16fa1658d08 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -49,6 +49,7 @@ #include "utils/crypto.h" #include "utils/phase.h" #include "include/common/utils/comm_manager.h" +#include "include/common/utils/stub_tensor.h" #include "utils/interpret_node_recorder.h" #include "include/common/debug/anf_ir_dump.h" #include "include/common/debug/dump_proto.h" @@ -208,7 +209,7 @@ bool CheckArgValid(const py::handle &arg) { } if (py::isinstance(arg)) { - TensorPtr tensor = PyTensorCast(arg); + auto tensor = py::cast(arg); if (tensor->data_type() == kNumberTypeBool) { MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause " << "operator compilation failure. For more details, please refer to the FAQ at " @@ -216,9 +217,9 @@ bool CheckArgValid(const py::handle &arg) { } } - return py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || - py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || - py::isinstance(arg); + return IsStubTensor(arg) || py::isinstance(arg) || py::isinstance(arg) || + py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || + py::isinstance(arg) || py::isinstance(arg); } std::string GetCompileExceptionInfo() { @@ -473,13 +474,22 @@ py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple size_t count = 0; for (auto arg_obj : inputs) { + std::shared_ptr m_tensor = nullptr; + bool is_tensor = false; if (py::isinstance(arg_obj)) { + m_tensor = arg_obj.cast>(); + is_tensor = true; + } else if (IsStubTensor(arg_obj)) { + m_tensor = ConvertStubTensor(arg_obj); + is_tensor = true; + } + if (is_tensor && m_tensor == nullptr) { + MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; + return false; + } + + if (m_tensor != nullptr) { MS_LOG(DEBUG) << "Verify Tensor"; - std::shared_ptr m_tensor = PyTensorCast(arg_obj); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; - return false; - } auto sig = input_signature[count].cast>(); ShapeVector sig_shape = sig->shape(); TypePtr sig_type = sig->Dtype(); diff --git a/mindspore/ccsrc/pipeline/pynative/predict_out_type_map.h b/mindspore/ccsrc/pipeline/pynative/predict_out_type_map.h index 4e58cd89bb5..056189aff40 100644 --- a/mindspore/ccsrc/pipeline/pynative/predict_out_type_map.h +++ b/mindspore/ccsrc/pipeline/pynative/predict_out_type_map.h @@ -74,11 +74,13 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4 {"BasicLSTMCellCStateGradV2", kTupleTensor2}, {"BasicLSTMCellInputGrad", kTupleTensor2}, {"BasicLSTMCellWeightGrad", kTupleTensor2}, - {"BatchNorm", kTuple}, + {"BatchNorm", kTupleTensor5}, {"BatchNormFold2GradD", kTupleTensor4}, {"BatchNormFold2GradReduce", kTupleTensor2}, {"BatchNormFoldD", kTupleTensor7}, - {"BatchNormGrad", kTuple}, + {"BatchNormGrad", kTupleTensor3}, + {"BatchNormGradWithActivation", kTupleTensor3}, + {"BatchNormGradWithAddAndActivation", kTupleTensor4}, {"BatchNormGradGrad", kTupleTensor3}, {"BiasDropoutAdd", kTupleTensor2}, {"CSRSparseMatrixToSparseTensor", kTupleTensor3}, @@ -127,6 +129,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4 {"FusedSparseProximalAdagrad", kTupleTensor2}, {"GRU", kTupleTensor4}, {"GRUV2", kTupleTensor4}, + {"GRUV2Grad", kTupleTensor3}, {"GRUV2HiddenGrad", kTupleTensor3}, {"GRUV2HiddenGradCell", kTupleTensor3}, {"Geqrf", kTupleTensor2}, @@ -139,6 +142,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4 {"InstanceNormGrad", kTupleTensor3}, {"InstanceNormV2", kTupleTensor3}, {"InstanceNormV2Grad", kTupleTensor3}, + {"InvertPermutation", kAnyType}, {"LSTM", kTupleTensor5}, {"LSTMGrad", kTupleTensor4}, {"LSTMGradData", kTupleTensor3}, @@ -183,8 +187,6 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4 {"NMSWithMask", kTupleTensor3}, {"PReLUGrad", kTupleTensor2}, {"PSROIPooling", kAnyType}, - {"PriorityReplayBufferDestroy", kTupleTensor5}, - {"PriorityReplayBufferPush", kTupleTensor2}, {"PriorityReplayBufferSample", kTuple}, {"Qr", kTupleTensor2}, {"RNNTLoss", kTupleTensor2}, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_utils.cc b/mindspore/ccsrc/pipeline/pynative/pynative_utils.cc index 2e83be94f1b..d441ef01496 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_utils.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_utils.cc @@ -48,6 +48,21 @@ std::string GetObjIdFromPython(const py::handle &obj) { return out.cast(); } +std::string GetIdForPyTupleOrList(const py::handle &obj) { + auto p_list = py::cast(obj); + string prefix = py::isinstance(obj) ? "Tuple<" : "List<"; + if (p_list.empty()) { + prefix = "Empty:"; + } else { + for (size_t i = 0; i < p_list.size(); ++i) { + prefix += PyParser::GetIdByPyObj(p_list[i]) + ":"; + } + } + prefix.pop_back(); + prefix += ">"; + return prefix; +} + std::string GetFnInfoByPyObj(const py::object &obj) { std::string fn_info = obj.attr("__module__").cast(); fn_info += "_" + obj.attr("__name__").cast(); @@ -287,7 +302,9 @@ void Common::ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph) { std::string PyParser::GetIdByPyObj(const py::object &obj) { if (py::isinstance(obj)) { - return PyTensorCast(obj)->id(); + return obj.cast()->id(); + } else if (IsStubTensor(obj)) { + return ConvertStubTensor(obj)->id(); } else if (py::isinstance(obj)) { return obj.cast()->id(); } else if (py::isinstance(obj)) { @@ -306,18 +323,7 @@ std::string PyParser::GetIdByPyObj(const py::object &obj) { } else if (py::isinstance(obj)) { return "Ellipsis"; } else if (py::isinstance(obj) || py::isinstance(obj)) { - auto p_list = py::cast(obj); - string prefix = py::isinstance(obj) ? "Tuple<" : "List<"; - if (p_list.empty()) { - prefix = "Empty:"; - } else { - for (size_t i = 0; i < p_list.size(); ++i) { - prefix += PyParser::GetIdByPyObj(p_list[i]) + ":"; - } - } - prefix.pop_back(); - prefix += ">"; - return prefix; + return GetIdForPyTupleOrList(obj); } else if (py::isinstance(obj)) { return GetFnInfoByPyObj(obj); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 397a645c2af..66ddfd75612 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -215,8 +215,8 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, << (py_args.size() - filter_args_size) << ", but got:" << grads.size() << "."; } for (size_t i = 0; i < grads.size(); i++) { - if (py::isinstance(py_args[i])) { - if (!py::isinstance(grads[i])) { + if (py::isinstance(py_args[i]) || IsStubTensor(py_args[i])) { + if (!py::isinstance(grads[i]) && !IsStubTensor(grads[i])) { MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i << "th return value(gradient of the " << i << "th argument) should be Tensor, but got " << py::cast(grads[i].attr("__class__").attr("__name__")) @@ -313,8 +313,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got " << py::cast(grad_out.attr("__class__").attr("__name__")) << "."; } - auto actual_out_tensor = PyTensorCast(grad_out); - auto expected_out_tensor = PyTensorCast(expected_grad_out); + tensor::TensorPtr actual_out_tensor = + IsStubTensor(grad_out) ? ConvertStubTensor(grad_out) : py::cast(grad_out); + tensor::TensorPtr expected_out_tensor = IsStubTensor(expected_grad_out) + ? ConvertStubTensor(expected_grad_out) + : py::cast(expected_grad_out); MS_EXCEPTION_IF_NULL(actual_out_tensor); MS_EXCEPTION_IF_NULL(expected_out_tensor); if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) { diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 2dda016f687..daac9d63131 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -685,20 +685,15 @@ py::object MakeCOOTensor(const VectorRef &value_list) { return ret[0]; } -tensor::TensorPtr PyTensorCast(const py::handle &obj) { - if (!py::isinstance(obj)) { - return nullptr; +bool IsStubTensor(const py::handle &obj) { return py::hasattr(obj, stub::PY_ATTR_STUB); } + +tensor::TensorPtr ConvertStubTensor(const py::handle &obj) { + auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB); + auto stub = py_stub.cast(); + if (stub == nullptr) { + return py::getattr(obj, stub::PY_ATTR_TENSOR).cast(); } - auto is_stub_tensor = py::hasattr(obj, stub::PY_ATTR_STUB); - if (!is_stub_tensor) { - return py::cast(obj); - } - auto stub_node = py::getattr(obj, stub::PY_ATTR_STUB); - auto is_stub_tensor_sync = py::isinstance(stub_node); - if (!is_stub_tensor_sync) { - return py::cast(obj); - } - auto stub = py::getattr(obj, stub::PY_ATTR_STUB).cast(); - return stub->WaitValue()->cast(); + auto res = stub->WaitValue()->cast(); + return res; } } // namespace mindspore diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index fde2383d1f6..007df6248bb 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -308,6 +308,10 @@ def get_log2_size(size): return cast_res +def is_stub_tensor(tensor): + return hasattr(tensor, "stub") + + class Validator: """validator for checking input parameters""" @@ -622,7 +626,7 @@ class Validator: hit = False for template_type in template_types: if isinstance(template_type, mstype.Type): - if mstype._issubclass_(type_, template_type): # pylint: disable=W0212 + if mstype._issubclass_(type_, template_type): # pylint: disable=W0212 hit = True break elif type_ is template_type: @@ -1021,9 +1025,9 @@ class Validator: @staticmethod def check_sparse_tensor_input(indices, values, shape): """Common input check for SparseTensors.""" - if not isinstance(indices, Tensor_): + if not isinstance(indices, Tensor_) and not is_stub_tensor(indices): raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.") - if not isinstance(values, Tensor_): + if not isinstance(values, Tensor_) and not is_stub_tensor(values): raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.") if not isinstance(shape, tuple): raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.") @@ -1031,7 +1035,7 @@ class Validator: @staticmethod def check_csr_tensor_input(indptr, indices, values, shape): """Checks inputs type for CSRTensor.""" - if not isinstance(indptr, Tensor_): + if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr): raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.") Validator.check_sparse_tensor_input(indices, values, shape) @@ -1069,13 +1073,13 @@ class Validator: err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}." raise ValueError(err_msg1 + err_msg2) if len(values_shp) + 1 != len(csr_shp): - raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"\ - f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "\ - f"{len(csr_shp)}") - if values_shp[1: ] != csr_shp[2: ]: - raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"\ - f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"\ - f"got: {values_shp[1: ]}") + raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" + f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " + f"{len(csr_shp)}") + if values_shp[1:] != csr_shp[2:]: + raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," + f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" + f"got: {values_shp[1: ]}") @staticmethod def check_csr_tensor_dtype(indptr_dtype, indices_dtype): diff --git a/mindspore/python/mindspore/common/_stub_tensor.py b/mindspore/python/mindspore/common/_stub_tensor.py index 6089da4e14c..d1a69e3decf 100644 --- a/mindspore/python/mindspore/common/_stub_tensor.py +++ b/mindspore/python/mindspore/common/_stub_tensor.py @@ -14,46 +14,70 @@ # ============================================================================ """Stub Tensor implementation.""" +import inspect from functools import reduce from mindspore.common.tensor import Tensor from mindspore.common.dtype import type_size_in_bytes -from mindspore._c_expression import Tensor as Tensor_ from mindspore._c_expression import TensorNode, SequenceNode from mindspore.common.api import _convert_python_data -class StubTensor(Tensor): +def _stub_member(var, init): + def getx(stub): + return init if stub.tensor is None else getattr(stub.tensor, var) + + def setx(stub, value): + setattr(stub.stub_sync(), var, value) + return property(getx, setx) + + +def _stub_method(method): + def fun(*arg, **kwargs): + stub = arg[0] + arg = (stub.stub_sync(),) + arg[1:] + return method(*arg, **kwargs) + return fun + + +class StubTensor: """stub tensor for async op run.""" + const_arg = _stub_member("const_arg", None) + init = _stub_member("init", None) + init_finished = _stub_member("init_finished", False) + virtual_flag = _stub_member("virtual_flag", False) + parent_tensor_ = _stub_member("parent_tensor_", None) + index_of_parent_ = _stub_member("index_of_parent_", None) + slice_num_of_persistent_data_ = _stub_member("slice_num_of_persistent_data_", None) + slice_shape_of_persistent_data_ = _stub_member("slice_shape_of_persistent_data_", None) def __init__(self, stub): - Tensor.__init__(self, internal=True) self.stub = stub + self.tensor = None - def __repr__(self): - self.stub_sync() - return super().__repr__() + __repr__ = _stub_method(Tensor.__repr__) + __str__ = _stub_method(Tensor.__str__) + __setitem__ = _stub_method(Tensor.__setitem__) - def __str__(self): - self.stub_sync() - return super().__str__() - - def __setitem__(self, index, value): - self.stub_sync() - return super().__setitem__(index, value) + __lt__ = Tensor.__lt__ + __le__ = Tensor.__le__ + __gt__ = Tensor.__gt__ + __ge__ = Tensor.__ge__ + __eq__ = Tensor.__eq__ + __ne__ = Tensor.__ne__ @property def shape(self): """shape stub.""" if self.stub: return self.stub.get_shape() - return super().shape + return self.tensor.shape @property def dtype(self): """dtype stub.""" if self.stub: return self.stub.get_dtype() - return super().dtype + return self.tensor.dtype @property def size(self): @@ -76,31 +100,20 @@ class StubTensor(Tensor): """ndim stub.""" return len(self.shape) + @property + def adapter_flag(self): + return False + @property def strides(self): """strides stub.""" - self.stub_sync() - return super().strides + return self.stub_sync().strides @property def has_init(self): """has_init stub.""" return False - @property - def adapter_flag(self): - """adapter_flag stub.""" - if self.stub: - return False - return super().adapter_flag - - def stub_sync(self): - """data sync to get real tensor""" - if self.stub: - val = self.stub.get_value() - Tensor_.__init__(self, val) - self.stub = None - def ndimension(self): r""" Alias for :func:`mindspore.Tensor.ndim`. @@ -113,45 +126,30 @@ class StubTensor(Tensor): """ return self.ndim - def asnumpy(self): - """api stub.""" - self.stub_sync() - return super().asnumpy() + asnumpy = _stub_method(Tensor.asnumpy) + is_persistent_data = _stub_method(Tensor.is_persistent_data) + asnumpy_of_slice_persistent_data = _stub_method(Tensor.asnumpy_of_slice_persistent_data) + slice_num_of_persistent_data = _stub_method(Tensor.slice_num_of_persistent_data) + slice_shape_of_persistent_data = _stub_method(Tensor.slice_shape_of_persistent_data) + flush_from_cache = _stub_method(Tensor.flush_from_cache) - def is_persistent_data(self): - """ - For details, please refer to :`mindspore.common.tensor.is_persistent_data`. - """ - self.stub_sync() - super().is_persistent_data() + def stub_sync(self): + if self.stub: + val = self.stub.get_value() + self.tensor = Tensor(val, internal=True) + self.stub = None + return self.tensor - def asnumpy_of_slice_persistent_data(self, param_key, slice_index): - """ - For details, please refer to :`mindspore.common.tensor.asnumpy_of_slice_persistent_data`. - """ - self.stub_sync() - return super().asnumpy_of_slice_persistent_data(param_key, slice_index) - def slice_num_of_persistent_data(self): - """ - For details, please refer to :`mindspore.common.tensor.slice_num_of_persistent_data`. - """ - self.stub_sync() - return super().slice_num_of_persistent_data() +def _init_stub_tensor_api(): + stub_func = dir(StubTensor) + for attr in dir(Tensor): + if attr not in stub_func: + func = inspect.getattr_static(Tensor, attr) + setattr(StubTensor, attr, func) - def slice_shape_of_persistent_data(self): - """ - For details, please refer to :`mindspore.common.tensor.slice_shape_of_persistent_data`. - """ - self.stub_sync() - return super().slice_shape_of_persistent_data() - def flush_from_cache(self): - """ - For details, please refer to :`mindspore.common.tensor.flush_from_cache`. - """ - self.stub_sync() - super().flush_from_cache() +_init_stub_tensor_api() def _convert_stub(stub): diff --git a/mindspore/python/mindspore/common/_utils.py b/mindspore/python/mindspore/common/_utils.py index 20e85f11591..de6dc577e8d 100644 --- a/mindspore/python/mindspore/common/_utils.py +++ b/mindspore/python/mindspore/common/_utils.py @@ -82,9 +82,5 @@ def dict_setitem(dic, key, val): return dic -def is_stub_tensor(tensor): - return getattr(tensor, "stub", False) - - def raise_func(type, script): raise type(script) diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index f8c1720b205..a10ea64ea8e 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -42,7 +42,7 @@ from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor from mindspore.parallel._ps_context import _is_role_sched from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \ _get_pipeline_stages, _is_in_auto_parallel_mode -from mindspore._checkparam import Validator +from mindspore._checkparam import Validator, is_stub_tensor from mindspore.common._utils import is_shape_unknown from mindspore.common.mutable import mutable from mindspore.common._register_for_adapter import ms_adapter_registry @@ -106,7 +106,8 @@ def _wrap_func(fn): def _check_all_tensor(sequence): for element in sequence: - if not isinstance(element, Tensor) and not (isinstance(element, tuple) and _check_all_tensor(element)): + if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple) + and _check_all_tensor(element)): return False return True @@ -359,7 +360,7 @@ class _MindsporeFunctionExecutor: compile_args = _restore_mutable_attr(args, compile_args) generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \ - str(self.fn.__code__.co_firstlineno) + str(self.fn.__code__.co_firstlineno) if _pynative_executor.grad_flag(): generate_name = generate_name + ".grad" if _is_pynative_parallel(): @@ -690,7 +691,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None): ... closure_fn(inputs, func) """ - logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " \ + logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " "Please use 'mindspore.jit' instead.") return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config) @@ -836,7 +837,7 @@ def ms_class(cls): 20 """ - logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " \ + logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " "Please use 'mindspore.jit_class' instead.") # Check if cls is of type class. @@ -1584,6 +1585,7 @@ def _bind_device_context(): """Bind device context to current thread""" _bind_device_ctx() + _cell_graph_executor = _CellGraphExecutor() _pynative_executor = _PyNativeExecutor() diff --git a/mindspore/python/mindspore/common/sparse_tensor.py b/mindspore/python/mindspore/common/sparse_tensor.py index 00f6ed72dac..9faa3fc0750 100644 --- a/mindspore/python/mindspore/common/sparse_tensor.py +++ b/mindspore/python/mindspore/common/sparse_tensor.py @@ -22,13 +22,13 @@ from typing import Tuple from mindspore import log as logger from mindspore.common import dtype as mstype from mindspore.common._register_for_tensor import tensor_operator_registry -from mindspore.common._utils import is_stub_tensor from mindspore.common.tensor import Tensor from mindspore._c_expression import COOTensor as COOTensor_ from mindspore._c_expression import CSRTensor as CSRTensor_ from mindspore._c_expression import RowTensor as RowTensor_ from mindspore._c_expression import Tensor as Tensor_ from mindspore._checkparam import Validator as validator +from mindspore._checkparam import is_stub_tensor class RowTensorInner(RowTensor_): @@ -49,7 +49,7 @@ class RowTensorInner(RowTensor_): # Init a RowTensor from indices, values and shape else: if is_stub_tensor(values): - values.stub_sync() + values = values.stub_sync() RowTensor_.__init__(self, indices, values, shape) self.init_finished = True @@ -189,9 +189,9 @@ class SparseTensor(COOTensor_): if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)): raise TypeError("Inputs must follow: COOTensor(indices, values, shape).") if is_stub_tensor(indices): - indices.stub_sync() + indices = indices.stub_sync() if is_stub_tensor(values): - values.stub_sync() + values = values.stub_sync() COOTensor_.__init__(self, indices, values, shape) @property @@ -280,9 +280,9 @@ class COOTensor(COOTensor_): validator.check_coo_tensor_dtype(indices.dtype) indices = tensor_operator_registry.get('stop_gradient')(indices) if is_stub_tensor(indices): - indices.stub_sync() + indices = indices.stub_sync() if is_stub_tensor(values): - values.stub_sync() + values = values.stub_sync() COOTensor_.__init__(self, indices, values, shape) self.init_finished = True @@ -617,11 +617,11 @@ class CSRTensor(CSRTensor_): indptr = tensor_operator_registry.get('stop_gradient')(indptr) indices = tensor_operator_registry.get('stop_gradient')(indices) if is_stub_tensor(indptr): - indptr.stub_sync() + indptr = indptr.stub_sync() if is_stub_tensor(values): - values.stub_sync() + values = values.stub_sync() if is_stub_tensor(indices): - indices.stub_sync() + indices = indices.stub_sync() CSRTensor_.__init__(self, indptr, indices, values, shape) self.init_finished = True diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 4b298e21528..a785a691b77 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -16,12 +16,13 @@ __all__ = ['Tensor'] +import abc import math import numbers import numpy as np from mindspore.communication.management import get_group_size -from mindspore.common._utils import is_shape_unknown, is_stub_tensor +from mindspore.common._utils import is_shape_unknown from mindspore.common.seed import get_seed from mindspore import context from mindspore import log as logger @@ -30,7 +31,7 @@ from mindspore.common import dtype as mstype from mindspore.common._utils import get_slice_num from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore._c_expression import Tensor as Tensor_ -from mindspore._checkparam import Rel, check_is_number +from mindspore._checkparam import Rel, check_is_number, is_stub_tensor from mindspore._checkparam import Validator as validator from mindspore._check_jit_forbidden_api import jit_forbidden_register @@ -42,7 +43,7 @@ np_types = (np.int8, np.int16, np.int32, np.int64, def _check_input_data_type(input_data): """Check the type of input_data for Tensor""" validator.check_value_type('input_data', input_data, - (Tensor_, np.ndarray, np.str_, list, tuple, float, int, bool, complex), + (Tensor_, Tensor, np.ndarray, np.str_, list, tuple, float, int, bool, complex), 'Tensor') valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128) @@ -67,7 +68,13 @@ def _check_input_data_type(input_data): f"For Tensor, the input_data is {input_data} that contain unsupported element.") -class Tensor(Tensor_): +class _TensorMeta(type(Tensor_), abc.ABCMeta): + """ + Meta class for Tensor. Used internally. + """ + + +class Tensor(Tensor_, metaclass=_TensorMeta): """ Tensor is a data structure that stores an n-dimensional array. @@ -151,12 +158,13 @@ class Tensor(Tensor_): def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False): self.init_finished = False + if internal: if input_data is not None: Tensor_.__init__(self, input_data) else: if is_stub_tensor(input_data): - input_data.stub_sync() + input_data = input_data.stub_sync() # If input data is numpy number, convert it to np array if isinstance(input_data, np_types): @@ -174,8 +182,8 @@ class Tensor(Tensor_): else: _check_input_data_type(input_data) if dtype is not None: - validator.check_type_name( - 'dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor") + validator.check_type_name('dtype', dtype, mstype.number_type + + (mstype.bool_, mstype.string), "Tensor") else: dtype = self._set_default_dtype(input_data, dtype) @@ -186,8 +194,8 @@ class Tensor(Tensor_): Tensor_.__init__(self, input_data, dtype) else: Tensor_.__init__(self, input_data) + validator.check_value_type('const_arg', const_arg, bool, 'Tensor') - validator.check_value_type('const_arg', const_arg, bool, 'Tensor') self.const_arg = const_arg self.virtual_flag = False self.init = init @@ -202,6 +210,16 @@ class Tensor(Tensor_): self.slice_num_of_persistent_data_ = None self.slice_shape_of_persistent_data_ = None + @classmethod + def __subclasshook__(cls, sub): + """ + Subclass with stub_sync attr will be instance of Tensor + """ + if cls is Tensor: + if any("stub_sync" in s.__dict__ for s in sub.__mro__): + return True + return NotImplemented + @staticmethod def _set_default_dtype(input_data, dtype): """Set tensor default dtype""" @@ -401,8 +419,6 @@ class Tensor(Tensor_): def __setitem__(self, index, value): out = tensor_operator_registry.get('__setitem__')(self, index, value) - if is_stub_tensor(out): - out.stub_sync() self.assign_value(out) if self.parent_tensor_ is not None and self.index_of_parent_ is not None: self.parent_tensor_.__setitem__(self.index_of_parent_, self) @@ -587,6 +603,8 @@ class Tensor(Tensor_): Returns: Tensor, Tensor that's been assigned. """ + if is_stub_tensor(value): + value = value.stub_sync() self.assign_value_cpp(value) return self