!48549 fix stub tensor bug

Merge pull request !48549 from yangsijia/fix-stub-tensor
This commit is contained in:
i-robot 2023-02-08 09:50:34 +00:00 committed by Gitee
commit 680f9d3b48
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 39 additions and 25 deletions

View File

@ -90,7 +90,6 @@
#include "include/common/debug/rdr/recorder_manager.h"
#include "ir/cell.h"
#endif
#include "pybind_api/utils/stub_tensor_py.h"
#include "pybind_api/ir/log_adapter_py.h" // Only include one-time in the whole project.
#include "pybind_api/ir/py_execute_py.h" // Only include one-time in the whole project.

View File

@ -25,7 +25,6 @@
#include "include/common/utils/convert_utils_py.h"
#include "include/common/debug/anf_ir_dump.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore {
namespace pynative {

View File

@ -30,7 +30,6 @@
#include "include/common/utils/primitive_utils.h"
#include "utils/check_convert_utils.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pybind_api/utils/stub_tensor_py.h"
namespace mindspore {
namespace {
@ -314,8 +313,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
}
tensor::TensorPtr actual_out_tensor = mindspore::PyTensorCast(grad_out);
tensor::TensorPtr expected_out_tensor = mindspore::PyTensorCast(expected_grad_out);
auto actual_out_tensor = PyTensorCast(grad_out);
auto expected_out_tensor = PyTensorCast(expected_grad_out);
MS_EXCEPTION_IF_NULL(actual_out_tensor);
MS_EXCEPTION_IF_NULL(expected_out_tensor);
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {

View File

@ -38,11 +38,10 @@ ShapeVector StubNode::GetShapeVector() {
TypePtr StubNode::GetTypePtr() {
auto base = abs->BuildType();
auto type = base->cast<TensorTypePtr>();
if (!type) {
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
if (base->isa<TensorType>()) {
return base->cast<TensorTypePtr>()->element();
}
return type->element();
return base;
}
py::object StubNode::GetValue() { return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(value); }
@ -60,8 +59,9 @@ py::object StubNode::GetDtype() { return py::cast(GetTypePtr()); }
py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value) {
py::object result;
if (abs->isa<abstract::AbstractTensor>()) {
result = ConvertTensor(abs->cast<abstract::AbstractTensorPtr>(), value);
if (abs->isa<abstract::AbstractTensor>() || (value && value->isa<tensor::Tensor>())) {
// In `TensorArray` case, abstract is AbstractScalar and value is Tensor.
result = ConvertTensor(abs, value);
root_type_ = static_cast<int>(StubNode::TENSOR);
} else if (abs->isa<abstract::AbstractTuple>()) {
result = ConvertTuple(abs->cast<abstract::AbstractTuplePtr>(), value);
@ -81,7 +81,7 @@ py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const
return result;
}
py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value) {
py::object StubOutConverter::ConvertTensor(const abstract::AbstractBasePtr &tensor_abs, const ValuePtr &value) {
auto stub = std::make_shared<StubNode>();
stub->value = value;
stub->abs = tensor_abs;
@ -90,18 +90,18 @@ py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &te
py::object StubOutConverter::ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value) {
auto elements = seq_abs->elements();
py::tuple out(elements.size());
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
<< seq_abs->ToString();
}
auto seq_value = value->cast<ValueTuplePtr>();
if (seq_value->size() != seq_abs->size()) {
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
<< seq_abs->size();
if (seq_value->size() > seq_abs->size()) {
MS_LOG(EXCEPTION) << "Cannot convert, abstract size must greater or equal to value size: " << seq_value->size()
<< " vs " << seq_abs->size();
}
for (size_t i = 0; i < elements.size(); ++i) {
py::tuple out(seq_value->size());
for (size_t i = 0; i < seq_value->size(); ++i) {
out[i] = Convert(elements[i], seq_value->value()[i]);
}
return out;
@ -115,12 +115,12 @@ py::object StubOutConverter::ConvertList(const abstract::AbstractListPtr &seq_ab
<< seq_abs->ToString();
}
auto seq_value = value->cast<ValueListPtr>();
if (seq_value->size() != seq_abs->size()) {
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
<< seq_abs->size();
if (seq_value->size() > seq_abs->size()) {
MS_LOG(EXCEPTION) << "Cannot convert, abstract size must greater or equal to value size: " << seq_value->size()
<< " vs " << seq_abs->size();
}
py::list out(elements.size());
for (size_t i = 0; i < elements.size(); ++i) {
py::list out(seq_value->size());
for (size_t i = 0; i < seq_value->size(); ++i) {
out[i] = Convert(elements[i], seq_value->value()[i]);
}
return out;

View File

@ -63,7 +63,7 @@ class StubOutConverter {
int GetRootType() { return root_type_; }
private:
py::object ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value);
py::object ConvertTensor(const abstract::AbstractBasePtr &tensor_abs, const ValuePtr &value);
py::object ConvertScalar(const abstract::AbstractBasePtr &scalar_abs);
py::object ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value);
py::object ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value);

View File

@ -86,6 +86,13 @@ class StubTensor(Tensor):
self.stub_sync()
return super().has_init
@property
def adapter_flag(self):
"""adapter_flag stub."""
if self.stub:
return False
return super().adapter_flag
def stub_sync(self):
"""data sync to get real tensor"""
if self.stub:
@ -99,6 +106,13 @@ class StubTensor(Tensor):
"""
return self.ndim
def dim(self):
r"""
Alias for :func:`mindspore.Tensor.ndim`.
"""
return self.ndim
def asnumpy(self):
"""api stub."""
self.stub_sync()

View File

@ -48,6 +48,8 @@ class RowTensorInner(RowTensor_):
RowTensor_.__init__(self, row_tensor)
# Init a RowTensor from indices, values and shape
else:
if is_stub_tensor(values):
values.stub_sync()
RowTensor_.__init__(self, indices, values, shape)
self.init_finished = True
@ -122,6 +124,7 @@ class RowTensor(RowTensorInner):
>>> print(x.dense_shape)
(3, 2)
"""
def __init__(self, indices=None, values=None, shape=None, row_tensor=None):
"""Init RowTensor"""
logger.warning("'RowTensor' is deprecated from version 1.7 and will be removed in a future version.")

View File

@ -2208,7 +2208,7 @@ class Cell(Cell_):
if not isinstance(net_input, Tensor):
raise TypeError(
f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
if set_input.dtype is not net_input.dtype:
if set_input.dtype != net_input.dtype:
raise ValueError(
f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")

View File

@ -805,7 +805,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
return deco
_RUN_OP_ASYNC = False
_RUN_OP_ASYNC = True
def _run_op(obj, op_name, args):