From 7e384a92cdeb13af272057e5909a6e23c8829460 Mon Sep 17 00:00:00 2001 From: huangchengnuo Date: Thu, 2 Mar 2023 10:19:24 +0800 Subject: [PATCH] Add StubTensor check in C++ --- mindspore/ccsrc/pipeline/jit/parse/data_converter.cc | 3 +++ mindspore/ccsrc/pipeline/jit/pipeline.cc | 8 ++++---- .../ascend/hal/hardware/ascend_deprecated_interface.cc | 4 ++++ .../device/cpu/kernel/pyexecute/py_execute_cpu_kernel.cc | 5 +++++ mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 8 ++++++-- mindspore/ccsrc/pybind_api/ir/py_execute_py.h | 4 ++-- 6 files changed, 24 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 90a192bfeb4..fed37007027 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -379,6 +379,9 @@ ValuePtr ConvertSlice(const py::object &obj) { if (py::isinstance(py_attr)) { return py::cast(py_attr); } + if (IsStubTensor(py_attr)) { + return ConvertStubTensor(py_attr); + } MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj) << " should be int or Tensor with Int type but got " << py::str(py_attr); }; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 5fdc22144b1..7cf9fa873f4 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -210,8 +210,8 @@ bool CheckArgValid(const py::handle &arg) { return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); }); } - if (py::isinstance(arg)) { - auto tensor = py::cast(arg); + if (py::isinstance(arg) || IsStubTensor(arg)) { + auto tensor = IsStubTensor(arg) ? ConvertStubTensor(arg) : py::cast(arg); if (tensor->data_type() == kNumberTypeBool) { MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause " << "operator compilation failure. For more details, please refer to the FAQ at " @@ -1320,8 +1320,8 @@ void GraphExecutorPy::TerminateDebugger() { std::pair GraphExecutorPy::GetPyExecuteOutputFromAddress(const py::object &res, const BaseRef &value) { - if (py::isinstance(res)) { - auto res_tensor = res.cast(); + if (py::isinstance(res) || IsStubTensor(res)) { + auto res_tensor = IsStubTensor(res) ? ConvertStubTensor(res) : res.cast(); MS_EXCEPTION_IF_NULL(res_tensor); if (res_tensor->device_address() != nullptr) { auto tensor_address = std::dynamic_pointer_cast(res_tensor->device_address()); diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc index 15268d80256..34a165b814f 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_deprecated_interface.cc @@ -16,6 +16,7 @@ #include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h" #include +#include "mindspore/ccsrc/include/common/utils/convert_utils_py.h" #include "plugin/device/ascend/hal/hardware/ge_device_context.h" #include "include/transform/graph_ir/types.h" #include "include/transform/graph_ir/utils.h" @@ -66,6 +67,9 @@ void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *con } else if (py::isinstance(item.second.attr("data"))) { // cast tensor tensor = py::cast>(item.second.attr("data")); + } else if (IsStubTensor(item.second.attr("data"))) { + // cast stub_tensor + tensor = ConvertStubTensor(item.second.attr("data")); } if (tensor == nullptr) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.cc index 690b58e8c11..13e0a7e169d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.cc @@ -26,6 +26,7 @@ #include "plugin/device/cpu/hal/device/cpu_common.h" #include "include/common/fallback.h" #include "include/common/utils/python_adapter.h" +#include "include/common/utils/convert_utils_py.h" #include "include/common/utils/python_fallback_running.h" #include "plugin/factory/ms_factory.h" #include "mindspore/ccsrc/pipeline/jit/parse/resolve.h" @@ -305,6 +306,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector &inputs, const MS_LOG(DEBUG) << "Python *prebuilt* output type: " << output_type << ", output: " << output; if (py::isinstance(output)) { TensorToRawMemory(output.cast(), outputs[0]); + } else if (IsStubTensor(output)) { + TensorToRawMemory(ConvertStubTensor(output), outputs[0]); } AttachPyOutputData(output); return true; @@ -326,6 +329,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector &inputs, const MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output; if (py::isinstance(output)) { TensorToRawMemory(output.cast(), outputs[0]); + } else if (IsStubTensor(output)) { + TensorToRawMemory(ConvertStubTensor(output), outputs[0]); } AttachPyOutputData(output); } catch (const py::error_already_set &e) { diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 66ddfd75612..422f5c7404c 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -52,6 +52,10 @@ void SyncData(const py::object &arg) { auto tensor = py::cast(arg); tensor->data_sync(); } + if (IsStubTensor(arg)) { + auto tensor = ConvertStubTensor(arg); + tensor->data_sync(); + } } void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) { @@ -307,8 +311,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj } } - if (py::isinstance(expected_grad_out)) { - if (!py::isinstance(grad_out)) { + if (py::isinstance(expected_grad_out) || IsStubTensor(expected_grad_out)) { + if (!py::isinstance(grad_out) && !IsStubTensor(grad_out)) { hook_grad_.clear(); MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got " << py::cast(grad_out.attr("__class__").attr("__name__")) << "."; diff --git a/mindspore/ccsrc/pybind_api/ir/py_execute_py.h b/mindspore/ccsrc/pybind_api/ir/py_execute_py.h index c5aa054ffad..359d86e09ca 100644 --- a/mindspore/ccsrc/pybind_api/ir/py_execute_py.h +++ b/mindspore/ccsrc/pybind_api/ir/py_execute_py.h @@ -125,8 +125,8 @@ class PyExecuteInitializer { const auto &output = parse::data_converter::CallPythonScript(py_script, params); MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output; PushPyExecuteOutput(script_str, output); - if (py::isinstance(output)) { - const auto &tensor = output.cast(); + if (py::isinstance(output) || IsStubTensor(output)) { + const auto &tensor = IsStubTensor(output) ? ConvertStubTensor(output) : output.cast(); const auto &infer_shape = std::make_shared(tensor->shape()); return abstract::MakeAbstract(infer_shape, tensor->Dtype()); }