!48077 update stub tensor api

Merge pull request !48077 from yangsijia/stubtensor-api
This commit is contained in:
i-robot 2023-02-03 09:02:31 +00:00 committed by Gitee
commit 77f58edd7a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 421 additions and 122 deletions

View File

@ -33,6 +33,7 @@
#include "pybind_api/pybind_patch.h"
#include "include/common/utils/callbacks.h"
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/convert_utils_py.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "runtime/hardware/device_context_manager.h"
@ -457,7 +458,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePt
MS_EXCEPTION_IF_NULL(tensors);
ValuePtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
tensor_ptr = PyTensorCast(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
double input_value = py::cast<py::float_>(input_object);
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);

View File

@ -25,6 +25,7 @@
#include "base/base_ref.h"
#include "base/base.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "include/common/visible.h"
namespace py = pybind11;
@ -33,7 +34,8 @@ namespace mindspore {
py::object AnyToPyData(const Any &value);
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
// Convert python (stub) tensor to c++ tensor.
COMMON_EXPORT tensor::TensorPtr PyTensorCast(const py::handle &obj);
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);
} // namespace mindspore

View File

@ -28,6 +28,7 @@
#include "utils/symbolic.h"
#include "utils/ms_context.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/convert_utils_py.h"
namespace mindspore {
namespace parse {
@ -597,7 +598,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
static const std::vector<DataConverterPtr> data_converters{
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<Tensor>>(PyTensorCast),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),

View File

@ -92,6 +92,8 @@
#include "ir/cell.h"
#endif
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore {
// namespace to support intermediate representation definition
namespace pipeline {
@ -204,7 +206,7 @@ bool CheckArgValid(const py::handle &arg) {
}
if (py::isinstance<Tensor>(arg)) {
auto tensor = py::cast<TensorPtr>(arg);
TensorPtr tensor = PyTensorCast(arg);
if (tensor->data_type() == kNumberTypeBool) {
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
<< "operator compilation failure. For more details, please refer to the FAQ at "
@ -458,7 +460,7 @@ py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple
for (auto arg_obj : inputs) {
if (py::isinstance<Tensor>(arg_obj)) {
MS_LOG(DEBUG) << "Verify Tensor";
auto m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
std::shared_ptr<Tensor> m_tensor = PyTensorCast(arg_obj);
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
return false;

View File

@ -24,7 +24,10 @@
#include "pipeline/jit/pass.h"
#include "runtime/pynative/op_executor.h"
#include "runtime/pynative/op_compiler.h"
#include "pipeline/jit/parse/data_converter.h"
#include "ir/cell.h"
#include "abstract/utils.h"
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore::pynative {
std::shared_ptr<PyNativeExecutor> PyNativeExecutor::executor_ = nullptr;
@ -33,6 +36,7 @@ GradExecutorPtr PyNativeExecutor::grad_executor_ = nullptr;
std::mutex PyNativeExecutor::instance_lock_;
namespace {
enum class AsyncRunOpArgsEnum : size_t { PY_PRIM = 0, PY_INPUTS, PY_ARGS_NUM };
template <typename T, typename... Args>
T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Args &... args) {
const auto &inst = PyNativeExecutor::GetInstance();
@ -82,50 +86,20 @@ T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Arg
}
} // namespace
py::object PyNativeExecutor::RunOpAsync(const py::object &prim, const py::tuple &args) const {
const auto &adapter = prim.cast<PrimitivePyAdapterPtr>();
py::list val_args;
for (auto &arg : args) {
auto attr = py::getattr(arg, "stub", nullptr);
if (attr && py::isinstance<StubNode>(attr)) {
StubPtr node = py::cast<StubPtr>(attr);
val_args.append(PyNativeAlgo::DataConvert::ValueToPyObj(node->value));
} else {
val_args.append(arg);
}
py::object PyNativeExecutor::RunOpAsync(const py::args &args) const {
if (args.size() != static_cast<size_t>(AsyncRunOpArgsEnum::PY_ARGS_NUM)) {
MS_LOG(EXCEPTION) << "Two args are needed by RunOp";
}
auto run_args = py::make_tuple(prim, adapter->name(), val_args);
auto prim = args[static_cast<size_t>(AsyncRunOpArgsEnum::PY_PRIM)];
auto input_args = args[static_cast<size_t>(AsyncRunOpArgsEnum::PY_INPUTS)];
const auto &adapter = prim.cast<PrimitivePyAdapterPtr>();
auto run_args = py::make_tuple(prim, adapter->name(), input_args);
FrontendOpRunInfoPtr op_run_info = forward_executor()->GenerateOpRunInfo(run_args);
PyNativeExecutorTry(forward_executor()->RunOpS, op_run_info);
auto stub = std::make_shared<StubNode>();
stub->value = op_run_info->out_value;
if (!stub->value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << "Only Tensor output is supported by RunOpAsync now: " << stub->value->type_name();
}
stub->abs = stub->value->ToAbstract();
return py::make_tuple(static_cast<int>(StubNode::TENSOR), py::cast(stub));
}
py::object PyNativeExecutor::GetStubValue(const StubPtr &stub) const {
return PyNativeAlgo::DataConvert::ValueToPyObj(stub->value);
}
py::object PyNativeExecutor::GetStubShape(const StubPtr &stub) const {
auto base = stub->abs->BuildShape();
auto shape = base->cast<abstract::ShapePtr>();
if (!shape) {
MS_LOG(EXCEPTION) << "Only Tensor shape is supported by Stub now: " << base->ToString();
}
return py::cast(shape->shape());
}
py::object PyNativeExecutor::GetStubDtype(const StubPtr &stub) const {
auto base = stub->abs->BuildType();
auto type = base->cast<TensorTypePtr>();
if (!type) {
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
}
return py::cast(type->element());
auto converter = std::make_shared<StubOutConverter>();
auto ret_obj = converter->Convert(op_run_info->base_op_run_info.abstract, op_run_info->out_value);
auto ret_type = converter->GetRootType();
return py::make_tuple(ret_type, ret_obj);
}
py::object PyNativeExecutor::RealRunOp(const py::args &args) const {
@ -263,7 +237,8 @@ void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &a
}
void RegPyNativeExecutor(const py::module *m) {
(void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode");
RegPyNativeAsyncStub(m);
(void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
.def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
.def("is_first_cell", &PyNativeExecutor::IsFirstCell, "check if the first cell.")
@ -290,9 +265,6 @@ void RegPyNativeExecutor(const py::module *m) {
"set ms_funciton compile status.")
.def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.")
.def("run_op_async", &PyNativeExecutor::RunOpAsync, "run op asynchronously")
.def("get_stub_value", &PyNativeExecutor::GetStubValue, "get output value of async stub")
.def("get_stub_shape", &PyNativeExecutor::GetStubShape, "get output shape of async stub")
.def("get_stub_dtype", &PyNativeExecutor::GetStubDtype, "get output dtype of async stub")
.def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
}
} // namespace mindspore::pynative

View File

@ -22,19 +22,13 @@
#include <vector>
#include "pipeline/pynative/forward/forward.h"
#include "pipeline/pynative/grad/grad.h"
#include "pybind11/pybind11.h"
#include "frontend/operator/composite/composite.h"
namespace mindspore::pynative {
namespace py = pybind11;
struct StubNode {
enum { TENSOR = 0, CSR_TENSOR, COO_TENSOR, ROW_TENSOR, TUPLE };
AbstractBasePtr abs;
ValuePtr value;
};
using StubPtr = std::shared_ptr<StubNode>;
class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
public:
static std::shared_ptr<PyNativeExecutor> GetInstance() {
@ -58,11 +52,7 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
return forward_executor_;
}
py::object RunOpAsync(const py::object &prim, const py::tuple &args) const;
py::object GetStubValue(const StubPtr &stub) const;
py::object GetStubShape(const StubPtr &stub) const;
py::object GetStubDtype(const StubPtr &stub) const;
py::object RunOpAsync(const py::args &args) const;
py::object RealRunOp(const py::args &args) const;
py::object CallConstantFolding(const py::args &args) const;
bool grad_flag() const;

View File

@ -25,6 +25,7 @@
#include "include/common/utils/convert_utils_py.h"
#include "include/common/debug/anf_ir_dump.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore {
namespace pynative {
@ -199,7 +200,7 @@ std::string PyParser::GetPyObjId(const py::handle &obj) {
std::string PyParser::GetIdByPyObj(const py::object &obj) {
if (py::isinstance<tensor::Tensor>(obj)) {
return obj.cast<tensor::TensorPtr>()->id();
return PyTensorCast(obj)->id();
} else if (py::isinstance<Cell>(obj)) {
return obj.cast<CellPtr>()->id();
} else if (py::isinstance<mindspore::Type>(obj)) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@
#include "include/common/utils/primitive_utils.h"
#include "utils/check_convert_utils.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore {
namespace {
@ -313,8 +314,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
}
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
tensor::TensorPtr actual_out_tensor = mindspore::PyTensorCast(grad_out);
tensor::TensorPtr expected_out_tensor = mindspore::PyTensorCast(expected_grad_out);
MS_EXCEPTION_IF_NULL(actual_out_tensor);
MS_EXCEPTION_IF_NULL(expected_out_tensor);
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {

View File

@ -0,0 +1,132 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind_api/utils/stub_tensor_py.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/pynative/pynative_utils.h"
namespace mindspore {
namespace py = pybind11;
void RegPyNativeAsyncStub(const py::module *m) {
(void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode")
.def("get_value", &StubNode::GetValue, "get output value of async stub.")
.def("get_shape", &StubNode::GetShape, "get output shape of async stub.")
.def("get_dtype", &StubNode::GetDtype, "get output dtype of async stub.");
}
ShapeVector StubNode::GetShapeVector() {
auto base = abs->BuildShape();
auto shape = base->cast<abstract::ShapePtr>();
if (!shape) {
MS_LOG(EXCEPTION) << "Only Tensor shape is supported by Stub now: " << base->ToString();
}
return shape->shape();
}
TypePtr StubNode::GetTypePtr() {
auto base = abs->BuildType();
auto type = base->cast<TensorTypePtr>();
if (!type) {
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
}
return type->element();
}
py::object StubNode::GetValue() { return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(value); }
py::object StubNode::GetShape() {
auto shape_vector = GetShapeVector();
auto ret = py::tuple(shape_vector.size());
for (size_t i = 0; i < shape_vector.size(); ++i) {
ret[i] = shape_vector[i];
}
return ret;
}
py::object StubNode::GetDtype() { return py::cast(GetTypePtr()); }
py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value) {
py::object result;
if (abs->isa<abstract::AbstractTensor>()) {
result = ConvertTensor(abs->cast<abstract::AbstractTensorPtr>(), value);
root_type_ = static_cast<int>(StubNode::TENSOR);
} else if (abs->isa<abstract::AbstractTuple>()) {
result = ConvertTuple(abs->cast<abstract::AbstractTuplePtr>(), value);
root_type_ = static_cast<int>(StubNode::TUPLE);
} else if (abs->isa<abstract::AbstractList>()) {
result = ConvertList(abs->cast<abstract::AbstractListPtr>(), value);
// Should we create StubNode::LIST? Otherwise, this list output will be cast to tuple in python.
root_type_ = static_cast<int>(StubNode::TUPLE);
} else if (abs->isa<abstract::AbstractScalar>() || abs->isa<abstract::AbstractType>() ||
abs->isa<abstract::AbstractSlice>()) {
// Here are some types that `output_get_by_infer_value == true`
result = ConvertScalar(abs);
root_type_ = static_cast<int>(StubNode::SCALAR);
} else {
MS_LOG(EXCEPTION) << "StubOutConverter cannot handle this type of abstract: " << abs->ToString();
}
return result;
}
py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value) {
auto stub = std::make_shared<StubNode>();
stub->value = value;
stub->abs = tensor_abs;
return py::cast(stub);
}
py::object StubOutConverter::ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value) {
auto elements = seq_abs->elements();
py::tuple out(elements.size());
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
<< seq_abs->ToString();
}
auto seq_value = value->cast<ValueTuplePtr>();
if (seq_value->size() != seq_abs->size()) {
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
<< seq_abs->size();
}
for (size_t i = 0; i < elements.size(); ++i) {
out[i] = Convert(elements[i], seq_value->value()[i]);
}
return out;
}
py::object StubOutConverter::ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value) {
auto elements = seq_abs->elements();
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<ValueList>()) {
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
<< seq_abs->ToString();
}
auto seq_value = value->cast<ValueListPtr>();
if (seq_value->size() != seq_abs->size()) {
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
<< seq_abs->size();
}
py::list out(elements.size());
for (size_t i = 0; i < elements.size(); ++i) {
out[i] = Convert(elements[i], seq_value->value()[i]);
}
return out;
}
py::object StubOutConverter::ConvertScalar(const abstract::AbstractBasePtr &abs) {
return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(abs->BuildValue());
}
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
#define MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
#include <string>
#include <functional>
#include <memory>
#include "pybind11/pybind11.h"
#include "base/base.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "mindapi/base/shape_vector.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace py = pybind11;
namespace stub {
constexpr auto PY_ATTR_STUB = "stub";
} // namespace stub
void RegPyNativeAsyncStub(const py::module *m);
class StubNode {
public:
enum { TENSOR = 0, CSR_TENSOR, COO_TENSOR, ROW_TENSOR, TUPLE, SCALAR, NOT_SUPPORT };
StubNode() {}
~StubNode() {}
abstract::AbstractBasePtr abs;
ValuePtr value;
// Api for python StubTensor object
py::object GetValue();
py::object GetShape();
py::object GetDtype();
private:
ShapeVector GetShapeVector();
TypePtr GetTypePtr();
};
using StubNodePtr = std::shared_ptr<StubNode>;
class StubOutConverter {
public:
StubOutConverter() {}
~StubOutConverter() {}
py::object Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value = nullptr);
int GetRootType() { return root_type_; }
private:
py::object ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value);
py::object ConvertScalar(const abstract::AbstractBasePtr &scalar_abs);
py::object ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value);
py::object ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value);
int root_type_{static_cast<int>(StubNode::NOT_SUPPORT)};
};
using StubOutConverterPtr = std::shared_ptr<StubOutConverter>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_

View File

@ -33,6 +33,7 @@
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "pybind_api/ir/base_ref_py.h"
#include "pybind_api/utils/stub_tensor_py.h"
#include "ir/dtype/tensor_type.h"
#include "utils/ms_context.h"
#include "include/common/utils/convert_utils.h"
@ -677,4 +678,21 @@ py::object MakeCOOTensor(const VectorRef &value_list) {
ret[0] = std::make_shared<tensor::COOTensor>(indices, values, shape);
return ret[0];
}
tensor::TensorPtr PyTensorCast(const py::handle &obj) {
if (!py::isinstance<tensor::Tensor>(obj)) {
return nullptr;
}
auto is_stub_tensor = py::hasattr(obj, stub::PY_ATTR_STUB);
if (!is_stub_tensor) {
return py::cast<tensor::TensorPtr>(obj);
}
auto stub_node = py::getattr(obj, stub::PY_ATTR_STUB);
auto is_stub_tensor_sync = !py::isinstance<StubNode>(stub_node);
if (is_stub_tensor_sync) {
return py::cast<tensor::TensorPtr>(obj);
}
auto stub = py::getattr(obj, stub::PY_ATTR_STUB).cast<StubNodePtr>();
return stub->value->cast<tensor::TensorPtr>();
}
} // namespace mindspore

View File

@ -14,8 +14,10 @@
# ============================================================================
"""Stub Tensor implementation."""
from functools import reduce
from mindspore.common.tensor import Tensor
from mindspore.common.api import _pynative_executor, _convert_python_data
from mindspore.common.dtype import type_size_in_bytes
from mindspore._c_expression import Tensor as Tensor_
class StubTensor(Tensor):
@ -24,45 +26,163 @@ class StubTensor(Tensor):
def __init__(self, stub):
Tensor.__init__(self, internal=True)
self.stub = stub
self.tensor = None
def __repr__(self):
return self.data().__repr__()
self.stub_sync()
return super().__repr__()
def __str__(self):
return self.data().__str__()
self.stub_sync()
return super().__str__()
def __setitem__(self, index, value):
self.stub_sync()
return super().__setitem__(index, value)
@property
def shape(self):
"""shape stub."""
if self.tensor:
return self.tensor.shape
return _pynative_executor.get_stub_shape(self.stub)
if self.stub:
return self.stub.get_shape()
return super().shape
@property
def dtype(self):
"""dtype stub."""
if self.tensor:
return self.tensor.dtype
return _pynative_executor.get_stub_dtype(self.stub)
if self.stub:
return self.stub.get_dtype()
return super().dtype
def data(self):
"""get real tensor data."""
if self.tensor is None:
val = _pynative_executor.get_stub_value(self.stub)
self.tensor = _convert_python_data(val)
return self.tensor
@property
def size(self):
"""size stub."""
shape = self.shape
return reduce((lambda x, y: x * y), shape) if shape else 1
@property
def itemsize(self):
"""itemsize stub."""
return type_size_in_bytes(self.dtype)
@property
def nbytes(self):
"""nbytes stub."""
return self.size * self.itemsize
@property
def ndim(self):
"""ndim stub."""
return len(self.shape)
@property
def strides(self):
"""strides stub."""
self.stub_sync()
return super().strides
@property
def has_init(self):
"""has_init stub."""
self.stub_sync()
return super().has_init
def stub_sync(self):
"""data sync to get real tensor"""
if self.stub:
val = self.stub.get_value()
Tensor_.__init__(self, val)
self.stub = None
def ndimension(self):
r"""
Alias for :func:`mindspore.Tensor.ndim`.
"""
return self.ndim
def asnumpy(self):
"""api stub."""
return self.data().asnumpy()
self.stub_sync()
return super().asnumpy()
def is_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.is_persistent_data`.
"""
self.stub_sync()
super().is_persistent_data()
def asnumpy_of_slice_persistent_data(self, param_key, slice_index):
"""
For details, please refer to :`mindspore.common.tensor.asnumpy_of_slice_persistent_data`.
"""
self.stub_sync()
return super().asnumpy_of_slice_persistent_data(param_key, slice_index)
def slice_num_of_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.slice_num_of_persistent_data`.
"""
self.stub_sync()
return super().slice_num_of_persistent_data()
def slice_shape_of_persistent_data(self):
"""
For details, please refer to :`mindspore.common.tensor.slice_shape_of_persistent_data`.
"""
self.stub_sync()
return super().slice_shape_of_persistent_data()
def flush_from_cache(self):
"""
For details, please refer to :`mindspore.common.tensor.flush_from_cache`.
"""
self.stub_sync()
super().flush_from_cache()
class StubTuple(tuple):
"""tuple that may contain stub tensor for async op run."""
def __new__(cls, stub_tuple):
"""Do some pre-process before creating StubTuple
Args:
stub_tuple (tuple): a tuple of c_expression object that may contain `StubNode`
Returns:
StubTuple: a tuple of python object, in which `StubNode` is converted to `StubTensor`
"""
new_tuple = StubTuple._dfs_convert_stubnode(stub_tuple)
ret = super(StubTuple, cls).__new__(cls, new_tuple)
return ret
@staticmethod
def _is_c_expression_stubnode(node):
"""Currently we just simply use the function that defined in `py::class_<StubNode>`"""
return hasattr(node, "get_value")
@staticmethod
def _dfs_convert_stubnode(node):
if isinstance(node, (tuple, list)):
res = [StubTuple._dfs_convert_stubnode(o) for o in node]
return type(node)(res)
if StubTuple._is_c_expression_stubnode(node):
# Identify and handle CSR/COO/ROW Tensor here, we can use `_stub_map`
return StubTensor(node)
return node
_stub_map = [
StubTensor
StubTensor,
StubTensor, # CSR_TENSOR
StubTensor, # COO_TENSOR
StubTensor, # ROW_TENSOR
StubTuple,
]
def _convert_stub(stub_type, stub):
"""convert stub node to stub tensor."""
if stub_type >= len(_stub_map): # obj that already convert in c++, e.g. Scalar
return stub
return _stub_map[stub_type](stub)

View File

@ -80,3 +80,7 @@ def get_slice_shape(dtype, shape):
def dict_setitem(dic, key, val):
dic.__setitem__(key, val)
return dic
def is_stub_tensor(tensor):
return getattr(tensor, "stub", False)

View File

@ -1026,42 +1026,6 @@ class _PyNativeExecutor:
"""
return self._executor.run_op_async(prim, args)
def get_stub_value(self, stub):
"""
Get output value of stub node.
Args:
stub(StubNode): stub node
Return:
Value, result of run op.
"""
return self._executor.get_stub_value(stub)
def get_stub_shape(self, stub):
"""
Get output shape of stub node.
Args:
stub(StubNode): stub node
Return:
output shape.
"""
return self._executor.get_stub_shape(stub)
def get_stub_dtype(self, stub):
"""
Get output dtype of stub node.
Args:
stub(StubNode): stub node
Return:
output dtype.
"""
return self._executor.get_stub_dtype(stub)
def new_graph(self, obj, *args, **kwargs):
"""
Initialize resources for building forward and backward graph.

View File

@ -22,6 +22,7 @@ from typing import Tuple
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common._utils import is_stub_tensor
from mindspore.common.tensor import Tensor
from mindspore._c_expression import COOTensor as COOTensor_
from mindspore._c_expression import CSRTensor as CSRTensor_
@ -271,6 +272,10 @@ class COOTensor(COOTensor_):
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
validator.check_coo_tensor_dtype(indices.dtype)
indices = tensor_operator_registry.get('stop_gradient')(indices)
if is_stub_tensor(indices):
indices.stub_sync()
if is_stub_tensor(values):
values.stub_sync()
COOTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -604,6 +609,12 @@ class CSRTensor(CSRTensor_):
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
indices = tensor_operator_registry.get('stop_gradient')(indices)
if is_stub_tensor(indptr):
indptr.stub_sync()
if is_stub_tensor(values):
values.stub_sync()
if is_stub_tensor(indices):
indices.stub_sync()
CSRTensor_.__init__(self, indptr, indices, values, shape)
self.init_finished = True

View File

@ -21,11 +21,12 @@ import numbers
import numpy as np
from mindspore.communication.management import get_rank, get_group_size
from mindspore.common._utils import is_shape_unknown
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
from mindspore.common.seed import get_seed
from mindspore import context
from mindspore import log as logger
from mindspore.common import dtype as mstype
from mindspore.common._utils import get_slice_num
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore._c_expression import Tensor as Tensor_
@ -153,6 +154,9 @@ class Tensor(Tensor_):
if input_data is not None:
Tensor_.__init__(self, input_data)
else:
if is_stub_tensor(input_data):
input_data.stub_sync()
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
input_data = np.array(input_data)
@ -396,6 +400,8 @@ class Tensor(Tensor_):
def __setitem__(self, index, value):
out = tensor_operator_registry.get('__setitem__')(self, index, value)
if is_stub_tensor(out):
out.stub_sync()
self.assign_value(out)
if self.parent_tensor_ is not None and self.index_of_parent_ is not None:
self.parent_tensor_.__setitem__(self.index_of_parent_, self)
@ -1328,7 +1334,6 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('log1p')(self)
def logit(self, eps=None):
r"""
For details, please refer to :func:`mindspore.ops.logit`.
"""