forked from mindspore-Ecosystem/mindspore
!48077 update stub tensor api
Merge pull request !48077 from yangsijia/stubtensor-api
This commit is contained in:
commit
77f58edd7a
|
@ -33,6 +33,7 @@
|
|||
#include "pybind_api/pybind_patch.h"
|
||||
#include "include/common/utils/callbacks.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
@ -457,7 +458,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, std::vector<ValuePt
|
|||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
ValuePtr tensor_ptr = nullptr;
|
||||
if (py::isinstance<tensor::Tensor>(input_object)) {
|
||||
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
|
||||
tensor_ptr = PyTensorCast(input_object);
|
||||
} else if (py::isinstance<py::float_>(input_object)) {
|
||||
double input_value = py::cast<py::float_>(input_object);
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "base/base_ref.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -33,7 +34,8 @@ namespace mindspore {
|
|||
py::object AnyToPyData(const Any &value);
|
||||
COMMON_EXPORT py::object BaseRefToPyData(const BaseRef &value, const AbstractBasePtr &abs = nullptr);
|
||||
COMMON_EXPORT py::object ValueToPyData(const ValuePtr &value, const AbstractBasePtr &abs = nullptr);
|
||||
|
||||
// Convert python (stub) tensor to c++ tensor.
|
||||
COMMON_EXPORT tensor::TensorPtr PyTensorCast(const py::handle &obj);
|
||||
COMMON_EXPORT bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
const std::shared_ptr<py::object> &ret_val);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
@ -597,7 +598,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
|
|||
static const std::vector<DataConverterPtr> data_converters{
|
||||
// AdapterTensor needs to be processed before Tensor because it inherits from Tensor.
|
||||
std::make_shared<ByFuncDataConverter>(IsAdapterTensor, ConvertAdapterTensor),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<Tensor>>(PyTensorCast),
|
||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||
|
|
|
@ -92,6 +92,8 @@
|
|||
#include "ir/cell.h"
|
||||
#endif
|
||||
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support intermediate representation definition
|
||||
namespace pipeline {
|
||||
|
@ -204,7 +206,7 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
}
|
||||
|
||||
if (py::isinstance<Tensor>(arg)) {
|
||||
auto tensor = py::cast<TensorPtr>(arg);
|
||||
TensorPtr tensor = PyTensorCast(arg);
|
||||
if (tensor->data_type() == kNumberTypeBool) {
|
||||
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
|
||||
<< "operator compilation failure. For more details, please refer to the FAQ at "
|
||||
|
@ -458,7 +460,7 @@ py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple
|
|||
for (auto arg_obj : inputs) {
|
||||
if (py::isinstance<Tensor>(arg_obj)) {
|
||||
MS_LOG(DEBUG) << "Verify Tensor";
|
||||
auto m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
|
||||
std::shared_ptr<Tensor> m_tensor = PyTensorCast(arg_obj);
|
||||
if (m_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
|
||||
return false;
|
||||
|
|
|
@ -24,7 +24,10 @@
|
|||
#include "pipeline/jit/pass.h"
|
||||
#include "runtime/pynative/op_executor.h"
|
||||
#include "runtime/pynative/op_compiler.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "ir/cell.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
std::shared_ptr<PyNativeExecutor> PyNativeExecutor::executor_ = nullptr;
|
||||
|
@ -33,6 +36,7 @@ GradExecutorPtr PyNativeExecutor::grad_executor_ = nullptr;
|
|||
std::mutex PyNativeExecutor::instance_lock_;
|
||||
|
||||
namespace {
|
||||
enum class AsyncRunOpArgsEnum : size_t { PY_PRIM = 0, PY_INPUTS, PY_ARGS_NUM };
|
||||
template <typename T, typename... Args>
|
||||
T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Args &... args) {
|
||||
const auto &inst = PyNativeExecutor::GetInstance();
|
||||
|
@ -82,50 +86,20 @@ T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Arg
|
|||
}
|
||||
} // namespace
|
||||
|
||||
py::object PyNativeExecutor::RunOpAsync(const py::object &prim, const py::tuple &args) const {
|
||||
const auto &adapter = prim.cast<PrimitivePyAdapterPtr>();
|
||||
py::list val_args;
|
||||
for (auto &arg : args) {
|
||||
auto attr = py::getattr(arg, "stub", nullptr);
|
||||
if (attr && py::isinstance<StubNode>(attr)) {
|
||||
StubPtr node = py::cast<StubPtr>(attr);
|
||||
val_args.append(PyNativeAlgo::DataConvert::ValueToPyObj(node->value));
|
||||
} else {
|
||||
val_args.append(arg);
|
||||
}
|
||||
py::object PyNativeExecutor::RunOpAsync(const py::args &args) const {
|
||||
if (args.size() != static_cast<size_t>(AsyncRunOpArgsEnum::PY_ARGS_NUM)) {
|
||||
MS_LOG(EXCEPTION) << "Two args are needed by RunOp";
|
||||
}
|
||||
auto run_args = py::make_tuple(prim, adapter->name(), val_args);
|
||||
auto prim = args[static_cast<size_t>(AsyncRunOpArgsEnum::PY_PRIM)];
|
||||
auto input_args = args[static_cast<size_t>(AsyncRunOpArgsEnum::PY_INPUTS)];
|
||||
const auto &adapter = prim.cast<PrimitivePyAdapterPtr>();
|
||||
auto run_args = py::make_tuple(prim, adapter->name(), input_args);
|
||||
FrontendOpRunInfoPtr op_run_info = forward_executor()->GenerateOpRunInfo(run_args);
|
||||
PyNativeExecutorTry(forward_executor()->RunOpS, op_run_info);
|
||||
auto stub = std::make_shared<StubNode>();
|
||||
stub->value = op_run_info->out_value;
|
||||
if (!stub->value->isa<tensor::Tensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor output is supported by RunOpAsync now: " << stub->value->type_name();
|
||||
}
|
||||
stub->abs = stub->value->ToAbstract();
|
||||
return py::make_tuple(static_cast<int>(StubNode::TENSOR), py::cast(stub));
|
||||
}
|
||||
|
||||
py::object PyNativeExecutor::GetStubValue(const StubPtr &stub) const {
|
||||
return PyNativeAlgo::DataConvert::ValueToPyObj(stub->value);
|
||||
}
|
||||
|
||||
py::object PyNativeExecutor::GetStubShape(const StubPtr &stub) const {
|
||||
auto base = stub->abs->BuildShape();
|
||||
auto shape = base->cast<abstract::ShapePtr>();
|
||||
if (!shape) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor shape is supported by Stub now: " << base->ToString();
|
||||
}
|
||||
return py::cast(shape->shape());
|
||||
}
|
||||
|
||||
py::object PyNativeExecutor::GetStubDtype(const StubPtr &stub) const {
|
||||
auto base = stub->abs->BuildType();
|
||||
auto type = base->cast<TensorTypePtr>();
|
||||
if (!type) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
|
||||
}
|
||||
return py::cast(type->element());
|
||||
auto converter = std::make_shared<StubOutConverter>();
|
||||
auto ret_obj = converter->Convert(op_run_info->base_op_run_info.abstract, op_run_info->out_value);
|
||||
auto ret_type = converter->GetRootType();
|
||||
return py::make_tuple(ret_type, ret_obj);
|
||||
}
|
||||
|
||||
py::object PyNativeExecutor::RealRunOp(const py::args &args) const {
|
||||
|
@ -263,7 +237,8 @@ void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &a
|
|||
}
|
||||
|
||||
void RegPyNativeExecutor(const py::module *m) {
|
||||
(void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode");
|
||||
RegPyNativeAsyncStub(m);
|
||||
|
||||
(void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
|
||||
.def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
|
||||
.def("is_first_cell", &PyNativeExecutor::IsFirstCell, "check if the first cell.")
|
||||
|
@ -290,9 +265,6 @@ void RegPyNativeExecutor(const py::module *m) {
|
|||
"set ms_funciton compile status.")
|
||||
.def("real_run_op", &PyNativeExecutor::RealRunOp, "Run op pynatively.")
|
||||
.def("run_op_async", &PyNativeExecutor::RunOpAsync, "run op asynchronously")
|
||||
.def("get_stub_value", &PyNativeExecutor::GetStubValue, "get output value of async stub")
|
||||
.def("get_stub_shape", &PyNativeExecutor::GetStubShape, "get output shape of async stub")
|
||||
.def("get_stub_dtype", &PyNativeExecutor::GetStubDtype, "get output dtype of async stub")
|
||||
.def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
|
||||
}
|
||||
} // namespace mindspore::pynative
|
||||
|
|
|
@ -22,19 +22,13 @@
|
|||
#include <vector>
|
||||
#include "pipeline/pynative/forward/forward.h"
|
||||
#include "pipeline/pynative/grad/grad.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
namespace py = pybind11;
|
||||
|
||||
struct StubNode {
|
||||
enum { TENSOR = 0, CSR_TENSOR, COO_TENSOR, ROW_TENSOR, TUPLE };
|
||||
AbstractBasePtr abs;
|
||||
ValuePtr value;
|
||||
};
|
||||
using StubPtr = std::shared_ptr<StubNode>;
|
||||
|
||||
class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
|
||||
public:
|
||||
static std::shared_ptr<PyNativeExecutor> GetInstance() {
|
||||
|
@ -58,11 +52,7 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
|
|||
return forward_executor_;
|
||||
}
|
||||
|
||||
py::object RunOpAsync(const py::object &prim, const py::tuple &args) const;
|
||||
py::object GetStubValue(const StubPtr &stub) const;
|
||||
py::object GetStubShape(const StubPtr &stub) const;
|
||||
py::object GetStubDtype(const StubPtr &stub) const;
|
||||
|
||||
py::object RunOpAsync(const py::args &args) const;
|
||||
py::object RealRunOp(const py::args &args) const;
|
||||
py::object CallConstantFolding(const py::args &args) const;
|
||||
bool grad_flag() const;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
@ -199,7 +200,7 @@ std::string PyParser::GetPyObjId(const py::handle &obj) {
|
|||
|
||||
std::string PyParser::GetIdByPyObj(const py::object &obj) {
|
||||
if (py::isinstance<tensor::Tensor>(obj)) {
|
||||
return obj.cast<tensor::TensorPtr>()->id();
|
||||
return PyTensorCast(obj)->id();
|
||||
} else if (py::isinstance<Cell>(obj)) {
|
||||
return obj.cast<CellPtr>()->id();
|
||||
} else if (py::isinstance<mindspore::Type>(obj)) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@
|
|||
#include "include/common/utils/primitive_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -313,8 +314,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
|
||||
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
|
||||
}
|
||||
auto actual_out_tensor = py::cast<tensor::TensorPtr>(grad_out);
|
||||
auto expected_out_tensor = py::cast<tensor::TensorPtr>(expected_grad_out);
|
||||
tensor::TensorPtr actual_out_tensor = mindspore::PyTensorCast(grad_out);
|
||||
tensor::TensorPtr expected_out_tensor = mindspore::PyTensorCast(expected_grad_out);
|
||||
MS_EXCEPTION_IF_NULL(actual_out_tensor);
|
||||
MS_EXCEPTION_IF_NULL(expected_out_tensor);
|
||||
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/pynative/pynative_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
void RegPyNativeAsyncStub(const py::module *m) {
|
||||
(void)py::class_<StubNode, std::shared_ptr<StubNode>>(*m, "StubNode")
|
||||
.def("get_value", &StubNode::GetValue, "get output value of async stub.")
|
||||
.def("get_shape", &StubNode::GetShape, "get output shape of async stub.")
|
||||
.def("get_dtype", &StubNode::GetDtype, "get output dtype of async stub.");
|
||||
}
|
||||
|
||||
ShapeVector StubNode::GetShapeVector() {
|
||||
auto base = abs->BuildShape();
|
||||
auto shape = base->cast<abstract::ShapePtr>();
|
||||
if (!shape) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor shape is supported by Stub now: " << base->ToString();
|
||||
}
|
||||
return shape->shape();
|
||||
}
|
||||
|
||||
TypePtr StubNode::GetTypePtr() {
|
||||
auto base = abs->BuildType();
|
||||
auto type = base->cast<TensorTypePtr>();
|
||||
if (!type) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
|
||||
}
|
||||
return type->element();
|
||||
}
|
||||
|
||||
py::object StubNode::GetValue() { return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(value); }
|
||||
|
||||
py::object StubNode::GetShape() {
|
||||
auto shape_vector = GetShapeVector();
|
||||
auto ret = py::tuple(shape_vector.size());
|
||||
for (size_t i = 0; i < shape_vector.size(); ++i) {
|
||||
ret[i] = shape_vector[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object StubNode::GetDtype() { return py::cast(GetTypePtr()); }
|
||||
|
||||
py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value) {
|
||||
py::object result;
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
result = ConvertTensor(abs->cast<abstract::AbstractTensorPtr>(), value);
|
||||
root_type_ = static_cast<int>(StubNode::TENSOR);
|
||||
} else if (abs->isa<abstract::AbstractTuple>()) {
|
||||
result = ConvertTuple(abs->cast<abstract::AbstractTuplePtr>(), value);
|
||||
root_type_ = static_cast<int>(StubNode::TUPLE);
|
||||
} else if (abs->isa<abstract::AbstractList>()) {
|
||||
result = ConvertList(abs->cast<abstract::AbstractListPtr>(), value);
|
||||
// Should we create StubNode::LIST? Otherwise, this list output will be cast to tuple in python.
|
||||
root_type_ = static_cast<int>(StubNode::TUPLE);
|
||||
} else if (abs->isa<abstract::AbstractScalar>() || abs->isa<abstract::AbstractType>() ||
|
||||
abs->isa<abstract::AbstractSlice>()) {
|
||||
// Here are some types that `output_get_by_infer_value == true`
|
||||
result = ConvertScalar(abs);
|
||||
root_type_ = static_cast<int>(StubNode::SCALAR);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "StubOutConverter cannot handle this type of abstract: " << abs->ToString();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value) {
|
||||
auto stub = std::make_shared<StubNode>();
|
||||
stub->value = value;
|
||||
stub->abs = tensor_abs;
|
||||
return py::cast(stub);
|
||||
}
|
||||
|
||||
py::object StubOutConverter::ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value) {
|
||||
auto elements = seq_abs->elements();
|
||||
py::tuple out(elements.size());
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
|
||||
<< seq_abs->ToString();
|
||||
}
|
||||
auto seq_value = value->cast<ValueTuplePtr>();
|
||||
if (seq_value->size() != seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
|
||||
<< seq_abs->size();
|
||||
}
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
out[i] = Convert(elements[i], seq_value->value()[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
py::object StubOutConverter::ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value) {
|
||||
auto elements = seq_abs->elements();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<ValueList>()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
|
||||
<< seq_abs->ToString();
|
||||
}
|
||||
auto seq_value = value->cast<ValueListPtr>();
|
||||
if (seq_value->size() != seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
|
||||
<< seq_abs->size();
|
||||
}
|
||||
py::list out(elements.size());
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
out[i] = Convert(elements[i], seq_value->value()[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
py::object StubOutConverter::ConvertScalar(const abstract::AbstractBasePtr &abs) {
|
||||
return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(abs->BuildValue());
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "mindapi/base/shape_vector.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace stub {
|
||||
constexpr auto PY_ATTR_STUB = "stub";
|
||||
} // namespace stub
|
||||
|
||||
void RegPyNativeAsyncStub(const py::module *m);
|
||||
|
||||
class StubNode {
|
||||
public:
|
||||
enum { TENSOR = 0, CSR_TENSOR, COO_TENSOR, ROW_TENSOR, TUPLE, SCALAR, NOT_SUPPORT };
|
||||
StubNode() {}
|
||||
~StubNode() {}
|
||||
abstract::AbstractBasePtr abs;
|
||||
ValuePtr value;
|
||||
|
||||
// Api for python StubTensor object
|
||||
py::object GetValue();
|
||||
py::object GetShape();
|
||||
py::object GetDtype();
|
||||
|
||||
private:
|
||||
ShapeVector GetShapeVector();
|
||||
TypePtr GetTypePtr();
|
||||
};
|
||||
using StubNodePtr = std::shared_ptr<StubNode>;
|
||||
|
||||
class StubOutConverter {
|
||||
public:
|
||||
StubOutConverter() {}
|
||||
~StubOutConverter() {}
|
||||
py::object Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value = nullptr);
|
||||
int GetRootType() { return root_type_; }
|
||||
|
||||
private:
|
||||
py::object ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value);
|
||||
py::object ConvertScalar(const abstract::AbstractBasePtr &scalar_abs);
|
||||
py::object ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value);
|
||||
py::object ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value);
|
||||
int root_type_{static_cast<int>(StubNode::NOT_SUPPORT)};
|
||||
};
|
||||
|
||||
using StubOutConverterPtr = std::shared_ptr<StubOutConverter>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_STUB_TENSOR_PY_H_
|
|
@ -33,6 +33,7 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
#include "ir/dtype/tensor_type.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
|
@ -677,4 +678,21 @@ py::object MakeCOOTensor(const VectorRef &value_list) {
|
|||
ret[0] = std::make_shared<tensor::COOTensor>(indices, values, shape);
|
||||
return ret[0];
|
||||
}
|
||||
|
||||
tensor::TensorPtr PyTensorCast(const py::handle &obj) {
|
||||
if (!py::isinstance<tensor::Tensor>(obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto is_stub_tensor = py::hasattr(obj, stub::PY_ATTR_STUB);
|
||||
if (!is_stub_tensor) {
|
||||
return py::cast<tensor::TensorPtr>(obj);
|
||||
}
|
||||
auto stub_node = py::getattr(obj, stub::PY_ATTR_STUB);
|
||||
auto is_stub_tensor_sync = !py::isinstance<StubNode>(stub_node);
|
||||
if (is_stub_tensor_sync) {
|
||||
return py::cast<tensor::TensorPtr>(obj);
|
||||
}
|
||||
auto stub = py::getattr(obj, stub::PY_ATTR_STUB).cast<StubNodePtr>();
|
||||
return stub->value->cast<tensor::TensorPtr>();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# ============================================================================
|
||||
"""Stub Tensor implementation."""
|
||||
|
||||
from functools import reduce
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import _pynative_executor, _convert_python_data
|
||||
from mindspore.common.dtype import type_size_in_bytes
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
||||
|
||||
class StubTensor(Tensor):
|
||||
|
@ -24,45 +26,163 @@ class StubTensor(Tensor):
|
|||
def __init__(self, stub):
|
||||
Tensor.__init__(self, internal=True)
|
||||
self.stub = stub
|
||||
self.tensor = None
|
||||
|
||||
def __repr__(self):
|
||||
return self.data().__repr__()
|
||||
self.stub_sync()
|
||||
return super().__repr__()
|
||||
|
||||
def __str__(self):
|
||||
return self.data().__str__()
|
||||
self.stub_sync()
|
||||
return super().__str__()
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.stub_sync()
|
||||
return super().__setitem__(index, value)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""shape stub."""
|
||||
if self.tensor:
|
||||
return self.tensor.shape
|
||||
return _pynative_executor.get_stub_shape(self.stub)
|
||||
if self.stub:
|
||||
return self.stub.get_shape()
|
||||
return super().shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""dtype stub."""
|
||||
if self.tensor:
|
||||
return self.tensor.dtype
|
||||
return _pynative_executor.get_stub_dtype(self.stub)
|
||||
if self.stub:
|
||||
return self.stub.get_dtype()
|
||||
return super().dtype
|
||||
|
||||
def data(self):
|
||||
"""get real tensor data."""
|
||||
if self.tensor is None:
|
||||
val = _pynative_executor.get_stub_value(self.stub)
|
||||
self.tensor = _convert_python_data(val)
|
||||
return self.tensor
|
||||
@property
|
||||
def size(self):
|
||||
"""size stub."""
|
||||
shape = self.shape
|
||||
return reduce((lambda x, y: x * y), shape) if shape else 1
|
||||
|
||||
@property
|
||||
def itemsize(self):
|
||||
"""itemsize stub."""
|
||||
return type_size_in_bytes(self.dtype)
|
||||
|
||||
@property
|
||||
def nbytes(self):
|
||||
"""nbytes stub."""
|
||||
return self.size * self.itemsize
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
"""ndim stub."""
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def strides(self):
|
||||
"""strides stub."""
|
||||
self.stub_sync()
|
||||
return super().strides
|
||||
|
||||
@property
|
||||
def has_init(self):
|
||||
"""has_init stub."""
|
||||
self.stub_sync()
|
||||
return super().has_init
|
||||
|
||||
def stub_sync(self):
|
||||
"""data sync to get real tensor"""
|
||||
if self.stub:
|
||||
val = self.stub.get_value()
|
||||
Tensor_.__init__(self, val)
|
||||
self.stub = None
|
||||
|
||||
def ndimension(self):
|
||||
r"""
|
||||
Alias for :func:`mindspore.Tensor.ndim`.
|
||||
"""
|
||||
return self.ndim
|
||||
|
||||
def asnumpy(self):
|
||||
"""api stub."""
|
||||
return self.data().asnumpy()
|
||||
self.stub_sync()
|
||||
return super().asnumpy()
|
||||
|
||||
def is_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.is_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
super().is_persistent_data()
|
||||
|
||||
def asnumpy_of_slice_persistent_data(self, param_key, slice_index):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.asnumpy_of_slice_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().asnumpy_of_slice_persistent_data(param_key, slice_index)
|
||||
|
||||
def slice_num_of_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.slice_num_of_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().slice_num_of_persistent_data()
|
||||
|
||||
def slice_shape_of_persistent_data(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.slice_shape_of_persistent_data`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
return super().slice_shape_of_persistent_data()
|
||||
|
||||
def flush_from_cache(self):
|
||||
"""
|
||||
For details, please refer to :`mindspore.common.tensor.flush_from_cache`.
|
||||
"""
|
||||
self.stub_sync()
|
||||
super().flush_from_cache()
|
||||
|
||||
|
||||
class StubTuple(tuple):
|
||||
"""tuple that may contain stub tensor for async op run."""
|
||||
|
||||
def __new__(cls, stub_tuple):
|
||||
"""Do some pre-process before creating StubTuple
|
||||
|
||||
Args:
|
||||
stub_tuple (tuple): a tuple of c_expression object that may contain `StubNode`
|
||||
|
||||
Returns:
|
||||
StubTuple: a tuple of python object, in which `StubNode` is converted to `StubTensor`
|
||||
"""
|
||||
new_tuple = StubTuple._dfs_convert_stubnode(stub_tuple)
|
||||
ret = super(StubTuple, cls).__new__(cls, new_tuple)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _is_c_expression_stubnode(node):
|
||||
"""Currently we just simply use the function that defined in `py::class_<StubNode>`"""
|
||||
return hasattr(node, "get_value")
|
||||
|
||||
@staticmethod
|
||||
def _dfs_convert_stubnode(node):
|
||||
if isinstance(node, (tuple, list)):
|
||||
res = [StubTuple._dfs_convert_stubnode(o) for o in node]
|
||||
return type(node)(res)
|
||||
if StubTuple._is_c_expression_stubnode(node):
|
||||
# Identify and handle CSR/COO/ROW Tensor here, we can use `_stub_map`
|
||||
return StubTensor(node)
|
||||
return node
|
||||
|
||||
|
||||
_stub_map = [
|
||||
StubTensor
|
||||
StubTensor,
|
||||
StubTensor, # CSR_TENSOR
|
||||
StubTensor, # COO_TENSOR
|
||||
StubTensor, # ROW_TENSOR
|
||||
StubTuple,
|
||||
]
|
||||
|
||||
|
||||
def _convert_stub(stub_type, stub):
|
||||
"""convert stub node to stub tensor."""
|
||||
if stub_type >= len(_stub_map): # obj that already convert in c++, e.g. Scalar
|
||||
return stub
|
||||
return _stub_map[stub_type](stub)
|
||||
|
|
|
@ -80,3 +80,7 @@ def get_slice_shape(dtype, shape):
|
|||
def dict_setitem(dic, key, val):
|
||||
dic.__setitem__(key, val)
|
||||
return dic
|
||||
|
||||
|
||||
def is_stub_tensor(tensor):
|
||||
return getattr(tensor, "stub", False)
|
||||
|
|
|
@ -1026,42 +1026,6 @@ class _PyNativeExecutor:
|
|||
"""
|
||||
return self._executor.run_op_async(prim, args)
|
||||
|
||||
def get_stub_value(self, stub):
|
||||
"""
|
||||
Get output value of stub node.
|
||||
|
||||
Args:
|
||||
stub(StubNode): stub node
|
||||
|
||||
Return:
|
||||
Value, result of run op.
|
||||
"""
|
||||
return self._executor.get_stub_value(stub)
|
||||
|
||||
def get_stub_shape(self, stub):
|
||||
"""
|
||||
Get output shape of stub node.
|
||||
|
||||
Args:
|
||||
stub(StubNode): stub node
|
||||
|
||||
Return:
|
||||
output shape.
|
||||
"""
|
||||
return self._executor.get_stub_shape(stub)
|
||||
|
||||
def get_stub_dtype(self, stub):
|
||||
"""
|
||||
Get output dtype of stub node.
|
||||
|
||||
Args:
|
||||
stub(StubNode): stub node
|
||||
|
||||
Return:
|
||||
output dtype.
|
||||
"""
|
||||
return self._executor.get_stub_dtype(stub)
|
||||
|
||||
def new_graph(self, obj, *args, **kwargs):
|
||||
"""
|
||||
Initialize resources for building forward and backward graph.
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import Tuple
|
|||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common._utils import is_stub_tensor
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._c_expression import COOTensor as COOTensor_
|
||||
from mindspore._c_expression import CSRTensor as CSRTensor_
|
||||
|
@ -271,6 +272,10 @@ class COOTensor(COOTensor_):
|
|||
validator.check_coo_tensor_shape(indices.shape, values.shape, shape)
|
||||
validator.check_coo_tensor_dtype(indices.dtype)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
if is_stub_tensor(indices):
|
||||
indices.stub_sync()
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -604,6 +609,12 @@ class CSRTensor(CSRTensor_):
|
|||
validator.check_csr_tensor_dtype(indptr.dtype, indices.dtype)
|
||||
indptr = tensor_operator_registry.get('stop_gradient')(indptr)
|
||||
indices = tensor_operator_registry.get('stop_gradient')(indices)
|
||||
if is_stub_tensor(indptr):
|
||||
indptr.stub_sync()
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
if is_stub_tensor(indices):
|
||||
indices.stub_sync()
|
||||
CSRTensor_.__init__(self, indptr, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
|
|
@ -21,11 +21,12 @@ import numbers
|
|||
import numpy as np
|
||||
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.common._utils import is_shape_unknown
|
||||
from mindspore.common._utils import is_shape_unknown, is_stub_tensor
|
||||
from mindspore.common.seed import get_seed
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from mindspore.common._utils import get_slice_num
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
|
@ -153,6 +154,9 @@ class Tensor(Tensor_):
|
|||
if input_data is not None:
|
||||
Tensor_.__init__(self, input_data)
|
||||
else:
|
||||
if is_stub_tensor(input_data):
|
||||
input_data.stub_sync()
|
||||
|
||||
# If input data is numpy number, convert it to np array
|
||||
if isinstance(input_data, np_types):
|
||||
input_data = np.array(input_data)
|
||||
|
@ -396,6 +400,8 @@ class Tensor(Tensor_):
|
|||
|
||||
def __setitem__(self, index, value):
|
||||
out = tensor_operator_registry.get('__setitem__')(self, index, value)
|
||||
if is_stub_tensor(out):
|
||||
out.stub_sync()
|
||||
self.assign_value(out)
|
||||
if self.parent_tensor_ is not None and self.index_of_parent_ is not None:
|
||||
self.parent_tensor_.__setitem__(self.index_of_parent_, self)
|
||||
|
@ -1328,7 +1334,6 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('log1p')(self)
|
||||
|
||||
def logit(self, eps=None):
|
||||
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.logit`.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue