From 0ef5e5e586c00fae32ef8558dc6d6f3459e11639 Mon Sep 17 00:00:00 2001 From: caifubi Date: Fri, 2 Sep 2022 10:40:20 +0800 Subject: [PATCH] Convert python tensor to c++ tensor --- mindspore/ccsrc/backend/graph_compiler/backend.cc | 12 ++++++++++++ mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 2 ++ mindspore/core/ir/tensor.h | 5 +++++ 3 files changed, 19 insertions(+) diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index fc2bc611fce..b08e5766585 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -793,6 +793,18 @@ void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs, MS_EXCEPTION_IF_NULL(graph); const auto &output_nodes = op_compiler_info->graph_output_nodes_; + auto input_tensors = op_run_info->base_op_run_info.input_tensor; + for (auto &input_tensor : input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensor); + auto data = input_tensor->data_ptr(); + if (data != nullptr && data->is_from_numpy()) { + // Convert python tensor to cpp tensor + py::gil_scoped_acquire gil; + input_tensor->AssignValue(tensor::Tensor(*input_tensor, input_tensor->data_type())); + data = nullptr; + } + } + runtime::UpdateDeviceAddress(graph, GetTensorWithoutValueMask(op_run_info), op_compiler_info->device_context_); UpdateOutput(output_nodes, outputs); diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 340a8868f26..011a17daf75 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -179,6 +179,8 @@ class TensorDataNumpy : public TensorData { bool has_sub_data() const override { return false; } + bool is_from_numpy() const override { return true; } + /// To string. std::string ToString(const TypeId, const ShapeVector &, bool use_comma) const override { if (use_comma) { diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index f05eed1f440..aa46b1124ee 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -99,6 +99,11 @@ class MS_CORE_API TensorData { /// \return True if this tensor data has sub data, otherwise false. virtual bool has_sub_data() const = 0; + /// \brief Get whether this tensor data is from numpy. + /// + /// \return Whether this tensor data is from numpy. + virtual bool is_from_numpy() const { return false; } + /// \brief Whether the data are equal. /// /// \param[in] other Another TensorData.