!49634 Add StubTensor check in C++

Merge pull request !49634 from NaCN/fix_tensor_in_check
This commit is contained in:
i-robot 2023-03-03 17:17:56 +00:00 committed by Gitee
commit 8547db71ea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 24 additions and 8 deletions

View File

@ -379,6 +379,9 @@ ValuePtr ConvertSlice(const py::object &obj) {
if (py::isinstance<Tensor>(py_attr)) {
return py::cast<TensorPtr>(py_attr);
}
if (IsStubTensor(py_attr)) {
return ConvertStubTensor(py_attr);
}
MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj)
<< " should be int or Tensor with Int type but got " << py::str(py_attr);
};

View File

@ -210,8 +210,8 @@ bool CheckArgValid(const py::handle &arg) {
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
}
if (py::isinstance<Tensor>(arg)) {
auto tensor = py::cast<TensorPtr>(arg);
if (py::isinstance<Tensor>(arg) || IsStubTensor(arg)) {
auto tensor = IsStubTensor(arg) ? ConvertStubTensor(arg) : py::cast<TensorPtr>(arg);
if (tensor->data_type() == kNumberTypeBool) {
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
<< "operator compilation failure. For more details, please refer to the FAQ at "
@ -1320,8 +1320,8 @@ void GraphExecutorPy::TerminateDebugger() {
std::pair<py::object, bool> GraphExecutorPy::GetPyExecuteOutputFromAddress(const py::object &res,
const BaseRef &value) {
if (py::isinstance<tensor::Tensor>(res)) {
auto res_tensor = res.cast<tensor::TensorPtr>();
if (py::isinstance<tensor::Tensor>(res) || IsStubTensor(res)) {
auto res_tensor = IsStubTensor(res) ? ConvertStubTensor(res) : res.cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(res_tensor);
if (res_tensor->device_address() != nullptr) {
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(res_tensor->device_address());

View File

@ -16,6 +16,7 @@
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
#include <algorithm>
#include "mindspore/ccsrc/include/common/utils/convert_utils_py.h"
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
#include "include/transform/graph_ir/types.h"
#include "include/transform/graph_ir/utils.h"
@ -66,6 +67,9 @@ void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *con
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
// cast tensor
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
} else if (IsStubTensor(item.second.attr("data"))) {
// cast stub_tensor
tensor = ConvertStubTensor(item.second.attr("data"));
}
if (tensor == nullptr) {

View File

@ -26,6 +26,7 @@
#include "plugin/device/cpu/hal/device/cpu_common.h"
#include "include/common/fallback.h"
#include "include/common/utils/python_adapter.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/utils/python_fallback_running.h"
#include "plugin/factory/ms_factory.h"
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
@ -305,6 +306,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
MS_LOG(DEBUG) << "Python *prebuilt* output type: " << output_type << ", output: " << output;
if (py::isinstance<tensor::Tensor>(output)) {
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
} else if (IsStubTensor(output)) {
TensorToRawMemory(ConvertStubTensor(output), outputs[0]);
}
AttachPyOutputData(output);
return true;
@ -326,6 +329,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
if (py::isinstance<tensor::Tensor>(output)) {
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
} else if (IsStubTensor(output)) {
TensorToRawMemory(ConvertStubTensor(output), outputs[0]);
}
AttachPyOutputData(output);
} catch (const py::error_already_set &e) {

View File

@ -52,6 +52,10 @@ void SyncData(const py::object &arg) {
auto tensor = py::cast<tensor::TensorPtr>(arg);
tensor->data_sync();
}
if (IsStubTensor(arg)) {
auto tensor = ConvertStubTensor(arg);
tensor->data_sync();
}
}
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) {
@ -307,8 +311,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
}
}
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
if (!py::isinstance<tensor::Tensor>(grad_out)) {
if (py::isinstance<tensor::Tensor>(expected_grad_out) || IsStubTensor(expected_grad_out)) {
if (!py::isinstance<tensor::Tensor>(grad_out) && !IsStubTensor(grad_out)) {
hook_grad_.clear();
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";

View File

@ -125,8 +125,8 @@ class PyExecuteInitializer {
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
PushPyExecuteOutput(script_str, output);
if (py::isinstance<tensor::Tensor>(output)) {
const auto &tensor = output.cast<tensor::TensorPtr>();
if (py::isinstance<tensor::Tensor>(output) || IsStubTensor(output)) {
const auto &tensor = IsStubTensor(output) ? ConvertStubTensor(output) : output.cast<tensor::TensorPtr>();
const auto &infer_shape = std::make_shared<abstract::Shape>(tensor->shape());
return abstract::MakeAbstract(infer_shape, tensor->Dtype());
}