Add the transformation to python tensor and adapter tensor for empty tensor created by FuncGraphBuilder

This commit is contained in:
yujianfeng 2024-02-04 15:33:04 +08:00 committed by r1chardf1d0
parent db0c12c97b
commit d7c780681f
1 changed files with 17 additions and 0 deletions

View File

@ -29,6 +29,10 @@
namespace mindspore {
namespace {
constexpr auto kPiJitPyObjKey = "pi_jit_py_obj";
constexpr auto kTensorModule = "mindspore.common";
constexpr auto kAdapterFlag = "adapter_flag";
constexpr auto kInnerOpsModule = "mindspore.ops.operations._inner_ops";
bool ShouldFallBackInRuntime(const PrimitivePtr &prim) {
static HashSet<std::string> prims_should_fallback_in_runtime = {kListInplaceExtendOpName,
kListInplaceInsertOpName,
@ -143,6 +147,19 @@ py::object FuncGraphBuilder::ConvertToPyObj(const AbstractBasePtr &abs) {
if (py::isinstance<py::none>(py_obj)) {
return py::object();
}
bool is_adapter_tensor = py::hasattr(py_obj, kAdapterFlag) && py::cast<bool>(py::getattr(py_obj, kAdapterFlag));
// Create python tensor.
if (abs->isa<abstract::AbstractTensor>()) {
py::module mod = python_adapter::GetPyModule(kTensorModule);
py_obj = python_adapter::CallPyModFn(mod, "Tensor", py_obj, py::none(), py::none(), py::none(), true);
}
// Create adapter tensor.
if (is_adapter_tensor) {
py::module mod = python_adapter::GetPyModule(kInnerOpsModule);
py_obj = python_adapter::CallPyModFn(mod, "convert_to_adapter_tensor", py_obj);
}
return py_obj;
}