!29782 fix the bug of host device in the control flow

Merge pull request !29782 from limingqi107/r1.6
This commit is contained in:
i-robot 2022-02-09 02:21:13 +00:00 committed by Gitee
commit ebbd3d2bc8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 37 additions and 41 deletions

View File

@ -657,7 +657,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
MS_EXCEPTION_IF_NULL(graph_inputs);
ParameterPtr new_parameter = nullptr;
auto func_graph = anf->func_graph();
if (func_graph->manager() != nullptr && func_graph->IsMultiTarget() &&
if (func_graph->manager() != nullptr && func_graph->exist_multi_target() &&
graph->device_target() == device::DeviceAddressType::kCPU) {
auto iter = default_param_map_.find(anf);
if (iter != default_param_map_.end()) {

View File

@ -117,7 +117,7 @@ void DisableMindRT(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(func_graph);
bool enable_old_runtime = (common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "0");
if (enable_old_runtime ||
(func_graph != nullptr && func_graph->ContainMultiTarget() && IsDynamicShapeGraph(func_graph))) {
(func_graph != nullptr && func_graph->exist_multi_target() && IsDynamicShapeGraph(func_graph))) {
// Heterogeneous scenario + dynamic_shape runs in MsBackend.
MS_LOG(INFO) << "Disable mindRT in the heterogeneous + dynamic shape scenario.";
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
@ -783,21 +783,14 @@ void SetRunMode(const ResourcePtr &res) {
}
// Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
if (func_graph->ContainMultiTarget() && ExistControlNode(func_graph)) {
MS_LOG(INFO) << "Run graph mode with kernelbykernel.";
set_ctx(false, false, false);
return;
}
// Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
if (func_graph->ContainMultiTarget() && ExistControlNode(func_graph)) {
if (func_graph->exist_multi_target() && ExistControlNode(func_graph)) {
MS_LOG(INFO) << "Run graph mode with kernelbykernel.";
set_ctx(false, false, false);
return;
}
// GRAPH | Heterogeneous scenario : SubGraph path in MindRT.
if (func_graph->ContainMultiTarget()) {
if (func_graph->exist_multi_target()) {
MS_LOG(INFO) << "Run graph mode with subgraph sink.";
set_ctx(true, false, false);
return;
@ -817,7 +810,7 @@ void OriginSetRunMode(const ResourcePtr &res) {
std::string backend = MsContext::GetInstance()->backend_policy();
MS_EXCEPTION_IF_NULL(context_ptr);
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (func_graph->ContainMultiTarget() || !task_sink) {
if (func_graph->exist_multi_target() || !task_sink) {
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
@ -845,13 +838,16 @@ void OriginSetRunMode(const ResourcePtr &res) {
bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(res->func_graph())) {
return true;
}
if (res->func_graph() == nullptr) {
FuncGraphPtr func_graph = res->func_graph();
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "TaskEmit args error";
}
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
CheckGraphOutputConstOrParameter(func_graph)) {
return true;
}
func_graph->SetMultiTarget();
DisableMindRT(res);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -861,8 +857,6 @@ bool TaskEmitAction(const ResourcePtr &res) {
OriginSetRunMode(res);
}
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto bc_ptr = res->GetResult(kBackend).cast<compile::BackendPtr>();
MS_EXCEPTION_IF_NULL(bc_ptr);
std::string backend = context_ptr->backend_policy();

View File

@ -438,11 +438,6 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
<< "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
<< py::str(obj);
}
if (target != kCPUDevice && target != kGPUDevice) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
}
}
attrs_[attr_name] = converted_ret;

View File

@ -401,10 +401,12 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
" position:" + std::to_string(formal_parameter_position) + " copy failed.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
output_data->data_ = device_tensor.get();
DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
output_data->data_ = device_tensor.get();
continue;
}
// Ref node may use the ptr of device tensor as the output address, so need set the ptr from data.
device_tensor->set_ptr(data->GetMutablePtr());
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
<< " for the ref formal parameter: " << formal_parameter.first->DebugString()

View File

@ -1944,8 +1944,8 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
// If the device tensor store of this device type is not exist, then create the new device tensor of this type.
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
<< ", type:" << device_context->GetDeviceAddressType();
MS_LOG(WARNING) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
<< ", type:" << device_context->GetDeviceAddressType();
auto other_type_device_tensor = device_context->CreateDeviceAddress(
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
other_type_device_tensor->SetNodeIndex(input_node, 0);

View File

@ -601,22 +601,22 @@ BackendPtr CreateBackend() {
void SetMindRTEnable() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT)) {
return;
}
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_ps_mode()) {
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
return;
}
#endif
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (common::GetEnv("DISABLE_ASCEND_MINDRT") == "1" && target == kAscendDevice) {
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
return;
}
#if defined(_WIN32) || defined(_WIN64)
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
return;
#endif

View File

@ -754,18 +754,25 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
return parameter;
}
bool FuncGraph::ContainMultiTarget() {
void FuncGraph::SetMultiTarget() {
auto graph_manager = manager();
MS_EXCEPTION_IF_NULL(graph_manager);
FuncGraphSet graphs = graph_manager->func_graphs();
std::vector<AnfNodePtr> all_nodes;
for (auto &g : graphs) {
auto nodes = mindspore::TopoSort(g->get_return());
if (mindspore::ContainMultiTarget(nodes)) {
exist_multi_target_ = true;
return true;
}
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(all_nodes));
}
bool exist_multi_target = false;
if (mindspore::ContainMultiTarget(all_nodes)) {
exist_multi_target = true;
MS_LOG(INFO) << "The graph " << ToString() << " exists the multi target.";
}
for (auto &g : graphs) {
g->set_exist_multi_target(exist_multi_target);
}
return false;
}
void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes) {

View File

@ -348,8 +348,9 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) {
switch_layer_input_ = switch_layer_input;
}
bool ContainMultiTarget();
bool IsMultiTarget() const { return exist_multi_target_; }
void SetMultiTarget();
bool exist_multi_target() const { return exist_multi_target_; }
void set_exist_multi_target(bool exist_multi_target) { exist_multi_target_ = exist_multi_target; }
int64_t stage() const { return stage_; }
void set_stage(int64_t stage) { stage_ = stage; }

View File

@ -93,7 +93,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);

View File

@ -89,7 +89,6 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_INFER_OPT,
MS_CTX_GRAD_FOR_SCALAR,
MS_CTX_ENABLE_MINDRT,
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
MS_CTX_ENABLE_MEM_SCHEDULER,

View File

@ -76,7 +76,6 @@ def test_forward():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard