!18858 optimize-get-node-target

Merge pull request !18858 from kisnwang/optimize-get-node-target
This commit is contained in:
i-robot 2021-06-28 02:11:09 +00:00 committed by Gitee
commit 2acae09a92
4 changed files with 154 additions and 106 deletions

View File

@ -339,7 +339,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
}
}
if (need_sync) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) || UpdatedByAssign(kernel_graph, input_node) ||
if (AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) ||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
tensor->set_device_address(device_address);
}
@ -349,6 +349,9 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (kernel_graph->IsUpdatedParameter(pk_node)) {
tensor->SetIsUpdateByDevice();
}
}
}
tensor->set_sync_status(kNoNeedSync);

View File

@ -360,9 +360,15 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
return;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool need_convert = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
if (need_convert) {
need_convert =
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end();
}
// if not find in host convert format map means the host has not registered the convert function of this format
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT &&
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end()) {
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT && !need_convert) {
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
@ -414,8 +420,6 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string
}
} // namespace
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(kernel_node);
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info);
@ -434,12 +438,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
}
auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
std::vector<std::string> output_format = {refresh_format};
// if not find in host convert format map means the host has not registered the convert function of this format
if ((trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() ||
!context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) &&
refresh_format != kOpFormat_DEFAULT) {
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
}
SetWeightFormat(real_input_node, output_format, kernel_node, input_index);
}
}

View File

@ -369,61 +369,9 @@ std::string get_id(const AnfNodePtr &node) {
void reset_id() { node_ids.clear(); }
} // namespace id_generator
auto constexpr kTargetUnDefined = "kTargetUnDefined";
auto constexpr kPrimitiveTarget = "primitive_target";
namespace {
std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto users = manager->node_users()[cnode];
std::string first_user_target = GetCNodeTarget(users.back().first);
bool is_used_by_different_target =
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) -> bool {
return GetCNodeTarget(u.first) != first_user_target;
});
if (!is_used_by_different_target) {
return first_user_target;
}
auto inputs = cnode->inputs();
std::vector<AnfNodePtr> real_inputs;
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
std::string first_input_target = GetCNodeTarget(real_inputs[0]);
bool is_from_different_target =
std::any_of(std::begin(real_inputs), std::end(real_inputs),
[&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; });
if (!is_from_different_target) {
return first_input_target;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
return default_target;
}
std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input,
const std::string &primitive_target, const std::string &default_target) {
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) {
primitive->EraseAttr(primitive_target);
return default_target;
}
MS_EXCEPTION_IF_NULL(att_target);
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
}
return target;
}
PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
@ -438,53 +386,151 @@ PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
}
return value->cast<PrimitivePtr>();
}
std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
if (inputs.size() > 1) {
return GetOriginNodeTarget(inputs[1]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
const size_t node_inputs_num = 3;
if (inputs.size() >= node_inputs_num) {
size_t use_index = 1;
if (!inputs[use_index]->isa<CNode>()) {
use_index = 2;
}
return GetOriginNodeTarget(inputs[use_index]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
const size_t node_inputs_num = 3;
if (inputs.size() >= node_inputs_num) {
return GetOriginNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> real_inputs;
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
std::string first_input_target = kTargetUnDefined;
bool has_same_target =
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
auto target = GetOriginNodeTarget(n);
if (target != kTargetUnDefined) {
first_input_target = target;
}
return target == first_input_target;
});
if (has_same_target) {
return first_input_target;
}
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetOriginNodeTarget(cnode->input(1));
}
return kTargetUnDefined;
}
std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto func_graph = cnode->func_graph();
if (func_graph == nullptr) {
return kTargetUnDefined;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
return kTargetUnDefined;
}
auto users = manager->node_users()[cnode];
std::string first_user_target = kTargetUnDefined;
bool has_same_target =
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
auto target = GetOriginNodeTarget(u.first);
if (target != kTargetUnDefined) {
first_user_target = target;
}
return target == first_user_target;
});
if (has_same_target) {
return first_user_target;
}
return kTargetUnDefined;
}
std::string GetVirtualNodeTarget(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(kTargetUnDefined));
auto target = GetVirtualNodeTargetFromInputs(node);
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
if (target != kTargetUnDefined) {
return target;
}
target = GetVirtualNodeTargetFromUsers(node);
node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
return target;
}
std::string GetTargetFromAttr(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = cnode->input(0);
auto primitive = GetPrimitiveFromValueNode(attr_input);
if (primitive == nullptr) {
return kTargetUnDefined;
}
auto att_target = primitive->GetAttr(kPrimitiveTarget);
if (att_target != nullptr) {
if (!att_target->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
}
auto target = GetValue<std::string>(att_target);
if (kTargetSet.find(target) == kTargetSet.end()) {
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
}
return target;
}
return kTargetUnDefined;
}
} // namespace
std::string GetOriginNodeTarget(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return kTargetUnDefined;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto ud_target = cnode->user_data<std::string>(kPrimitiveTarget);
if (ud_target != nullptr) {
return *ud_target.get();
}
auto target = GetTargetFromAttr(node);
if (target != kTargetUnDefined) {
return target;
}
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetVirtualNodeTarget(node);
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
}
std::string GetCNodeTarget(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (!node->isa<CNode>()) {
return default_target;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const std::string primitive_target = "primitive_target";
auto ud_target = cnode->user_data<std::string>(primitive_target);
if (ud_target != nullptr) {
return *ud_target.get();
}
auto attr_input = cnode->input(0);
auto primitive = GetPrimitiveFromValueNode(attr_input);
if (primitive == nullptr) {
return default_target;
}
auto att_target = primitive->GetAttr(primitive_target);
if (att_target != nullptr) {
return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target);
}
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const size_t depend_node_inputs_num = 3;
auto &inputs = cnode->inputs();
if (inputs.size() >= depend_node_inputs_num) {
size_t use_index = 1;
if (!inputs[use_index]->isa<CNode>()) {
use_index = 2;
}
if (!IsPrimitiveCNode(inputs[use_index], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[use_index]);
}
}
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
const size_t update_state_node_inputs_num = 3;
auto &inputs = cnode->inputs();
if (inputs.size() >= update_state_node_inputs_num && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return GetMaketupleNodeTarget(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetCNodeTarget(cnode->input(1));
auto target = GetOriginNodeTarget(node);
if (target != kTargetUnDefined) {
return target;
}
return default_target;
}

View File

@ -642,6 +642,7 @@ void reset_id();
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
std::string GetCNodeTarget(const AnfNodePtr &node);
std::string GetOriginNodeTarget(const AnfNodePtr &node);
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
struct GraphSegment {
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}