!29782 fix the bug of host device in the control flow
Merge pull request !29782 from limingqi107/r1.6
This commit is contained in:
commit
ebbd3d2bc8
|
@ -657,7 +657,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
ParameterPtr new_parameter = nullptr;
|
||||
auto func_graph = anf->func_graph();
|
||||
if (func_graph->manager() != nullptr && func_graph->IsMultiTarget() &&
|
||||
if (func_graph->manager() != nullptr && func_graph->exist_multi_target() &&
|
||||
graph->device_target() == device::DeviceAddressType::kCPU) {
|
||||
auto iter = default_param_map_.find(anf);
|
||||
if (iter != default_param_map_.end()) {
|
||||
|
|
|
@ -117,7 +117,7 @@ void DisableMindRT(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool enable_old_runtime = (common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "0");
|
||||
if (enable_old_runtime ||
|
||||
(func_graph != nullptr && func_graph->ContainMultiTarget() && IsDynamicShapeGraph(func_graph))) {
|
||||
(func_graph != nullptr && func_graph->exist_multi_target() && IsDynamicShapeGraph(func_graph))) {
|
||||
// Heterogeneous scenario + dynamic_shape runs in MsBackend.
|
||||
MS_LOG(INFO) << "Disable mindRT in the heterogeneous + dynamic shape scenario.";
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
|
@ -783,21 +783,14 @@ void SetRunMode(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
// Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
|
||||
if (func_graph->ContainMultiTarget() && ExistControlNode(func_graph)) {
|
||||
MS_LOG(INFO) << "Run graph mode with kernelbykernel.";
|
||||
set_ctx(false, false, false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
|
||||
if (func_graph->ContainMultiTarget() && ExistControlNode(func_graph)) {
|
||||
if (func_graph->exist_multi_target() && ExistControlNode(func_graph)) {
|
||||
MS_LOG(INFO) << "Run graph mode with kernelbykernel.";
|
||||
set_ctx(false, false, false);
|
||||
return;
|
||||
}
|
||||
|
||||
// GRAPH | Heterogeneous scenario : SubGraph path in MindRT.
|
||||
if (func_graph->ContainMultiTarget()) {
|
||||
if (func_graph->exist_multi_target()) {
|
||||
MS_LOG(INFO) << "Run graph mode with subgraph sink.";
|
||||
set_ctx(true, false, false);
|
||||
return;
|
||||
|
@ -817,7 +810,7 @@ void OriginSetRunMode(const ResourcePtr &res) {
|
|||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
||||
if (func_graph->ContainMultiTarget() || !task_sink) {
|
||||
if (func_graph->exist_multi_target() || !task_sink) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
|
||||
|
@ -845,13 +838,16 @@ void OriginSetRunMode(const ResourcePtr &res) {
|
|||
|
||||
bool TaskEmitAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
|
||||
CheckGraphOutputConstOrParameter(res->func_graph())) {
|
||||
return true;
|
||||
}
|
||||
if (res->func_graph() == nullptr) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "TaskEmit args error";
|
||||
}
|
||||
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
|
||||
CheckGraphOutputConstOrParameter(func_graph)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
func_graph->SetMultiTarget();
|
||||
DisableMindRT(res);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -861,8 +857,6 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
OriginSetRunMode(res);
|
||||
}
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto bc_ptr = res->GetResult(kBackend).cast<compile::BackendPtr>();
|
||||
MS_EXCEPTION_IF_NULL(bc_ptr);
|
||||
std::string backend = context_ptr->backend_policy();
|
||||
|
|
|
@ -438,11 +438,6 @@ void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
|
|||
<< "' failed, value of attribute 'primitive_target' must be CPU|GPU|Ascend but got "
|
||||
<< py::str(obj);
|
||||
}
|
||||
if (target != kCPUDevice && target != kGPUDevice) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
context_ptr->set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, true);
|
||||
}
|
||||
}
|
||||
|
||||
attrs_[attr_name] = converted_ret;
|
||||
|
|
|
@ -401,10 +401,12 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
" position:" + std::to_string(formal_parameter_position) + " copy failed.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
output_data->data_ = device_tensor.get();
|
||||
DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
|
||||
output_data->data_ = device_tensor.get();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ref node may use the ptr of device tensor as the output address, so need set the ptr from data.
|
||||
device_tensor->set_ptr(data->GetMutablePtr());
|
||||
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
|
||||
<< " for the ref formal parameter: " << formal_parameter.first->DebugString()
|
||||
|
|
|
@ -1944,8 +1944,8 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
|
||||
// If the device tensor store of this device type is not exist, then create the new device tensor of this type.
|
||||
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
|
||||
MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
|
||||
<< ", type:" << device_context->GetDeviceAddressType();
|
||||
MS_LOG(WARNING) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
|
||||
<< ", type:" << device_context->GetDeviceAddressType();
|
||||
auto other_type_device_tensor = device_context->CreateDeviceAddress(
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
||||
other_type_device_tensor->SetNodeIndex(input_node, 0);
|
||||
|
|
|
@ -601,22 +601,22 @@ BackendPtr CreateBackend() {
|
|||
void SetMindRTEnable() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT)) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (common::GetEnv("DISABLE_ASCEND_MINDRT") == "1" && target == kAscendDevice) {
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
return;
|
||||
#endif
|
||||
|
||||
|
|
|
@ -754,18 +754,25 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
|
|||
return parameter;
|
||||
}
|
||||
|
||||
bool FuncGraph::ContainMultiTarget() {
|
||||
void FuncGraph::SetMultiTarget() {
|
||||
auto graph_manager = manager();
|
||||
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||
FuncGraphSet graphs = graph_manager->func_graphs();
|
||||
std::vector<AnfNodePtr> all_nodes;
|
||||
for (auto &g : graphs) {
|
||||
auto nodes = mindspore::TopoSort(g->get_return());
|
||||
if (mindspore::ContainMultiTarget(nodes)) {
|
||||
exist_multi_target_ = true;
|
||||
return true;
|
||||
}
|
||||
(void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(all_nodes));
|
||||
}
|
||||
|
||||
bool exist_multi_target = false;
|
||||
if (mindspore::ContainMultiTarget(all_nodes)) {
|
||||
exist_multi_target = true;
|
||||
MS_LOG(INFO) << "The graph " << ToString() << " exists the multi target.";
|
||||
}
|
||||
|
||||
for (auto &g : graphs) {
|
||||
g->set_exist_multi_target(exist_multi_target);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes) {
|
||||
|
|
|
@ -348,8 +348,9 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
|
|||
void set_switch_layer_input(const std::shared_ptr<bool> &switch_layer_input) {
|
||||
switch_layer_input_ = switch_layer_input;
|
||||
}
|
||||
bool ContainMultiTarget();
|
||||
bool IsMultiTarget() const { return exist_multi_target_; }
|
||||
void SetMultiTarget();
|
||||
bool exist_multi_target() const { return exist_multi_target_; }
|
||||
void set_exist_multi_target(bool exist_multi_target) { exist_multi_target_ = exist_multi_target; }
|
||||
int64_t stage() const { return stage_; }
|
||||
void set_stage(int64_t stage) { stage_ = stage; }
|
||||
|
||||
|
|
|
@ -93,7 +93,6 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
||||
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
||||
|
|
|
@ -89,7 +89,6 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_INFER_OPT,
|
||||
MS_CTX_GRAD_FOR_SCALAR,
|
||||
MS_CTX_ENABLE_MINDRT,
|
||||
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
|
||||
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
|
||||
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
||||
MS_CTX_ENABLE_MEM_SCHEDULER,
|
||||
|
|
|
@ -76,7 +76,6 @@ def test_forward():
|
|||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue