forked from mindspore-Ecosystem/mindspore
refactor buffer fusion
This commit is contained in:
parent
da5b10b699
commit
2e2e7a28ae
|
@ -261,23 +261,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
|
|||
return buffer_fusion_kernel;
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list_in,
|
||||
const std::vector<AnfNodePtr> &inputs_list,
|
||||
kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list,
|
||||
const std::vector<AnfNodePtr> &outputs_list) {
|
||||
MS_LOG(DEBUG) << "Start Create Kernel Info";
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
// inputs format and data type
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_data_type;
|
||||
for (auto node : inputs_list_in) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) {
|
||||
if (std::find(inputs_list.begin(), inputs_list.end(), inputs[input_index]) != inputs_list.end()) {
|
||||
inputs_format.push_back(AnfAlgo::GetInputFormat(node, input_index - 1));
|
||||
inputs_data_type.push_back(AnfAlgo::GetInputDeviceDataType(node, input_index - 1));
|
||||
}
|
||||
for (const auto &input : inputs_list) {
|
||||
if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) {
|
||||
auto tuple_getitem = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
inputs_format.push_back(AnfAlgo::GetOutputFormat(
|
||||
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
|
||||
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(
|
||||
tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
|
||||
} else {
|
||||
inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0));
|
||||
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0));
|
||||
}
|
||||
}
|
||||
// outputs format and data type
|
||||
|
@ -360,62 +361,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
|
|||
}
|
||||
}
|
||||
|
||||
void GetInputList(const CNodePtr &node, const int32_t cur_fusion_id, std::vector<AnfNodePtr> *inputs_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inputs_list);
|
||||
auto &inputs = node->inputs();
|
||||
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) {
|
||||
auto input = inputs[input_index];
|
||||
if (AnfAlgo::IsRealCNodeKernel(input)) {
|
||||
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input)) {
|
||||
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input, kOpAttrFusionId);
|
||||
if (fusion_id != cur_fusion_id) {
|
||||
inputs_list->push_back(input);
|
||||
}
|
||||
} else {
|
||||
inputs_list->push_back(input);
|
||||
}
|
||||
} else if (input->isa<CNode>()) {
|
||||
for (auto &input_in : input->cast<CNodePtr>()->inputs()) {
|
||||
if (AnfAlgo::IsRealCNodeKernel(input_in)) {
|
||||
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input_in)) {
|
||||
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input_in, kOpAttrFusionId);
|
||||
if (fusion_id != cur_fusion_id) {
|
||||
inputs_list->push_back(input);
|
||||
}
|
||||
} else {
|
||||
inputs_list->push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
inputs_list->push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
if ((*buffer_fusion_infos).find(cur_fusion_id) == (*buffer_fusion_infos).end()) {
|
||||
BufferFusionInfo_t buffer_fusion_info;
|
||||
(*buffer_fusion_infos)[cur_fusion_id] = buffer_fusion_info;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs_list;
|
||||
GetInputList(node, cur_fusion_id, &inputs_list);
|
||||
if (!inputs_list.empty()) {
|
||||
if (!(*buffer_fusion_infos)[cur_fusion_id].inputs_list.empty()) {
|
||||
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list.insert(
|
||||
(*buffer_fusion_infos)[cur_fusion_id].inputs_list.end(), inputs_list.begin(), inputs_list.end());
|
||||
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.insert(
|
||||
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.end(), node);
|
||||
} else {
|
||||
(*buffer_fusion_infos)[cur_fusion_id].inputs_list = inputs_list;
|
||||
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
|
@ -429,6 +374,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
|
|||
}
|
||||
}
|
||||
|
||||
void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
auto manager = kernel_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
||||
auto fusion_id = buffer_fusion_info.first;
|
||||
auto fusion_info = buffer_fusion_info.second;
|
||||
for (const auto &node : fusion_info.anf_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
|
||||
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
|
||||
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
|
||||
fusion_info.anf_nodes.end()) {
|
||||
if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(),
|
||||
(*buffer_fusion_infos)[fusion_id].inputs_list.end(),
|
||||
cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) {
|
||||
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
auto getitem1 = node1->cast<CNodePtr>();
|
||||
auto getitem2 = node2->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem1);
|
||||
MS_EXCEPTION_IF_NULL(getitem2);
|
||||
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
|
||||
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
|
||||
return output_idx1 < output_idx2;
|
||||
}
|
||||
|
||||
void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -454,14 +438,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
|||
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(),
|
||||
std::back_inserter(tuple_getitem_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
|
||||
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(),
|
||||
[](const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
auto getitem1 = node1->cast<CNodePtr>();
|
||||
auto getitem2 = node2->cast<CNodePtr>();
|
||||
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
|
||||
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
|
||||
return output_idx1 < output_idx2;
|
||||
});
|
||||
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
|
||||
for (auto getitem : tuple_getitem_nodes) {
|
||||
auto getitem_ptr = getitem->cast<CNodePtr>();
|
||||
auto input2 = getitem_ptr->input(2);
|
||||
|
@ -634,24 +611,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|||
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
|
||||
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
|
||||
auto cur_fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId);
|
||||
CheckCurrentNodeIsInput(cnode, cur_fusion_id, buffer_fusion_infos);
|
||||
}
|
||||
}
|
||||
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
|
||||
GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos);
|
||||
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
||||
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
||||
buffer_fusion_info.second.kernel_build_info =
|
||||
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list_in, buffer_fusion_info.second.inputs_list,
|
||||
buffer_fusion_info.second.outputs_list);
|
||||
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ namespace opt {
|
|||
struct BufferFusionInfo_t {
|
||||
std::vector<AnfNodePtr> anf_nodes;
|
||||
std::vector<AnfNodePtr> inputs_list;
|
||||
std::vector<AnfNodePtr> inputs_list_in;
|
||||
std::vector<AnfNodePtr> outputs_list;
|
||||
kernel::KernelBuildInfoPtr kernel_build_info;
|
||||
};
|
||||
|
|
|
@ -816,6 +816,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
|
|||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
|
||||
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
|
||||
{kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
|
||||
{kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
|
||||
{prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
|
||||
{prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
|
||||
{prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
|
||||
|
|
|
@ -122,6 +122,8 @@ constexpr auto kSendOpName = "Send";
|
|||
constexpr auto kRecvOpName = "Recv";
|
||||
constexpr auto kReluV2OpName = "ReLUV2";
|
||||
constexpr auto kReluGradV2OpName = "ReluGradV2";
|
||||
constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2";
|
||||
constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
Loading…
Reference in New Issue