forked from mindspore-Ecosystem/mindspore
!48549 fix stub tensor bug
Merge pull request !48549 from yangsijia/fix-stub-tensor
This commit is contained in:
commit
680f9d3b48
|
@ -90,7 +90,6 @@
|
|||
#include "include/common/debug/rdr/recorder_manager.h"
|
||||
#include "ir/cell.h"
|
||||
#endif
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
#include "pybind_api/ir/log_adapter_py.h" // Only include one-time in the whole project.
|
||||
#include "pybind_api/ir/py_execute_py.h" // Only include one-time in the whole project.
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
#include "include/common/utils/primitive_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pybind_api/utils/stub_tensor_py.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -314,8 +313,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
|
||||
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
|
||||
}
|
||||
tensor::TensorPtr actual_out_tensor = mindspore::PyTensorCast(grad_out);
|
||||
tensor::TensorPtr expected_out_tensor = mindspore::PyTensorCast(expected_grad_out);
|
||||
auto actual_out_tensor = PyTensorCast(grad_out);
|
||||
auto expected_out_tensor = PyTensorCast(expected_grad_out);
|
||||
MS_EXCEPTION_IF_NULL(actual_out_tensor);
|
||||
MS_EXCEPTION_IF_NULL(expected_out_tensor);
|
||||
if (actual_out_tensor->GetShapeAndDataTypeInfo() != expected_out_tensor->GetShapeAndDataTypeInfo()) {
|
||||
|
|
|
@ -38,11 +38,10 @@ ShapeVector StubNode::GetShapeVector() {
|
|||
|
||||
TypePtr StubNode::GetTypePtr() {
|
||||
auto base = abs->BuildType();
|
||||
auto type = base->cast<TensorTypePtr>();
|
||||
if (!type) {
|
||||
MS_LOG(EXCEPTION) << "Only Tensor dtype is supported by Stub now: " << base->ToString();
|
||||
if (base->isa<TensorType>()) {
|
||||
return base->cast<TensorTypePtr>()->element();
|
||||
}
|
||||
return type->element();
|
||||
return base;
|
||||
}
|
||||
|
||||
py::object StubNode::GetValue() { return pynative::PyNativeAlgo::DataConvert::ValueToPyObj(value); }
|
||||
|
@ -60,8 +59,9 @@ py::object StubNode::GetDtype() { return py::cast(GetTypePtr()); }
|
|||
|
||||
py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const ValuePtr &value) {
|
||||
py::object result;
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
result = ConvertTensor(abs->cast<abstract::AbstractTensorPtr>(), value);
|
||||
if (abs->isa<abstract::AbstractTensor>() || (value && value->isa<tensor::Tensor>())) {
|
||||
// In `TensorArray` case, abstract is AbstractScalar and value is Tensor.
|
||||
result = ConvertTensor(abs, value);
|
||||
root_type_ = static_cast<int>(StubNode::TENSOR);
|
||||
} else if (abs->isa<abstract::AbstractTuple>()) {
|
||||
result = ConvertTuple(abs->cast<abstract::AbstractTuplePtr>(), value);
|
||||
|
@ -81,7 +81,7 @@ py::object StubOutConverter::Convert(const abstract::AbstractBasePtr &abs, const
|
|||
return result;
|
||||
}
|
||||
|
||||
py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value) {
|
||||
py::object StubOutConverter::ConvertTensor(const abstract::AbstractBasePtr &tensor_abs, const ValuePtr &value) {
|
||||
auto stub = std::make_shared<StubNode>();
|
||||
stub->value = value;
|
||||
stub->abs = tensor_abs;
|
||||
|
@ -90,18 +90,18 @@ py::object StubOutConverter::ConvertTensor(const abstract::AbstractTensorPtr &te
|
|||
|
||||
py::object StubOutConverter::ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value) {
|
||||
auto elements = seq_abs->elements();
|
||||
py::tuple out(elements.size());
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs not match: value " << value->ToString() << " vs abstract "
|
||||
<< seq_abs->ToString();
|
||||
}
|
||||
auto seq_value = value->cast<ValueTuplePtr>();
|
||||
if (seq_value->size() != seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
|
||||
<< seq_abs->size();
|
||||
if (seq_value->size() > seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot convert, abstract size must greater or equal to value size: " << seq_value->size()
|
||||
<< " vs " << seq_abs->size();
|
||||
}
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
py::tuple out(seq_value->size());
|
||||
for (size_t i = 0; i < seq_value->size(); ++i) {
|
||||
out[i] = Convert(elements[i], seq_value->value()[i]);
|
||||
}
|
||||
return out;
|
||||
|
@ -115,12 +115,12 @@ py::object StubOutConverter::ConvertList(const abstract::AbstractListPtr &seq_ab
|
|||
<< seq_abs->ToString();
|
||||
}
|
||||
auto seq_value = value->cast<ValueListPtr>();
|
||||
if (seq_value->size() != seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "value and abs size not match: value " << seq_value->size() << " vs abstract "
|
||||
<< seq_abs->size();
|
||||
if (seq_value->size() > seq_abs->size()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot convert, abstract size must greater or equal to value size: " << seq_value->size()
|
||||
<< " vs " << seq_abs->size();
|
||||
}
|
||||
py::list out(elements.size());
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
py::list out(seq_value->size());
|
||||
for (size_t i = 0; i < seq_value->size(); ++i) {
|
||||
out[i] = Convert(elements[i], seq_value->value()[i]);
|
||||
}
|
||||
return out;
|
||||
|
|
|
@ -63,7 +63,7 @@ class StubOutConverter {
|
|||
int GetRootType() { return root_type_; }
|
||||
|
||||
private:
|
||||
py::object ConvertTensor(const abstract::AbstractTensorPtr &tensor_abs, const ValuePtr &value);
|
||||
py::object ConvertTensor(const abstract::AbstractBasePtr &tensor_abs, const ValuePtr &value);
|
||||
py::object ConvertScalar(const abstract::AbstractBasePtr &scalar_abs);
|
||||
py::object ConvertTuple(const abstract::AbstractTuplePtr &seq_abs, const ValuePtr &value);
|
||||
py::object ConvertList(const abstract::AbstractListPtr &seq_abs, const ValuePtr &value);
|
||||
|
|
|
@ -86,6 +86,13 @@ class StubTensor(Tensor):
|
|||
self.stub_sync()
|
||||
return super().has_init
|
||||
|
||||
@property
|
||||
def adapter_flag(self):
|
||||
"""adapter_flag stub."""
|
||||
if self.stub:
|
||||
return False
|
||||
return super().adapter_flag
|
||||
|
||||
def stub_sync(self):
|
||||
"""data sync to get real tensor"""
|
||||
if self.stub:
|
||||
|
@ -99,6 +106,13 @@ class StubTensor(Tensor):
|
|||
"""
|
||||
return self.ndim
|
||||
|
||||
def dim(self):
|
||||
r"""
|
||||
Alias for :func:`mindspore.Tensor.ndim`.
|
||||
"""
|
||||
return self.ndim
|
||||
|
||||
|
||||
def asnumpy(self):
|
||||
"""api stub."""
|
||||
self.stub_sync()
|
||||
|
|
|
@ -48,6 +48,8 @@ class RowTensorInner(RowTensor_):
|
|||
RowTensor_.__init__(self, row_tensor)
|
||||
# Init a RowTensor from indices, values and shape
|
||||
else:
|
||||
if is_stub_tensor(values):
|
||||
values.stub_sync()
|
||||
RowTensor_.__init__(self, indices, values, shape)
|
||||
self.init_finished = True
|
||||
|
||||
|
@ -122,6 +124,7 @@ class RowTensor(RowTensorInner):
|
|||
>>> print(x.dense_shape)
|
||||
(3, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, indices=None, values=None, shape=None, row_tensor=None):
|
||||
"""Init RowTensor"""
|
||||
logger.warning("'RowTensor' is deprecated from version 1.7 and will be removed in a future version.")
|
||||
|
|
|
@ -2208,7 +2208,7 @@ class Cell(Cell_):
|
|||
if not isinstance(net_input, Tensor):
|
||||
raise TypeError(
|
||||
f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
|
||||
if set_input.dtype is not net_input.dtype:
|
||||
if set_input.dtype != net_input.dtype:
|
||||
raise ValueError(
|
||||
f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
|
||||
f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
||||
|
|
|
@ -805,7 +805,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|||
return deco
|
||||
|
||||
|
||||
_RUN_OP_ASYNC = False
|
||||
_RUN_OP_ASYNC = True
|
||||
|
||||
|
||||
def _run_op(obj, op_name, args):
|
||||
|
|
Loading…
Reference in New Issue