Support None is output of subgraph in control flow.

This commit is contained in:
Margaret_wangrui 2023-02-20 17:41:56 +08:00
parent 770c9fc703
commit bba47d6374
5 changed files with 43 additions and 5 deletions

View File

@ -107,6 +107,7 @@ using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTensorPtr;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
using DeviceTensor = mindspore::device::DeviceAddress;
const char IR_TYPE_ANF[] = "anf_ir";
const char IR_TYPE_ONNX[] = "onnx_ir";
@ -1372,6 +1373,28 @@ void GraphExecutorPy::TerminateDebugger() {
}
#endif
std::pair<py::object, bool> GraphExecutorPy::GetPyExecuteOutputFromAddress(const py::object &res,
const BaseRef &value) {
if (py::isinstance<tensor::Tensor>(res)) {
auto res_tensor = res.cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(res_tensor);
if (res_tensor->device_address() != nullptr) {
auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(res_tensor->device_address());
MS_LOG(DEBUG) << "res tensor_address:" << tensor_address;
AnfNodePtr real_node = AnfNodePtr(tensor_address->node_index().first);
if (real_node != nullptr) {
MS_LOG(DEBUG) << "real_node:" << real_node->DebugString();
const auto &[py_res, has_real_output] = GetPyExecuteOutput(real_node, value);
if (has_real_output) {
MS_LOG(DEBUG) << "py_res:" << py_res;
return {py_res, true};
}
}
}
}
return {py::none(), false};
}
py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_obj) {
// init for dynamic-obfuscated model infer
(void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
@ -1448,7 +1471,7 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
}
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
py::object res;
MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy();
MS_LOG(DEBUG) << "Eval run " << ms_context->backend_policy();
const auto &output = execute_info->func_graph->output();
MS_EXCEPTION_IF_NULL(output);
@ -1458,6 +1481,12 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
for (int64_t i = 0; i < vm_loop; i++) {
value = (*run)(execute_info->arg_list);
res = BaseRefToPyData(value, output_abs);
// If crossing the graph, may not get PyExecuteOutputUserData in the parent graph.
// Get PyExecuteOutputUserData by device_address bound AnfNode which is in sub graph.
const auto &[py_res, has_real_node_address] = GetPyExecuteOutputFromAddress(res, value);
if (has_real_node_address) {
return py_res;
}
}
// Replace the output if it's not Tensor, but Python data.
const auto &[py_res, has_real_output] = GetPyExecuteOutput(output, value);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -88,6 +88,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
// for pynative mode when use_vm is on
py::object Run(const py::tuple &args, const py::object &phase_obj);
std::pair<py::object, bool> GetPyExecuteOutputFromAddress(const py::object &res, const BaseRef &value);
ResourcePtr GetResource(const std::string &phase);
FuncGraphPtr GetFuncGraph(const std::string &phase);
FuncGraphPtr GetGradGraph(const std::string &phase);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -245,6 +245,8 @@ class DeviceAddress : public mindspore::DeviceSync {
void UpdateFlag(size_t flag) { SET_FLAG(flag_, flag); }
void ClearFlag(size_t flag) { CLEAR_FLAG(flag_, flag); }
std::pair<AnfNodeWeakPtr, size_t> node_index() const { return node_index_; }
protected:
const void *ptr() const { return ptr_; }
size_t size() const { return size_; }

View File

@ -493,14 +493,21 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
}
return ref_tuple;
}
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
// If FALLBACK_RUNTIME is not enable
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
size_t ref_idx = 0;
for (size_t i = 0; i < seq_abs->size(); i++) {
auto elem_abs = seq_abs->elements()[i];
if (elem_abs->isa<abstract::AbstractNone>() && !support_fallback_runtime) {
continue;
}
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
ref_idx++;
}
if (ref_idx != value_size) {
MS_LOG(EXCEPTION) << "The size of elements should be equal to " << value_size << ", but got " << ref_idx;
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
<< ref_idx;
}
ret = ref_tuple;
return ret;

View File

@ -291,7 +291,6 @@ def test_none_is_condition():
assert res is None
@pytest.mark.skip(reason="No support None in control flow.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training