forked from mindspore-Ecosystem/mindspore
!18858 optimize-get-node-target
Merge pull request !18858 from kisnwang/optimize-get-node-target
This commit is contained in:
commit
2acae09a92
|
@ -339,7 +339,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
}
|
||||
}
|
||||
if (need_sync) {
|
||||
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) || UpdatedByAssign(kernel_graph, input_node) ||
|
||||
if (AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) ||
|
||||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
|
@ -349,6 +349,9 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
if (kernel_graph->IsUpdatedParameter(pk_node)) {
|
||||
tensor->SetIsUpdateByDevice();
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
|
|
|
@ -360,9 +360,15 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
|
|||
if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
|
||||
return;
|
||||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool need_convert = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
|
||||
if (need_convert) {
|
||||
need_convert =
|
||||
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end();
|
||||
}
|
||||
// if not find in host convert format map means the host has not registered the convert function of this format
|
||||
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT &&
|
||||
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end()) {
|
||||
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT && !need_convert) {
|
||||
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
@ -414,8 +420,6 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string
|
|||
}
|
||||
} // namespace
|
||||
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
||||
|
@ -434,12 +438,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|||
}
|
||||
auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
|
||||
std::vector<std::string> output_format = {refresh_format};
|
||||
// if not find in host convert format map means the host has not registered the convert function of this format
|
||||
if ((trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() ||
|
||||
!context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) &&
|
||||
refresh_format != kOpFormat_DEFAULT) {
|
||||
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
}
|
||||
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -369,61 +369,9 @@ std::string get_id(const AnfNodePtr &node) {
|
|||
|
||||
void reset_id() { node_ids.clear(); }
|
||||
} // namespace id_generator
|
||||
|
||||
auto constexpr kTargetUnDefined = "kTargetUnDefined";
|
||||
auto constexpr kPrimitiveTarget = "primitive_target";
|
||||
namespace {
|
||||
std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto users = manager->node_users()[cnode];
|
||||
std::string first_user_target = GetCNodeTarget(users.back().first);
|
||||
bool is_used_by_different_target =
|
||||
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) -> bool {
|
||||
return GetCNodeTarget(u.first) != first_user_target;
|
||||
});
|
||||
if (!is_used_by_different_target) {
|
||||
return first_user_target;
|
||||
}
|
||||
|
||||
auto inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
|
||||
std::string first_input_target = GetCNodeTarget(real_inputs[0]);
|
||||
bool is_from_different_target =
|
||||
std::any_of(std::begin(real_inputs), std::end(real_inputs),
|
||||
[&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; });
|
||||
if (!is_from_different_target) {
|
||||
return first_input_target;
|
||||
}
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
return default_target;
|
||||
}
|
||||
|
||||
std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input,
|
||||
const std::string &primitive_target, const std::string &default_target) {
|
||||
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
|
||||
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
|
||||
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
|
||||
IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) {
|
||||
primitive->EraseAttr(primitive_target);
|
||||
return default_target;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(att_target);
|
||||
if (!att_target->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
auto target = GetValue<std::string>(att_target);
|
||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -438,53 +386,151 @@ PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
|
|||
}
|
||||
return value->cast<PrimitivePtr>();
|
||||
}
|
||||
|
||||
std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
|
||||
if (inputs.size() > 1) {
|
||||
return GetOriginNodeTarget(inputs[1]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
const size_t node_inputs_num = 3;
|
||||
if (inputs.size() >= node_inputs_num) {
|
||||
size_t use_index = 1;
|
||||
if (!inputs[use_index]->isa<CNode>()) {
|
||||
use_index = 2;
|
||||
}
|
||||
return GetOriginNodeTarget(inputs[use_index]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
const size_t node_inputs_num = 3;
|
||||
if (inputs.size() >= node_inputs_num) {
|
||||
return GetOriginNodeTarget(inputs[2]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
|
||||
std::string first_input_target = kTargetUnDefined;
|
||||
bool has_same_target =
|
||||
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
|
||||
auto target = GetOriginNodeTarget(n);
|
||||
if (target != kTargetUnDefined) {
|
||||
first_input_target = target;
|
||||
}
|
||||
return target == first_input_target;
|
||||
});
|
||||
if (has_same_target) {
|
||||
return first_input_target;
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return GetOriginNodeTarget(cnode->input(1));
|
||||
}
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
|
||||
std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto func_graph = cnode->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
auto users = manager->node_users()[cnode];
|
||||
std::string first_user_target = kTargetUnDefined;
|
||||
bool has_same_target =
|
||||
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
|
||||
auto target = GetOriginNodeTarget(u.first);
|
||||
if (target != kTargetUnDefined) {
|
||||
first_user_target = target;
|
||||
}
|
||||
return target == first_user_target;
|
||||
});
|
||||
if (has_same_target) {
|
||||
return first_user_target;
|
||||
}
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
|
||||
std::string GetVirtualNodeTarget(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(kTargetUnDefined));
|
||||
auto target = GetVirtualNodeTargetFromInputs(node);
|
||||
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
|
||||
if (target != kTargetUnDefined) {
|
||||
return target;
|
||||
}
|
||||
target = GetVirtualNodeTargetFromUsers(node);
|
||||
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
|
||||
return target;
|
||||
}
|
||||
|
||||
std::string GetTargetFromAttr(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto attr_input = cnode->input(0);
|
||||
auto primitive = GetPrimitiveFromValueNode(attr_input);
|
||||
if (primitive == nullptr) {
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
auto att_target = primitive->GetAttr(kPrimitiveTarget);
|
||||
if (att_target != nullptr) {
|
||||
if (!att_target->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||
}
|
||||
auto target = GetValue<std::string>(att_target);
|
||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
|
||||
}
|
||||
return target;
|
||||
}
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string GetOriginNodeTarget(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto ud_target = cnode->user_data<std::string>(kPrimitiveTarget);
|
||||
if (ud_target != nullptr) {
|
||||
return *ud_target.get();
|
||||
}
|
||||
auto target = GetTargetFromAttr(node);
|
||||
if (target != kTargetUnDefined) {
|
||||
return target;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return GetVirtualNodeTarget(node);
|
||||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
}
|
||||
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (!node->isa<CNode>()) {
|
||||
return default_target;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::string primitive_target = "primitive_target";
|
||||
auto ud_target = cnode->user_data<std::string>(primitive_target);
|
||||
if (ud_target != nullptr) {
|
||||
return *ud_target.get();
|
||||
}
|
||||
auto attr_input = cnode->input(0);
|
||||
auto primitive = GetPrimitiveFromValueNode(attr_input);
|
||||
if (primitive == nullptr) {
|
||||
return default_target;
|
||||
}
|
||||
auto att_target = primitive->GetAttr(primitive_target);
|
||||
if (att_target != nullptr) {
|
||||
return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target);
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
const size_t depend_node_inputs_num = 3;
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.size() >= depend_node_inputs_num) {
|
||||
size_t use_index = 1;
|
||||
if (!inputs[use_index]->isa<CNode>()) {
|
||||
use_index = 2;
|
||||
}
|
||||
if (!IsPrimitiveCNode(inputs[use_index], prim::kPrimMakeTuple)) {
|
||||
return GetCNodeTarget(inputs[use_index]);
|
||||
}
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
const size_t update_state_node_inputs_num = 3;
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.size() >= update_state_node_inputs_num && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
|
||||
return GetCNodeTarget(inputs[2]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
return GetMaketupleNodeTarget(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return GetCNodeTarget(cnode->input(1));
|
||||
auto target = GetOriginNodeTarget(node);
|
||||
if (target != kTargetUnDefined) {
|
||||
return target;
|
||||
}
|
||||
return default_target;
|
||||
}
|
||||
|
|
|
@ -642,6 +642,7 @@ void reset_id();
|
|||
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
|
||||
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
|
||||
std::string GetCNodeTarget(const AnfNodePtr &node);
|
||||
std::string GetOriginNodeTarget(const AnfNodePtr &node);
|
||||
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
|
||||
struct GraphSegment {
|
||||
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
|
||||
|
|
Loading…
Reference in New Issue