From bfab67a206ef383db4369cf84f28810107af2a02 Mon Sep 17 00:00:00 2001 From: kswang Date: Fri, 25 Jun 2021 11:29:41 +0800 Subject: [PATCH] optimize node get target --- .../ccsrc/backend/session/gpu_session.cc | 5 +- .../device/ascend/kernel_select_ascend.cc | 18 +- mindspore/core/ir/anf.cc | 236 +++++++++++------- mindspore/core/ir/anf.h | 1 + 4 files changed, 154 insertions(+), 106 deletions(-) diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index c1e0c93bd19..1d006c61004 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -339,7 +339,7 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, } } if (need_sync) { - if (AnfAlgo::IsParameterWeight(input_node->cast()) || UpdatedByAssign(kernel_graph, input_node) || + if (AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) || ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { tensor->set_device_address(device_address); } @@ -349,6 +349,9 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, tensor->data_c())) { MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } + if (kernel_graph->IsUpdatedParameter(pk_node)) { + tensor->SetIsUpdateByDevice(); + } } } tensor->set_sync_status(kNoNeedSync); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 37243754ef3..507bcbb6c1e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -360,9 +360,15 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector outp if (real_input_node->isa() || AnfAlgo::OutputAddrExist(real_input_node, 0)) { return; } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool need_convert = context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK); + if (need_convert) { + need_convert = + trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end(); + } // if not find in host convert format map means the host has not registered the convert function of this format - if (real_input_node->isa() && output_format[0] != kOpFormat_DEFAULT && - trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end()) { + if (real_input_node->isa() && output_format[0] != kOpFormat_DEFAULT && !need_convert) { output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; } auto builder = std::make_shared(); @@ -414,8 +420,6 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string } } // namespace void SetTensorDeviceInfo(const CNodePtr &kernel_node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(kernel_node); auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); MS_EXCEPTION_IF_NULL(selected_kernel_info); @@ -434,12 +438,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { } auto refresh_format = selected_kernel_info->GetInputFormat(input_index); std::vector output_format = {refresh_format}; - // if not find in host convert format map means the host has not registered the convert function of this format - if ((trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() || - !context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK)) && - refresh_format != kOpFormat_DEFAULT) { - output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; - } SetWeightFormat(real_input_node, output_format, kernel_node, input_index); } } diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index a021a8799c1..32230204b89 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -369,61 +369,9 @@ std::string get_id(const AnfNodePtr &node) { void reset_id() { node_ids.clear(); } } // namespace id_generator - +auto constexpr kTargetUnDefined = "kTargetUnDefined"; +auto constexpr kPrimitiveTarget = "primitive_target"; namespace { -std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto users = manager->node_users()[cnode]; - std::string first_user_target = GetCNodeTarget(users.back().first); - bool is_used_by_different_target = - std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair &u) -> bool { - return GetCNodeTarget(u.first) != first_user_target; - }); - if (!is_used_by_different_target) { - return first_user_target; - } - - auto inputs = cnode->inputs(); - std::vector real_inputs; - std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); - std::string first_input_target = GetCNodeTarget(real_inputs[0]); - bool is_from_different_target = - std::any_of(std::begin(real_inputs), std::end(real_inputs), - [&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); - if (!is_from_different_target) { - return first_input_target; - } - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - return default_target; -} - -std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input, - const std::string &primitive_target, const std::string &default_target) { - if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || - IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || - IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || - IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { - primitive->EraseAttr(primitive_target); - return default_target; - } - MS_EXCEPTION_IF_NULL(att_target); - if (!att_target->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - auto target = GetValue(att_target); - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; - } - return target; -} - PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; @@ -438,53 +386,151 @@ PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) { } return value->cast(); } + +std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) || + IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) { + if (inputs.size() > 1) { + return GetOriginNodeTarget(inputs[1]); + } + } else if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) { + const size_t node_inputs_num = 3; + if (inputs.size() >= node_inputs_num) { + size_t use_index = 1; + if (!inputs[use_index]->isa()) { + use_index = 2; + } + return GetOriginNodeTarget(inputs[use_index]); + } + } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + const size_t node_inputs_num = 3; + if (inputs.size() >= node_inputs_num) { + return GetOriginNodeTarget(inputs[2]); + } + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + std::vector real_inputs; + std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); + std::string first_input_target = kTargetUnDefined; + bool has_same_target = + std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) { + auto target = GetOriginNodeTarget(n); + if (target != kTargetUnDefined) { + first_input_target = target; + } + return target == first_input_target; + }); + if (has_same_target) { + return first_input_target; + } + } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + return GetOriginNodeTarget(cnode->input(1)); + } + return kTargetUnDefined; +} + +std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto func_graph = cnode->func_graph(); + if (func_graph == nullptr) { + return kTargetUnDefined; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + return kTargetUnDefined; + } + auto users = manager->node_users()[cnode]; + std::string first_user_target = kTargetUnDefined; + bool has_same_target = + std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair &u) { + auto target = GetOriginNodeTarget(u.first); + if (target != kTargetUnDefined) { + first_user_target = target; + } + return target == first_user_target; + }); + if (has_same_target) { + return first_user_target; + } + return kTargetUnDefined; +} + +std::string GetVirtualNodeTarget(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + node->set_user_data(kPrimitiveTarget, std::make_shared(kTargetUnDefined)); + auto target = GetVirtualNodeTargetFromInputs(node); + node->set_user_data(kPrimitiveTarget, std::make_shared(target)); + if (target != kTargetUnDefined) { + return target; + } + target = GetVirtualNodeTargetFromUsers(node); + node->set_user_data(kPrimitiveTarget, std::make_shared(target)); + return target; +} + +std::string GetTargetFromAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(0); + auto primitive = GetPrimitiveFromValueNode(attr_input); + if (primitive == nullptr) { + return kTargetUnDefined; + } + auto att_target = primitive->GetAttr(kPrimitiveTarget); + if (att_target != nullptr) { + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; + } + return target; + } + return kTargetUnDefined; +} } // namespace +std::string GetOriginNodeTarget(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return kTargetUnDefined; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto ud_target = cnode->user_data(kPrimitiveTarget); + if (ud_target != nullptr) { + return *ud_target.get(); + } + auto target = GetTargetFromAttr(node); + if (target != kTargetUnDefined) { + return target; + } + if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) || + IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) || + IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) || + IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || + IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + return GetVirtualNodeTarget(node); + } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + return context_ptr->get_param(MS_CTX_DEVICE_TARGET); +} + std::string GetCNodeTarget(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - if (!node->isa()) { - return default_target; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - const std::string primitive_target = "primitive_target"; - auto ud_target = cnode->user_data(primitive_target); - if (ud_target != nullptr) { - return *ud_target.get(); - } - auto attr_input = cnode->input(0); - auto primitive = GetPrimitiveFromValueNode(attr_input); - if (primitive == nullptr) { - return default_target; - } - auto att_target = primitive->GetAttr(primitive_target); - if (att_target != nullptr) { - return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target); - } - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - const size_t depend_node_inputs_num = 3; - auto &inputs = cnode->inputs(); - if (inputs.size() >= depend_node_inputs_num) { - size_t use_index = 1; - if (!inputs[use_index]->isa()) { - use_index = 2; - } - if (!IsPrimitiveCNode(inputs[use_index], prim::kPrimMakeTuple)) { - return GetCNodeTarget(inputs[use_index]); - } - } - } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { - const size_t update_state_node_inputs_num = 3; - auto &inputs = cnode->inputs(); - if (inputs.size() >= update_state_node_inputs_num && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) { - return GetCNodeTarget(inputs[2]); - } - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - return GetMaketupleNodeTarget(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - return GetCNodeTarget(cnode->input(1)); + auto target = GetOriginNodeTarget(node); + if (target != kTargetUnDefined) { + return target; } return default_target; } diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 84a9d31b4b1..d1c731d49f1 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -642,6 +642,7 @@ void reset_id(); using TaggedNodeMap = std::unordered_map; using TaggedGraph = std::pair; std::string GetCNodeTarget(const AnfNodePtr &node); +std::string GetOriginNodeTarget(const AnfNodePtr &node); bool ContainMultiTarget(const std::vector &nodes); struct GraphSegment { GraphSegment(const std::vector &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}