Add the transformation to python tensor and adapter tensor for empty tensor created by FuncGraphBuilder
This commit is contained in:
parent
db0c12c97b
commit
d7c780681f
|
@ -29,6 +29,10 @@
|
|||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr auto kPiJitPyObjKey = "pi_jit_py_obj";
|
||||
constexpr auto kTensorModule = "mindspore.common";
|
||||
constexpr auto kAdapterFlag = "adapter_flag";
|
||||
constexpr auto kInnerOpsModule = "mindspore.ops.operations._inner_ops";
|
||||
|
||||
bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
|
||||
static HashSet<std::string> prims_should_fallback_in_runtime = {kListInplaceExtendOpName,
|
||||
kListInplaceInsertOpName,
|
||||
|
@ -143,6 +147,19 @@ py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs) {
|
|||
if (py::isinstance<py::none>(py_obj)) {
|
||||
return py::object();
|
||||
}
|
||||
|
||||
bool is_adapter_tensor = py::hasattr(py_obj, kAdapterFlag) && py::cast<bool>(py::getattr(py_obj, kAdapterFlag));
|
||||
// Create python tensor.
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
py::module mod = python_adapter::GetPyModule(kTensorModule);
|
||||
py_obj = python_adapter::CallPyModFn(mod, "Tensor", py_obj, py::none(), py::none(), py::none(), true);
|
||||
}
|
||||
// Create adapter tensor.
|
||||
if (is_adapter_tensor) {
|
||||
py::module mod = python_adapter::GetPyModule(kInnerOpsModule);
|
||||
py_obj = python_adapter::CallPyModFn(mod, "convert_to_adapter_tensor", py_obj);
|
||||
}
|
||||
|
||||
return py_obj;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue