forked from mindspore-Ecosystem/mindspore
!49349 [StubTensor]Optimize StubTensor
Merge pull request !49349 from jiaoy1224/fixbug2
This commit is contained in:
commit
4a16ad175d
|
@ -463,7 +463,9 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePt
|
|||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
ValuePtr tensor_ptr = nullptr;
|
||||
if (py::isinstance<tensor::Tensor>(input_object)) {
|
||||
tensor_ptr = PyTensorCast(input_object);
|
||||
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
|
||||
} else if (IsStubTensor(input_object)) {
|
||||
tensor_ptr = ConvertStubTensor(input_object);
|
||||
} else if (py::isinstance<py::float_>(input_object)) {
|
||||
double input_value = py::cast<py::float_>(input_object);
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
||||
|
|
|
@ -34,8 +34,8 @@ namespace mindspore {
|
|||
py::object AnyToPyData(const Any &value);
|
||||
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
|
||||
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
|
||||
// Convert python (stub) tensor to c++ tensor.
|
||||
COMMON_EXPORT tensor::TensorPtr PyTensorCast(const py::handle &obj);
|
||||
COMMON_EXPORT bool IsStubTensor(const py::handle &obj);
|
||||
COMMON_EXPORT tensor::TensorPtr ConvertStubTensor(const py::handle &obj);
|
||||
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
const std::shared_ptr<py::object> &ret_val);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
namespace mindspore {
|
||||
namespace stub {
|
||||
constexpr auto PY_ATTR_STUB = "stub";
|
||||
constexpr auto PY_ATTR_TENSOR = "tensor";
|
||||
|
||||
namespace py = pybind11;
|
||||
class StubNode;
|
||||
|
|
|
@ -599,7 +599,9 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
|
|||
static const std::vector<DataConverterPtr> data_converters{
|
||||
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
|
||||
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(PyTensorCast),
|
||||
std::make_shared<ByFuncDataConverter>([](const py::object &obj) -> bool { return IsStubTensor(obj); },
|
||||
[](const py::object &obj) -> ValuePtr { return ConvertStubTensor(obj); }),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||
|
|
|
@ -50,6 +50,7 @@
|
|||
#include "utils/crypto.h"
|
||||
#include "utils/phase.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "include/common/utils/stub_tensor.h"
|
||||
#include "utils/interpret_node_recorder.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
|
@ -210,7 +211,7 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
}
|
||||
|
||||
if (py::isinstance<Tensor>(arg)) {
|
||||
TensorPtr tensor = PyTensorCast(arg);
|
||||
auto tensor = py::cast<TensorPtr>(arg);
|
||||
if (tensor->data_type() == kNumberTypeBool) {
|
||||
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
|
||||
<< "operator compilation failure. For more details, please refer to the FAQ at "
|
||||
|
@ -218,9 +219,9 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
}
|
||||
}
|
||||
|
||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
|
||||
py::isinstance<Number>(arg) || py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) ||
|
||||
py::isinstance<COOTensor>(arg);
|
||||
return IsStubTensor(arg) || py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) ||
|
||||
py::isinstance<py::none>(arg) || py::isinstance<Number>(arg) || py::isinstance<Tensor>(arg) ||
|
||||
py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg);
|
||||
}
|
||||
|
||||
std::string GetCompileExceptionInfo() {
|
||||
|
@ -475,13 +476,22 @@ py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple
|
|||
|
||||
size_t count = 0;
|
||||
for (auto arg_obj : inputs) {
|
||||
std::shared_ptr<Tensor> m_tensor = nullptr;
|
||||
bool is_tensor = false;
|
||||
if (py::isinstance<Tensor>(arg_obj)) {
|
||||
m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
|
||||
is_tensor = true;
|
||||
} else if (IsStubTensor(arg_obj)) {
|
||||
m_tensor = ConvertStubTensor(arg_obj);
|
||||
is_tensor = true;
|
||||
}
|
||||
if (is_tensor && m_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_tensor != nullptr) {
|
||||
MS_LOG(DEBUG) << "Verify Tensor";
|
||||
std::shared_ptr<Tensor> m_tensor = PyTensorCast(arg_obj);
|
||||
if (m_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
|
||||
return false;
|
||||
}
|
||||
auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
|
||||
ShapeVector sig_shape = sig->shape();
|
||||
TypePtr sig_type = sig->Dtype();
|
||||
|
|
|
@ -74,11 +74,13 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"BasicLSTMCellCStateGradV2", kTupleTensor2},
|
||||
{"BasicLSTMCellInputGrad", kTupleTensor2},
|
||||
{"BasicLSTMCellWeightGrad", kTupleTensor2},
|
||||
{"BatchNorm", kTuple},
|
||||
{"BatchNorm", kTupleTensor5},
|
||||
{"BatchNormFold2GradD", kTupleTensor4},
|
||||
{"BatchNormFold2GradReduce", kTupleTensor2},
|
||||
{"BatchNormFoldD", kTupleTensor7},
|
||||
{"BatchNormGrad", kTuple},
|
||||
{"BatchNormGrad", kTupleTensor3},
|
||||
{"BatchNormGradWithActivation", kTupleTensor3},
|
||||
{"BatchNormGradWithAddAndActivation", kTupleTensor4},
|
||||
{"BatchNormGradGrad", kTupleTensor3},
|
||||
{"BiasDropoutAdd", kTupleTensor2},
|
||||
{"CSRSparseMatrixToSparseTensor", kTupleTensor3},
|
||||
|
@ -127,6 +129,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"FusedSparseProximalAdagrad", kTupleTensor2},
|
||||
{"GRU", kTupleTensor4},
|
||||
{"GRUV2", kTupleTensor4},
|
||||
{"GRUV2Grad", kTupleTensor3},
|
||||
{"GRUV2HiddenGrad", kTupleTensor3},
|
||||
{"GRUV2HiddenGradCell", kTupleTensor3},
|
||||
{"Geqrf", kTupleTensor2},
|
||||
|
@ -139,6 +142,7 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"InstanceNormGrad", kTupleTensor3},
|
||||
{"InstanceNormV2", kTupleTensor3},
|
||||
{"InstanceNormV2Grad", kTupleTensor3},
|
||||
{"InvertPermutation", kAnyType},
|
||||
{"LSTM", kTupleTensor5},
|
||||
{"LSTMGrad", kTupleTensor4},
|
||||
{"LSTMGradData", kTupleTensor3},
|
||||
|
@ -183,8 +187,6 @@ inline static PredictOutTypeMap out_type_prediction = {{"ActsULQ", kTupleTensor4
|
|||
{"NMSWithMask", kTupleTensor3},
|
||||
{"PReLUGrad", kTupleTensor2},
|
||||
{"PSROIPooling", kAnyType},
|
||||
{"PriorityReplayBufferDestroy", kTupleTensor5},
|
||||
{"PriorityReplayBufferPush", kTupleTensor2},
|
||||
{"PriorityReplayBufferSample", kTuple},
|
||||
{"Qr", kTupleTensor2},
|
||||
{"RNNTLoss", kTupleTensor2},
|
||||
|
|
|
@ -60,6 +60,21 @@ std::string GetObjIdFromPython(const py::handle &obj) {
|
|||
return out.cast<std::string>();
|
||||
}
|
||||
|
||||
std::string GetIdForPyTupleOrList(const py::handle &obj) {
|
||||
auto p_list = py::cast<py::tuple>(obj);
|
||||
string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
|
||||
if (p_list.empty()) {
|
||||
prefix = "Empty:";
|
||||
} else {
|
||||
for (size_t i = 0; i < p_list.size(); ++i) {
|
||||
prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
|
||||
}
|
||||
}
|
||||
prefix.pop_back();
|
||||
prefix += ">";
|
||||
return prefix;
|
||||
}
|
||||
|
||||
std::string GetFnInfoByPyObj(const py::object &obj) {
|
||||
std::string fn_info = obj.attr("__module__").cast<std::string>();
|
||||
fn_info += "_" + obj.attr("__name__").cast<std::string>();
|
||||
|
@ -299,7 +314,9 @@ void Common::ReplaceCNodeWithValueNode(const FuncGraphPtr &bprop_graph) {
|
|||
|
||||
std::string PyParser::GetIdByPyObj(const py::object &obj) {
|
||||
if (py::isinstance<tensor::Tensor>(obj)) {
|
||||
return PyTensorCast(obj)->id();
|
||||
return obj.cast<tensor::TensorPtr>()->id();
|
||||
} else if (IsStubTensor(obj)) {
|
||||
return ConvertStubTensor(obj)->id();
|
||||
} else if (py::isinstance<Cell>(obj)) {
|
||||
return obj.cast<CellPtr>()->id();
|
||||
} else if (py::isinstance<mindspore::Type>(obj)) {
|
||||
|
@ -318,18 +335,7 @@ std::string PyParser::GetIdByPyObj(const py::object &obj) {
|
|||
} else if (py::isinstance<py::ellipsis>(obj)) {
|
||||
return "Ellipsis";
|
||||
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
auto p_list = py::cast<py::tuple>(obj);
|
||||
string prefix = py::isinstance<py::tuple>(obj) ? "Tuple<" : "List<";
|
||||
if (p_list.empty()) {
|
||||
prefix = "Empty:";
|
||||
} else {
|
||||
for (size_t i = 0; i < p_list.size(); ++i) {
|
||||
prefix += PyParser::GetIdByPyObj(p_list[i]) + ":";
|
||||
}
|
||||
}
|
||||
prefix.pop_back();
|
||||
prefix += ">";
|
||||
return prefix;
|
||||
return GetIdForPyTupleOrList(obj);
|
||||
} else if (py::isinstance<py::function>(obj)) {
|
||||
return GetFnInfoByPyObj(obj);
|
||||
}
|
||||
|
|
|
@ -215,8 +215,8 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args,
|
|||
<< (py_args.size() - filter_args_size) << ", but got:" << grads.size() << ".";
|
||||
}
|
||||
for (size_t i = 0; i < grads.size(); i++) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i])) {
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i]) || IsStubTensor(py_args[i])) {
|
||||
if (!py::isinstance<tensor::Tensor>(grads[i]) && !IsStubTensor(grads[i])) {
|
||||
MS_EXCEPTION(ValueError) << "For user defined method 'bprop' of net '" << bprop_cls_name << "', the " << i
|
||||
<< "th return value(gradient of the " << i << "th argument) should be Tensor, but got "
|
||||
<< py::cast<std::string>(grads[i].attr("__class__").attr("__name__"))
|
||||
|
@ -313,8 +313,11 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
|
||||
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
|
||||
}
|
||||
auto actual_out_tensor = PyTensorCast(grad_out);
|
||||
auto expected_out_tensor = PyTensorCast(expected_grad_out);
|
||||
tensor::TensorPtr actual_out_tensor =
|
||||
IsStubTensor(grad_out) ? ConvertStubTensor(grad_out) : py::cast<tensor::TensorPtr>(grad_out);
|
||||
tensor::TensorPtr expected_out_tensor = IsStubTensor(expected_grad_out)
|
||||
? ConvertStubTensor(expected_grad_out)
|
||||
: py::cast<tensor::TensorPtr>(expected_grad_out);
|
||||
MS_EXCEPTION_IF_NULL(actual_out_tensor);
|
||||
MS_EXCEPTION_IF_NULL(expected_out_tensor);
|
||||
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
|
||||
|
|
|
@ -685,20 +685,15 @@ py::object MakeCOOTensor(const VectorRef &value_list) {
|
|||
return ret[0];
|
||||
}
|
||||
|
||||
tensor::TensorPtr PyTensorCast(const py::handle &obj) {
|
||||
if (!py::isinstance<tensor::Tensor>(obj)) {
|
||||
return nullptr;
|
||||
bool IsStubTensor(const py::handle &obj) { return py::hasattr(obj, stub::PY_ATTR_STUB); }
|
||||
|
||||
tensor::TensorPtr ConvertStubTensor(const py::handle &obj) {
|
||||
auto py_stub = py::getattr(obj, stub::PY_ATTR_STUB);
|
||||
auto stub = py_stub.cast<stub::StubNodePtr>();
|
||||
if (stub == nullptr) {
|
||||
return py::getattr(obj, stub::PY_ATTR_TENSOR).cast<tensor::TensorPtr>();
|
||||
}
|
||||
auto is_stub_tensor = py::hasattr(obj, stub::PY_ATTR_STUB);
|
||||
if (!is_stub_tensor) {
|
||||
return py::cast<tensor::TensorPtr>(obj);
|
||||
}
|
||||
auto stub_node = py::getattr(obj, stub::PY_ATTR_STUB);
|
||||
auto is_stub_tensor_sync = py::isinstance<stub::StubNode>(stub_node);
|
||||
if (!is_stub_tensor_sync) {
|
||||
return py::cast<tensor::TensorPtr>(obj);
|
||||
}
|
||||
auto stub = py::getattr(obj, stub::PY_ATTR_STUB).cast<stub::StubNodePtr>();
|
||||
return stub->WaitValue()->cast<tensor::TensorPtr>();
|
||||
auto res = stub->WaitValue()->cast<tensor::TensorPtr>();
|
||||
return res;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -308,6 +308,10 @@ def get_log2_size(size):
|
|||
return cast_res
|
||||
|
||||
|
||||
def is_stub_tensor(tensor):
|
||||
return hasattr(tensor, "stub")
|
||||
|
||||
|
||||
class Validator:
|
||||
"""validator for checking input parameters"""
|
||||
|
||||
|
@ -622,7 +626,7 @@ class Validator:
|
|||
hit = False
|
||||
for template_type in template_types:
|
||||
if isinstance(template_type, mstype.Type):
|
||||
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
|
||||
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
|
||||
hit = True
|
||||
break
|
||||
elif type_ is template_type:
|
||||
|
@ -1021,9 +1025,9 @@ class Validator:
|
|||
@staticmethod
|
||||
def check_sparse_tensor_input(indices, values, shape):
|
||||
"""Common input check for SparseTensors."""
|
||||
if not isinstance(indices, Tensor_):
|
||||
if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
|
||||
raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
|
||||
if not isinstance(values, Tensor_):
|
||||
if not isinstance(values, Tensor_) and not is_stub_tensor(values):
|
||||
raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
|
||||
if not isinstance(shape, tuple):
|
||||
raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
|
||||
|
@ -1031,7 +1035,7 @@ class Validator:
|
|||
@staticmethod
|
||||
def check_csr_tensor_input(indptr, indices, values, shape):
|
||||
"""Checks inputs type for CSRTensor."""
|
||||
if not isinstance(indptr, Tensor_):
|
||||
if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
|
||||
raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
|
||||
Validator.check_sparse_tensor_input(indices, values, shape)
|
||||
|
||||
|
@ -1069,13 +1073,13 @@ class Validator:
|
|||
err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
|
||||
raise ValueError(err_msg1 + err_msg2)
|
||||
if len(values_shp) + 1 != len(csr_shp):
|
||||
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"\
|
||||
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "\
|
||||
f"{len(csr_shp)}")
|
||||
if values_shp[1: ] != csr_shp[2: ]:
|
||||
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"\
|
||||
f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"\
|
||||
f"got: {values_shp[1: ]}")
|
||||
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got"
|
||||
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: "
|
||||
f"{len(csr_shp)}")
|
||||
if values_shp[1:] != csr_shp[2:]:
|
||||
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ],"
|
||||
f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]"
|
||||
f"got: {values_shp[1: ]}")
|
||||
|
||||
@staticmethod
|
||||
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
|
||||
|
|
|
@ -14,46 +14,70 @@
|
|||
# ============================================================================
|
||||
"""Stub Tensor implementation."""
|
||||
|
||||
import inspect
|
||||
from functools import reduce
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import type_size_in_bytes
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._c_expression import TensorNode, SequenceNode
|
||||
from mindspore.common.api import _convert_python_data
|
||||
|
||||
|
||||
class StubTensor(Tensor):
|
||||
def _stub_member(var, init):
|
||||
def getx(stub):
|
||||
return init if stub.tensor is None else getattr(stub.tensor, var)
|
||||
|
||||
def setx(stub, value):
|
||||
setattr(stub.stub_sync(), var, value)
|
||||
return property(getx, setx)
|
||||
|
||||
|
||||
def _stub_method(method):
|
||||
def fun(*arg, **kwargs):
|
||||
stub = arg[0]
|
||||
arg = (stub.stub_sync(),) + arg[1:]
|
||||
return method(*arg, **kwargs)
|
||||
return fun
|
||||
|
||||
|
||||
class StubTensor:
|
||||
"""stub tensor for async op run."""
|
||||
const_arg = _stub_member("const_arg", None)
|
||||
init = _stub_member("init", None)
|
||||
init_finished = _stub_member("init_finished", False)
|
||||
virtual_flag = _stub_member("virtual_flag", False)
|
||||
parent_tensor_ = _stub_member("parent_tensor_", None)
|
||||
index_of_parent_ = _stub_member("index_of_parent_", None)
|
||||
slice_num_of_persistent_data_ = _stub_member("slice_num_of_persistent_data_", None)
|
||||
slice_shape_of_persistent_data_ = _stub_member("slice_shape_of_persistent_data_", None)
|
||||
|
||||
def __init__(self, stub):
|
||||
Tensor.__init__(self, internal=True)
|
||||
self.stub = stub
|
||||
self.tensor = None
|
||||
|
||||
def __repr__(self):
|
||||
self.stub_sync()
|
||||
return super().__repr__()
|
||||
__repr__ = _stub_method(Tensor.__repr__)
|
||||
__str__ = _stub_method(Tensor.__str__)
|
||||
__setitem__ = _stub_method(Tensor.__setitem__)
|
||||
|
||||
def __str__(self):
|
||||
self.stub_sync()
|
||||
return super().__str__()
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.stub_sync()
|
||||
return super().__setitem__(index, value)
|
||||
__lt__ = Tensor.__lt__
|
||||
__le__ = Tensor.__le__
|
||||
__gt__ = Tensor.__gt__
|
||||
__ge__ = Tensor.__ge__
|
||||
__eq__ = Tensor.__eq__
|
||||
__ne__ = Tensor.__ne__
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""shape stub."""
|
||||
if self.stub:
|
||||
return self.stub.get_shape()
|
||||
return super().shape
|
||||
return self.tensor.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""dtype stub."""
|
||||
if self.stub:
|
||||
return self.stub.get_dtype()
|
||||
return super().dtype
|
||||
return self.tensor.dtype
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
|
@ -76,31 +100,20 @@ class StubTensor(Tensor):
|
|||
"""ndim stub."""
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def adapter_flag(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def strides(self):
|
||||
"""strides stub."""
|
||||
self.stub_sync()
|
||||
return super().strides
|
||||
return self.stub_sync().strides
|
||||
|
||||
@property
|
||||
def has_init(self):
|
||||
"""has_init stub."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def adapter_flag(self):
|
||||
"""adapter_flag stub."""
|
||||
if self.stub:
|
||||
return False
|
||||
return super().adapter_flag
|
||||
|
||||
def stub_sync(self):
|
||||
"""data sync to get real tensor"""
|
||||
if self.stub:
|
||||
val = self.stub.get_value()
|
||||
Tensor_.__init__(self, val)
|
||||
self.stub = None
|
||||
|
||||
def ndimension(self):
|
||||
r"""
|
||||
Alias for :func:`mindspore.Tensor.ndim`.
|
||||
|
@ -113,45 +126,30 @@ class StubTensor(Tensor):
|
|||
"""
|
||||
return self.ndim
|
||||
|
||||
def asnumpy(self):
|
||||
"""api stub."""
|
||||
self.stub_sync()
|
||||
return super().asnumpy()
|
||||
asnumpy = _stub_method(Tensor.asnumpy)
|
||||
is_persistent_data = _stub_method(Tensor.is_persistent_data)
|
||||
asnumpy_of_slice_persistent_data = _stub_method(Tensor.asnumpy_of_slice_persistent_data)
|
||||
slice_num_of_persistent_data = _stub_method(Tensor.slice_num_of_persistent_data)
|
||||
slice_shape_of_persistent_data = _stub_method(Tensor.slice_shape_of_persistent_data)
|
||||
flush_from_cache = _stub_method(Tensor.flush_from_cache)
|
||||
|
||||
def is_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.is_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
super().is_persistent_data()
|
||||
def stub_sync(self):
|
||||
if self.stub:
|
||||
val = self.stub.get_value()
|
||||
self.tensor = Tensor(val, internal=True)
|
||||
self.stub = None
|
||||
return self.tensor
|
||||
|
||||
def asnumpy_of_slice_persistent_data(self, param_key, slice_index):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.asnumpy_of_slice_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().asnumpy_of_slice_persistent_data(param_key, slice_index)
|
||||
|
||||
def slice_num_of_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.slice_num_of_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().slice_num_of_persistent_data()
|
||||
def _init_stub_tensor_api():
|
||||
stub_func = dir(StubTensor)
|
||||
for attr in dir(Tensor):
|
||||
if attr not in stub_func:
|
||||
func = inspect.getattr_static(Tensor, attr)
|
||||
setattr(StubTensor, attr, func)
|
||||
|
||||
def slice_shape_of_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.slice_shape_of_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().slice_shape_of_persistent_data()
|
||||
|
||||
def flush_from_cache(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.flush_from_cache`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
super().flush_from_cache()
|
||||
_init_stub_tensor_api()
|
||||
|
||||
|
||||
def _convert_stub(stub):
|
||||
|
|
|
@ -82,9 +82,5 @@ def dict_setitem(dic, key, val):
|
|||
return dic
|
||||
|
||||
|
||||
def is_stub_tensor(tensor):
|
||||
return getattr(tensor, "stub", False)
|
||||
|
||||
|
||||
def raise_func(type, script):
|
||||
raise type(script)
|
||||
|
|
|
@ -42,7 +42,7 @@ from mindspore._c_expression import GraphExecutor_, Tensor, CSRTensor, RowTensor
|
|||
from mindspore.parallel._ps_context import _is_role_sched
|
||||
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_pynative_parallel, \
|
||||
_get_pipeline_stages, _is_in_auto_parallel_mode
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore._checkparam import Validator, is_stub_tensor
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.mutable import mutable
|
||||
from mindspore.common._register_for_adapter import ms_adapter_registry
|
||||
|
@ -106,7 +106,8 @@ def _wrap_func(fn):
|
|||
|
||||
def _check_all_tensor(sequence):
|
||||
for element in sequence:
|
||||
if not isinstance(element, Tensor) and not (isinstance(element, tuple) and _check_all_tensor(element)):
|
||||
if not isinstance(element, Tensor) and not is_stub_tensor(element) and not (isinstance(element, tuple)
|
||||
and _check_all_tensor(element)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -359,7 +360,7 @@ class _MindsporeFunctionExecutor:
|
|||
compile_args = _restore_mutable_attr(args, compile_args)
|
||||
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
||||
str(self.fn.__code__.co_firstlineno)
|
||||
str(self.fn.__code__.co_firstlineno)
|
||||
if _pynative_executor.grad_flag():
|
||||
generate_name = generate_name + ".grad"
|
||||
if _is_pynative_parallel():
|
||||
|
@ -690,7 +691,7 @@ def ms_function(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
... closure_fn(inputs, func)
|
||||
"""
|
||||
|
||||
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. " \
|
||||
logger.warning("'mindspore.ms_function' will be deprecated and removed in a future version. "
|
||||
"Please use 'mindspore.jit' instead.")
|
||||
return jit(fn=fn, input_signature=input_signature, hash_args=hash_args, jit_config=jit_config)
|
||||
|
||||
|
@ -836,7 +837,7 @@ def ms_class(cls):
|
|||
20
|
||||
"""
|
||||
|
||||
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. " \
|
||||
logger.warning("'mindspore.ms_class' will be deprecated and removed in a future version. "
|
||||
"Please use 'mindspore.jit_class' instead.")
|
||||
|
||||
# Check if cls is of type class.
|
||||
|
@ -1584,6 +1585,7 @@ def _bind_device_context():
|
|||
"""Bind device context to current thread"""
|
||||
_bind_device_ctx()
|
||||
|
||||
|
||||
_cell_graph_executor = _CellGraphExecutor()
|
||||
_pynative_executor = _PyNativeExecutor()
|
||||
|
||||
|
|
|
@ -22,13 +22,13 @@ from typing import Tuple
|
|||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common._utils import is_stub_tensor
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import COOTensor as COOTensor_
|
||||
from mindspore._c_expression import CSRTensor as CSRTensor_
|
||||
from mindspore._c_expression import RowTensor as RowTensor_
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import is_stub_tensor
|
||||
|
||||
|
||||
class RowTensorInner(RowTensor_):
|
||||
|
@ -49,7 +49,7 @@ class RowTensorInner(RowTensor_):
|
|||
# Init a RowTensor from indices, values and shape
|
||||
else:
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
values = values.stub_sync()
|
||||
RowTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -189,9 +189,9 @@ class SparseTensor(COOTensor_):
|
|||
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
|
||||
if is_stub_tensor(indices):
|
||||
indices.stub_sync()
|
||||
indices = indices.stub_sync()
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
values = values.stub_sync()
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
|
||||
@property
|
||||
|
@ -280,9 +280,9 @@ class COOTensor(COOTensor_):
|
|||
validator.check_coo_tensor_dtype(indices.dtype)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
if is_stub_tensor(indices):
|
||||
indices.stub_sync()
|
||||
indices = indices.stub_sync()
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
values = values.stub_sync()
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -617,11 +617,11 @@ class CSRTensor(CSRTensor_):
|
|||
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
if is_stub_tensor(indptr):
|
||||
indptr.stub_sync()
|
||||
indptr = indptr.stub_sync()
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
values = values.stub_sync()
|
||||
if is_stub_tensor(indices):
|
||||
indices.stub_sync()
|
||||
indices = indices.stub_sync()
|
||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
|
|
@ -16,12 +16,13 @@
|
|||
|
||||
__all__ = ['Tensor']
|
||||
|
||||
import abc
|
||||
import math
|
||||
import numbers
|
||||
import numpy as np
|
||||
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common.seed import get_seed
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
|
@ -30,7 +31,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.common._utils import get_slice_num
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._checkparam import Rel, check_is_number
|
||||
from mindspore._checkparam import Rel, check_is_number, is_stub_tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
||||
|
||||
|
@ -42,7 +43,7 @@ np_types = (np.int8, np.int16, np.int32, np.int64,
|
|||
def _check_input_data_type(input_data):
|
||||
"""Check the type of input_data for Tensor"""
|
||||
validator.check_value_type('input_data', input_data,
|
||||
(Tensor_, np.ndarray, np.str_, list, tuple, float, int, bool, complex),
|
||||
(Tensor_, Tensor, np.ndarray, np.str_, list, tuple, float, int, bool, complex),
|
||||
'Tensor')
|
||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
||||
|
@ -67,7 +68,13 @@ def _check_input_data_type(input_data):
|
|||
f"For Tensor, the input_data is {input_data} that contain unsupported element.")
|
||||
|
||||
|
||||
class Tensor(Tensor_):
|
||||
class _TensorMeta(type(Tensor_), abc.ABCMeta):
|
||||
"""
|
||||
Meta class for Tensor. Used internally.
|
||||
"""
|
||||
|
||||
|
||||
class Tensor(Tensor_, metaclass=_TensorMeta):
|
||||
"""
|
||||
Tensor is a data structure that stores an n-dimensional array.
|
||||
|
||||
|
@ -151,12 +158,13 @@ class Tensor(Tensor_):
|
|||
|
||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False, const_arg=False):
|
||||
self.init_finished = False
|
||||
|
||||
if internal:
|
||||
if input_data is not None:
|
||||
Tensor_.__init__(self, input_data)
|
||||
else:
|
||||
if is_stub_tensor(input_data):
|
||||
input_data.stub_sync()
|
||||
input_data = input_data.stub_sync()
|
||||
|
||||
# If input data is numpy number, convert it to np array
|
||||
if isinstance(input_data, np_types):
|
||||
|
@ -174,8 +182,8 @@ class Tensor(Tensor_):
|
|||
else:
|
||||
_check_input_data_type(input_data)
|
||||
if dtype is not None:
|
||||
validator.check_type_name(
|
||||
'dtype', dtype, mstype.number_type + (mstype.bool_, mstype.string), "Tensor")
|
||||
validator.check_type_name('dtype', dtype, mstype.number_type +
|
||||
(mstype.bool_, mstype.string), "Tensor")
|
||||
else:
|
||||
dtype = self._set_default_dtype(input_data, dtype)
|
||||
|
||||
|
@ -186,8 +194,8 @@ class Tensor(Tensor_):
|
|||
Tensor_.__init__(self, input_data, dtype)
|
||||
else:
|
||||
Tensor_.__init__(self, input_data)
|
||||
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
||||
|
||||
validator.check_value_type('const_arg', const_arg, bool, 'Tensor')
|
||||
self.const_arg = const_arg
|
||||
self.virtual_flag = False
|
||||
self.init = init
|
||||
|
@ -202,6 +210,16 @@ class Tensor(Tensor_):
|
|||
self.slice_num_of_persistent_data_ = None
|
||||
self.slice_shape_of_persistent_data_ = None
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, sub):
|
||||
"""
|
||||
Subclass with stub_sync attr will be instance of Tensor
|
||||
"""
|
||||
if cls is Tensor:
|
||||
if any("stub_sync" in s.__dict__ for s in sub.__mro__):
|
||||
return True
|
||||
return NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def _set_default_dtype(input_data, dtype):
|
||||
"""Set tensor default dtype"""
|
||||
|
@ -401,8 +419,6 @@ class Tensor(Tensor_):
|
|||
|
||||
def __setitem__(self, index, value):
|
||||
out = tensor_operator_registry.get('__setitem__')(self, index, value)
|
||||
if is_stub_tensor(out):
|
||||
out.stub_sync()
|
||||
self.assign_value(out)
|
||||
if self.parent_tensor_ is not None and self.index_of_parent_ is not None:
|
||||
self.parent_tensor_.__setitem__(self.index_of_parent_, self)
|
||||
|
@ -587,6 +603,8 @@ class Tensor(Tensor_):
|
|||
Returns:
|
||||
Tensor, Tensor that's been assigned.
|
||||
"""
|
||||
if is_stub_tensor(value):
|
||||
value = value.stub_sync()
|
||||
self.assign_value_cpp(value)
|
||||
return self
|
||||
|
||||
|
|
Loading…
Reference in New Issue