forked from mindspore-Ecosystem/mindspore
!49634 Add StubTensor check in C++
Merge pull request !49634 from NaCN/fix_tensor_in_check
This commit is contained in:
commit
8547db71ea
|
@ -379,6 +379,9 @@ ValuePtr ConvertSlice(const py::object &obj) {
|
|||
if (py::isinstance<Tensor>(py_attr)) {
|
||||
return py::cast<TensorPtr>(py_attr);
|
||||
}
|
||||
if (IsStubTensor(py_attr)) {
|
||||
return ConvertStubTensor(py_attr);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Attribute '" << attr << "' of " << py::str(obj)
|
||||
<< " should be int or Tensor with Int type but got " << py::str(py_attr);
|
||||
};
|
||||
|
|
|
@ -210,8 +210,8 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
|
||||
}
|
||||
|
||||
if (py::isinstance<Tensor>(arg)) {
|
||||
auto tensor = py::cast<TensorPtr>(arg);
|
||||
if (py::isinstance<Tensor>(arg) || IsStubTensor(arg)) {
|
||||
auto tensor = IsStubTensor(arg) ? ConvertStubTensor(arg) : py::cast<TensorPtr>(arg);
|
||||
if (tensor->data_type() == kNumberTypeBool) {
|
||||
MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
|
||||
<< "operator compilation failure. For more details, please refer to the FAQ at "
|
||||
|
@ -1320,8 +1320,8 @@ void GraphExecutorPy::TerminateDebugger() {
|
|||
|
||||
std::pair<py::object, bool> GraphExecutorPy::GetPyExecuteOutputFromAddress(const py::object &res,
|
||||
const BaseRef &value) {
|
||||
if (py::isinstance<tensor::Tensor>(res)) {
|
||||
auto res_tensor = res.cast<tensor::TensorPtr>();
|
||||
if (py::isinstance<tensor::Tensor>(res) || IsStubTensor(res)) {
|
||||
auto res_tensor = IsStubTensor(res) ? ConvertStubTensor(res) : res.cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(res_tensor);
|
||||
if (res_tensor->device_address() != nullptr) {
|
||||
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(res_tensor->device_address());
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_deprecated_interface.h"
|
||||
#include <algorithm>
|
||||
#include "mindspore/ccsrc/include/common/utils/convert_utils_py.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ge_device_context.h"
|
||||
#include "include/transform/graph_ir/types.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
|
@ -66,6 +67,9 @@ void ConvertObjectToTensors(const py::dict &dict, transform::TensorOrderMap *con
|
|||
} else if (py::isinstance<tensor::Tensor>(item.second.attr("data"))) {
|
||||
// cast tensor
|
||||
tensor = py::cast<std::shared_ptr<tensor::Tensor>>(item.second.attr("data"));
|
||||
} else if (IsStubTensor(item.second.attr("data"))) {
|
||||
// cast stub_tensor
|
||||
tensor = ConvertStubTensor(item.second.attr("data"));
|
||||
}
|
||||
|
||||
if (tensor == nullptr) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "plugin/device/cpu/hal/device/cpu_common.h"
|
||||
#include "include/common/fallback.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/utils/python_fallback_running.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
|
||||
|
@ -305,6 +306,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
MS_LOG(DEBUG) << "Python *prebuilt* output type: " << output_type << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
|
||||
} else if (IsStubTensor(output)) {
|
||||
TensorToRawMemory(ConvertStubTensor(output), outputs[0]);
|
||||
}
|
||||
AttachPyOutputData(output);
|
||||
return true;
|
||||
|
@ -326,6 +329,8 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
MS_LOG(DEBUG) << "Python output type: " << output_type << ", output: " << output;
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
TensorToRawMemory(output.cast<tensor::TensorPtr>(), outputs[0]);
|
||||
} else if (IsStubTensor(output)) {
|
||||
TensorToRawMemory(ConvertStubTensor(output), outputs[0]);
|
||||
}
|
||||
AttachPyOutputData(output);
|
||||
} catch (const py::error_already_set &e) {
|
||||
|
|
|
@ -52,6 +52,10 @@ void SyncData(const py::object &arg) {
|
|||
auto tensor = py::cast<tensor::TensorPtr>(arg);
|
||||
tensor->data_sync();
|
||||
}
|
||||
if (IsStubTensor(arg)) {
|
||||
auto tensor = ConvertStubTensor(arg);
|
||||
tensor->data_sync();
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) {
|
||||
|
@ -307,8 +311,8 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
}
|
||||
}
|
||||
|
||||
if (py::isinstance<tensor::Tensor>(expected_grad_out)) {
|
||||
if (!py::isinstance<tensor::Tensor>(grad_out)) {
|
||||
if (py::isinstance<tensor::Tensor>(expected_grad_out) || IsStubTensor(expected_grad_out)) {
|
||||
if (!py::isinstance<tensor::Tensor>(grad_out) && !IsStubTensor(grad_out)) {
|
||||
hook_grad_.clear();
|
||||
MS_EXCEPTION(TypeError) << "The output type of:" << py::str(co_name) << " should be a tensor but got "
|
||||
<< py::cast<std::string>(grad_out.attr("__class__").attr("__name__")) << ".";
|
||||
|
|
|
@ -125,8 +125,8 @@ class PyExecuteInitializer {
|
|||
const auto &output = parse::data_converter::CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "Python output type: " << py::str(output.get_type()) << ", output: " << output;
|
||||
PushPyExecuteOutput(script_str, output);
|
||||
if (py::isinstance<tensor::Tensor>(output)) {
|
||||
const auto &tensor = output.cast<tensor::TensorPtr>();
|
||||
if (py::isinstance<tensor::Tensor>(output) || IsStubTensor(output)) {
|
||||
const auto &tensor = IsStubTensor(output) ? ConvertStubTensor(output) : output.cast<tensor::TensorPtr>();
|
||||
const auto &infer_shape = std::make_shared<abstract::Shape>(tensor->shape());
|
||||
return abstract::MakeAbstract(infer_shape, tensor->Dtype());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue