!55337 [graph kernel] add GraphKernelOpCombiner pass

Merge pull request !55337 from yangsijia/graph-rewrite
This commit is contained in:
i-robot 2023-06-27 08:13:56 +00:00 committed by Gitee
commit 2e7652bf66
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 1266 additions and 7 deletions

View File

@ -62,6 +62,7 @@
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
#include "backend/common/graph_kernel/compact_tensor_liveness.h"
#include "backend/common/graph_kernel/core/convert_op_input_attr.h"
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
#ifdef ENABLE_AKG
#include "backend/common/graph_kernel/graph_kernel_build.h"
@ -110,6 +111,10 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
// Expand complex basic kernels to composite kernels
pm->Add(std::make_shared<GraphKernelExpanderWithPy>(), OptLevel_1);
if (GraphKernelFlags::GetInstance().enable_parallel_op_combine) {
pm->Add(std::make_shared<GraphKernelOpCombiner>(), OptLevel_2);
}
// Cluster basic kernels and composite kernels
pm->Add(std::make_shared<GraphKernelCluster>(), OptLevel_1);

View File

@ -0,0 +1,39 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
namespace mindspore::graphkernel {
bool GraphKernelOpCombiner::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto cb = Callback::Instance();
auto nodes = TopoSort(func_graph->get_return());
auto changed = false;
for (auto node : nodes) {
if (node->cast<CNodePtr>() == nullptr || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto node_name = node->fullname_with_scope();
auto res = ConcatParallelMatMul(node, min_ops_to_combine_, default_layout_to_combine_, func_graph);
if (res != nullptr) {
changed = true;
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,39 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_
#include <memory>
#include <string>
#include "include/backend/optimizer/pass.h"
#include "ir/func_graph.h"
#include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
namespace mindspore::graphkernel {
class GraphKernelOpCombiner : public opt::Pass {
public:
GraphKernelOpCombiner() : Pass("graph_kernel_op_combiner") {}
~GraphKernelOpCombiner() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
uint64_t min_ops_to_combine_{2};
std::string default_layout_to_combine_{kOpFormat_NCHW};
};
using GraphKernelOpCombinerPtr = std::shared_ptr<GraphKernelOpCombiner>;
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_

View File

@ -316,15 +316,20 @@ inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
auto todos = TopoSort(func_graph->output());
const auto &params = func_graph->parameters();
auto cb = Callback::Instance();
auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) {
auto shape = cb->GetOutputShape(node, 0);
auto type = cb->GetOutputType(node, 0);
auto format = cb->GetOutputFormat(node, 0);
return inner::NodeBase({shape, type, format});
auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) -> inner::NodeBaseList {
inner::NodeBaseList listinfo;
size_t output_num = AnfUtils::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
auto shape = cb->GetOutputShape(node, i);
auto type = cb->GetOutputType(node, i);
auto format = cb->GetOutputFormat(node, i);
listinfo.push_back(inner::NodeBase({shape, type, format}));
}
return listinfo;
};
// set inputs
for (auto &p : params) {
node_map[p] = gb.Parameter(ExtractBuildInfo(p));
node_map[p] = gb.Parameter(ExtractBuildInfo(p)[0]);
}
// set ops
for (auto node : todos) {
@ -332,7 +337,7 @@ inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
if (cnode == nullptr) {
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
if (node == func_graph->output() && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
break;
}
auto prim = GetCNodePrimitive(cnode);

View File

@ -0,0 +1,163 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
#include "base/base.h"
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
namespace {
MMAttr GetMatMulTransposeAttr(const CNodePtr &matmul) {
auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
if (mm_attrs.count(kTransposeA) == 0 || mm_attrs.count(kTransposeB) == 0) {
MS_LOG(WARNING) << "Can not find attr 'transpose_a' or 'transpose_b' in node " << matmul->fullname_with_scope();
return std::make_pair(false, false);
}
auto trans_a = GetValue<bool>(mm_attrs[kTransposeA]);
auto trans_b = GetValue<bool>(mm_attrs[kTransposeB]);
return std::make_pair(trans_a, trans_b);
}
CNodePtr NewMatMulNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs, const CNodePtr &orig_matmul,
ShapeVector new_out_shape) {
auto matmul = func_graph->NewCNode(matmul_inputs);
func_graph->AddNode(matmul);
MS_EXCEPTION_IF_NULL(matmul);
MS_EXCEPTION_IF_NULL(matmul_inputs[1]);
auto orig_cnode = matmul_inputs[1]->cast<CNodePtr>();
if (orig_cnode != nullptr && orig_cnode->HasAttr(kOutputsFormat)) {
auto input_format = GetValue<std::vector<std::string>>(orig_cnode->GetAttr(kOutputsFormat))[0];
std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(matmul), input_format);
matmul->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
}
auto [trans_a, trans_b] = GetMatMulTransposeAttr(orig_matmul);
matmul->AddAttr(kTransposeA, MakeValue(trans_a));
matmul->AddAttr(kTransposeB, MakeValue(trans_b));
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(matmul_inputs[1], 0)};
std::vector<ShapeVector> shapes = {new_out_shape};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, matmul.get());
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
return matmul;
}
BMNK GetBatchMNKV1(const CNodePtr &matmul) {
size_t b, m, n, k = 0;
auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
if (shape_a.size() == kDim3 && shape_b.size() == kDim3 && shape_a[kIndex0] == shape_b[kIndex0]) {
b = shape_a[kIndex0];
shape_a.erase(shape_a.begin());
shape_b.erase(shape_b.begin());
} else {
b = 1;
}
m = trans_a ? shape_a[kIndex1] : shape_a[kIndex0];
k = trans_a ? shape_a[kIndex0] : shape_a[kIndex1];
n = trans_b ? shape_b[kIndex0] : shape_b[kIndex1];
return std::tuple(b, m, n, k);
}
} // namespace
ConcatenatePlan ParallelMatMulConcatenater::Analyse(const Group &branches) {
ConcatenatePlan target_op_res;
Branch b0 = branches[kIndex0];
AnfNodePtr shared_input = b0.GetRootData();
target_op_res.in_shape = Callback::Instance()->GetOutputInferShape(shared_input, kIndex0);
auto matmul = b0.GetTargetOp()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(matmul);
bool is_a_shared = false;
for (size_t i = 1; i < matmul->size(); ++i) {
auto in = matmul->input(i);
if (in == shared_input) {
is_a_shared = i == kIndex1;
break;
}
}
auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
int64_t b, m, n, k;
std::tie(b, m, n, k) = GetBatchMNKV1(matmul);
if (is_a_shared) {
auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
size_t rank_b = shape_b.size();
auto n_idx = trans_b ? rank_b - kIndex2 : rank_b - kIndex1;
target_op_res.concat_in_idx = n_idx;
target_op_res.split_out_idx = rank_b - kIndex1;
int64_t new_n = n * branches.size();
if (rank_b == kDim3) {
target_op_res.out_shape = ShapeVector({b, m, new_n});
} else {
target_op_res.out_shape = ShapeVector({m, new_n});
}
} else {
auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
size_t rank_a = shape_a.size();
auto m_idx = trans_a ? rank_a - kIndex1 : rank_a - kIndex2;
target_op_res.concat_in_idx = m_idx;
target_op_res.split_out_idx = rank_a - kIndex2;
int64_t new_m = m * branches.size();
if (rank_a == kDim3) {
target_op_res.out_shape = ShapeVector({b, new_m, n});
} else {
target_op_res.out_shape = ShapeVector({new_m, n});
}
}
return target_op_res;
}
bool ParallelMatMulConcatenater::CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) {
auto matmul1 = a->cast<CNodePtr>();
auto matmul2 = b->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(matmul1);
MS_EXCEPTION_IF_NULL(matmul2);
auto [trans_a1, trans_b1] = GetMatMulTransposeAttr(matmul1);
auto [trans_a2, trans_b2] = GetMatMulTransposeAttr(matmul2);
return trans_a1 == trans_a2 && trans_b1 == trans_b2;
}
bool ParallelMatMulConcatenater::IsSupportedOp(const AnfNodePtr n) {
if (n->cast<CNodePtr>() == nullptr || unsupported_ops_.count(GetCNodePrimitive(n)->name())) {
return false;
}
return true;
}
AnfNodePtr ParallelMatMulConcatenater::MakeCombinedOp(const Group &branches) {
Branch b1 = branches[0];
AnfNodePtr shared_input = b1.GetRootData();
auto matmul_op = b1.GetTargetOp()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(matmul_op);
auto plan = Analyse(branches);
plans_.push_back(plan);
auto overall_inputs = ReloadInputs(branches, b1.target_op_pos, shared_input);
auto matmul = NewMatMulNode(main_graph_, overall_inputs, matmul_op, plan.out_shape);
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(matmul), "AutoUpdateInfo fail");
return matmul;
}
bool ParallelMatMulConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) { return true; }
AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
const FuncGraphPtr &func_graph) {
if (layout == kOpFormat_NCHW) {
auto res = ParallelMatMulConcatenater(min_num_branches, layout).Combine(root, func_graph);
return res;
}
MS_LOG(WARNING) << "Not supported combine for layout " << layout;
return root;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,53 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_
#include <vector>
#include <string>
#include <set>
#include <tuple>
#include <unordered_set>
#include <utility>
#include <memory>
#include "include/backend/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "kernel/common_utils.h"
#include "backend/common/graph_kernel/graph_kernel_helper.h"
#include "backend/common/graph_kernel/core/parallel_op_concatenate.h"
namespace mindspore::graphkernel {
using BMNK = std::tuple<size_t, size_t, size_t, size_t>;
using MMAttr = std::pair<bool, bool>;
class ParallelMatMulConcatenater : public ParallelOpConcatenater {
public:
explicit ParallelMatMulConcatenater(uint64_t min_num_branches, const std::string &layout)
: ParallelOpConcatenater("MatMul", min_num_branches, layout) {}
protected:
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b);
virtual bool IsSupportedOp(const AnfNodePtr n);
virtual AnfNodePtr MakeCombinedOp(const Group &branches);
bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) override;
private:
ConcatenatePlan Analyse(const Group &branches);
};
AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
const FuncGraphPtr &func_graph = nullptr);
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_

View File

@ -0,0 +1,443 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/graph_kernel/core/parallel_op_combine.h"
#include <vector>
#include <string>
#include <set>
#include <deque>
#include <utility>
#include <algorithm>
#include <unordered_set>
#include "include/backend/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "kernel/common_utils.h"
#include "backend/common/graph_kernel/graph_kernel_helper.h"
#include "include/backend/kernel_graph.h"
#include "utils/anf_utils.h"
#include "include/common/utils/utils.h"
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
#include "utils/ms_context.h"
#include "backend/common/graph_kernel/adapter/callback_impl.h"
namespace mindspore::graphkernel {
namespace {
constexpr auto kPerm = "perm";
constexpr auto kShape = "shape";
std::vector<int64_t> GetTransposePerm(const PrimitivePtr &primitive) {
ValuePtr perm = primitive->GetAttr(kPerm);
MS_EXCEPTION_IF_NULL(perm);
auto perm_val = perm->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(perm_val);
auto perm_val_data = perm_val->value();
std::vector<int64_t> perm_int;
(void)std::transform(perm_val_data.begin(), perm_val_data.end(), std::back_inserter(perm_int),
[=](const ValuePtr &e) -> int64_t {
if (e->isa<Int64Imm>()) {
return GetValue<int64_t>(e);
} else if (e->isa<Int32Imm>()) {
return GetValue<int>(e);
} else {
MS_LOG(EXCEPTION) << "Perm must be int";
return -1;
}
});
return perm_int;
}
} // namespace
BranchGroupFinder::BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) {}
AnfNodeIndexSet BranchGroupFinder::GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer) {
AnfNodeIndexSet consumers;
auto users = mng->node_users()[producer];
for (auto it : users) {
auto user = it.first;
if (user->cast<CNodePtr>() && AnfUtils::IsRealKernel(user) && fis_supported_op_(user)) {
consumers.add(CNodeIndexPair(it.first, it.second));
children_map_[producer].insert(user);
}
}
return consumers;
}
std::vector<Group> BranchGroupFinder::Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph) {
auto graph_kernel_fg = func_graph == nullptr ? common::AnfAlgo::GetCNodeFuncGraphPtr(start_node) : func_graph;
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
auto mng = graph_kernel_fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto cnode = start_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::deque<AnfNodePtr> init_consumer;
std::transform(graph_kernel_fg->parameters().begin(), graph_kernel_fg->parameters().end(),
std::back_inserter(init_consumer), [](const AnfNodePtr &global_in) { return global_in; });
for (size_t i = 1; i < cnode->size(); ++i) {
init_consumer.push_back(cnode->input(i));
}
while (!init_consumer.empty()) {
auto new_node = init_consumer.front();
init_consumer.pop_front();
auto new_consumer = GetConsumers(mng, new_node);
std::transform(new_consumer.begin(), new_consumer.end(), std::back_inserter(init_consumer),
[](const CNodeIndexPair &index_pair) { return index_pair.first; });
}
for (auto it : children_map_) {
if (it.second.size() > 1) {
op_roots_.insert(it.first);
}
}
std::vector<Group> groups;
for (const auto &root : op_roots_) {
size_t ngroups = groups.size();
auto childrens = children_map_.at(root);
for (auto child : childrens) {
auto prim_name = GetCNodePrimitive(child)->name();
// Branch should start with target node that specified by `op_name_`
if (prim_name != op_name_) {
continue;
}
auto branch = CreateBranch(child);
branch.SetDataRoot(root);
// // position index less than 0 means we didn't find target op in this branch
// if (branch.target_op_pos < 0) {
// groups.emplace_back();
// groups.back().push_back(branch);
// continue;
// }
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [this, &branch](const Group &group) {
MS_EXCEPTION_IF_CHECK_FAIL(!group.empty() && !group[0].ops.empty(), "group empty or group[0] empty");
auto top_branch = group[0];
return (branch.target_op_pos == top_branch.target_op_pos) &&
fare_compatible_ops_(branch.GetTargetOp(), top_branch.GetTargetOp());
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
groups.back().push_back(branch);
}
}
}
return groups;
}
Branch BranchGroupFinder::CreateBranch(AnfNodePtr lead_op) {
AnfNodePtrList ops{lead_op};
int root_idx = GetCNodePrimitive(lead_op)->name() == op_name_ ? 0 : -1;
auto it = children_map_.find(lead_op);
while (it != children_map_.end() && it->second.size() == 1) {
auto node = *(it->second).begin();
ops.push_back(node);
auto prim_name = GetCNodePrimitive(node)->name();
if (prim_name == op_name_) {
root_idx = ops.size();
}
it = children_map_.find(node);
}
return Branch(ops, root_idx);
}
ParallelOpCombiner::ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout)
: op_name_(op_name), min_num_branches_(min_num_branches), layout_(layout) {}
AnfNodePtr ParallelOpCombiner::Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(root);
if (func_graph) {
main_graph_ = func_graph;
} else {
main_graph_ = common::AnfAlgo::GetCNodeFuncGraphPtr(root);
}
MS_EXCEPTION_IF_NULL(main_graph_);
auto finder = BranchGroupFinder(
op_name_, [&](const AnfNodePtr n) { return IsSupportedOp(n); },
[&](const AnfNodePtr a, const AnfNodePtr b) { return CanOpsBeCombined(a, b); });
auto groups = finder.Find(root, main_graph_);
children_map_ = std::move(finder.children_map_);
for (const Group &group : groups) {
if (group.size() < min_num_branches_) {
MS_LOG(INFO) << "group size = " << group.size() << " < " << min_num_branches_ << ", skip.";
continue;
}
CombineBranches(group);
}
return combined_;
}
void ParallelOpCombiner::CombineBranches(const Group &branches) {
auto combined = MakeCombinedOp(branches);
auto it = std::min_element(branches.begin(), branches.end(), [](const Branch &branch_a, const Branch &branch_b) {
return branch_a.ops.size() < branch_b.ops.size();
});
size_t depth = it->ops.size();
int pos;
for (pos = 0; pos < static_cast<int>(depth); ++pos) {
if (pos == it->target_op_pos) {
continue;
}
if (!CheckLevel(branches, pos)) {
break;
}
combined = MakeCombinedAnfNodePtrFromFollowingOps(combined, branches, pos);
}
UpdateGroupOutput(combined, branches, pos - 1);
combined_ = combined;
}
bool ParallelOpCombiner::CheckLevel(const Group &branches, size_t depth) {
auto repr = branches[0].ops[depth];
auto repr_prim_name = GetCNodePrimitive(repr)->name();
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch &branch = *it;
auto node = branch.ops[depth];
auto prim_name = GetCNodePrimitive(node)->name();
if (prim_name != repr_prim_name) {
MS_LOG(INFO) << "Prim not compatible!" << prim_name << " vs " << repr_prim_name;
return false;
}
if (unsupported_ops_.find(prim_name) != unsupported_ops_.end()) {
MS_LOG(INFO) << "Op " << prim_name << " not supported for combination for now, stop.";
return false;
}
if (!IsArgCompatible(repr, node)) {
return false;
}
}
MS_LOG(DEBUG) << "Op " << repr_prim_name << " can be combined at depth " << depth;
return true;
}
bool ParallelOpCombiner::AutoUpdateInfo(const CNodePtr &to_update, size_t out_size) {
if (to_update->size() < 2) {
MS_LOG(ERROR) << "Cannot auto update for " << to_update->fullname_with_scope() << " with input size "
<< to_update->size();
return false;
}
auto rep_input = to_update->input(1);
// NOTE: We assume the inputs' formats and types are consistent with outputs'.
std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW;
TypeId input_type = TypeId::kTypeUnknown;
auto UpdateBoth = [&input_type, &input_format](const CNodePtr &cnode) -> bool {
if (cnode == nullptr) {
return false;
}
input_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
if (!cnode->HasAttr(kOutputsFormat)) {
return false;
}
auto prev_of = GetValue<std::vector<std::string>>(cnode->GetAttr(kOutputsFormat));
if (prev_of.size() > 0) {
input_format = prev_of[0];
return true;
}
return false;
};
if (AnfUtils::IsRealKernel(rep_input)) {
UpdateBoth(rep_input->cast<CNodePtr>());
}
if (input_format.empty()) {
auto it = children_map_.find(rep_input);
if (it != children_map_.end()) {
for (auto orig_user : it->second) {
if (UpdateBoth(orig_user->cast<CNodePtr>())) {
break;
}
}
}
}
if (input_format.empty() && Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
input_format = layout_;
}
if (input_format.empty() ||
(Callback::Instance()->GetTargetFromContext() != kAscendDevice && input_type == TypeId::kTypeUnknown)) {
MS_LOG(WARNING) << "Cannot auto update input format for node " << to_update->fullname_with_scope();
return false;
}
#ifndef MSLITE_ENABLE_GRAPH_KERNEL
if (Callback::Instance()->GetTargetFromContext() != kAscendDevice) {
to_update->set_kernel_info(std::make_shared<device::KernelInfo>());
std::vector<std::string> to_update_in_formats(to_update->size(), input_format);
std::vector<TypeId> to_update_in_types(to_update->size(), input_type);
std::vector<std::string> to_update_out_formats(out_size, input_format);
std::vector<TypeId> to_update_out_types(out_size, input_type);
auto graph_sel_info = BuildSelectKernelBuildInfo(to_update_in_formats, to_update_in_types, to_update_out_formats,
to_update_out_types, kernel::GetProcessorFromContext());
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, to_update.get());
return true;
}
#endif
std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format);
to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
return true;
}
std::map<size_t, AnfNodePtrList> ParallelOpCombiner::GetUniqueInputs(const Group &branches, size_t depth) {
std::map<size_t, AnfNodePtrList> unique_inputs;
AnfNodePtrList parent_in_branch;
if (depth >= 1) {
std::transform(branches.begin(), branches.end(), std::back_inserter(parent_in_branch),
[&depth](const Branch &br) { return br.ops[depth - 1]; });
} else {
Branch b1 = branches[0];
parent_in_branch.push_back(b1.GetRootData());
}
for (auto br : branches) {
auto op = br.ops[depth];
auto cnode = op->cast<CNodePtr>();
// Here we can know for sure that op's arg length are the same (check before)
for (size_t i = 1; i < cnode->size(); ++i) {
auto in = cnode->input(i);
if (std::any_of(parent_in_branch.begin(), parent_in_branch.end(),
[&in](const AnfNodePtr &p) { return in == p; })) {
continue;
}
unique_inputs[i].push_back(in);
}
}
return unique_inputs;
}
CNodePtr GraphBuilder::NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node,
size_t concat_dim, size_t input_num) {
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
auto maketuple = NewTupleNode(func_graph, input_node);
concat_inputs.push_back(maketuple);
} else {
for (size_t i = 0; i < input_node.size(); ++i) {
auto n = input_node[i];
concat_inputs.push_back(n);
}
}
auto concat = func_graph->NewCNode(concat_inputs);
MS_EXCEPTION_IF_NULL(concat);
func_graph->AddNode(concat);
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node[0], 0)};
auto shape = common::AnfAlgo::GetOutputInferShape(input_node[0], 0);
shape[concat_dim] *= input_num;
std::vector<ShapeVector> shapes(1, shape);
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(concat_dim)), concat);
common::AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(static_cast<int64_t>(input_num)), concat);
common::AnfAlgo::SetNodeAttr(kAttrN, MakeValue(static_cast<int64_t>(input_num)), concat);
return concat;
}
CNodePtr GraphBuilder::NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs) {
auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
AbstractBasePtrList abs_list;
for (auto in : shared_inputs) {
mk_inputs.push_back(in);
abs_list.push_back(in->abstract());
}
auto make_tuple_node = func_graph->NewCNode(mk_inputs);
func_graph->AddNode(make_tuple_node);
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
return make_tuple_node;
}
CNodePtr GraphBuilder::NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
size_t split_num) {
if (split_num == 0) {
MS_LOG(EXCEPTION) << "split_num should not be zero.";
}
MS_EXCEPTION_IF_NULL(input_node);
std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
input_node};
auto split = func_graph->NewCNode(split_inputs);
func_graph->AddNode(split);
MS_EXCEPTION_IF_NULL(split);
auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
std::vector<TypeId> dtypes(split_num, dtype);
auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
shape[split_dim] /= split_num;
std::vector<ShapeVector> shapes(split_num, shape);
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
common::AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_num), split);
return split;
}
CNodePtr GraphBuilder::NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
const AnfNodePtr &orig_node) {
auto node = func_graph->NewCNode(inputs);
func_graph->AddNode(node);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
std::vector<ShapeVector> shapes = {common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0)};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
return node;
}
CNodePtr GraphBuilder::NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
const AnfNodePtr &orig_node) {
auto node = func_graph->NewCNode(inputs);
func_graph->AddNode(node);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
auto orig_shape_in = common::AnfAlgo::GetPrevNodeOutputInferShape(orig_node, 0);
auto orig_shape_out = common::AnfAlgo::GetOutputInferShape(orig_node, 0);
auto new_out_shape = InferReshapeOut(orig_shape_in, orig_shape_out, new_shape_in);
GetCNodePrimitive(node)->set_attr(kShape, MakeValue(new_out_shape));
std::vector<ShapeVector> shapes = {new_out_shape};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
return node;
}
CNodePtr GraphBuilder::NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
const AnfNodePtr &orig_node) {
auto node = func_graph->NewCNode(inputs);
func_graph->AddNode(node);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
auto perm_int = GetTransposePerm(GetCNodePrimitive(node));
auto new_out_shape = InferTransposeOut(new_shape_in, perm_int);
std::vector<ShapeVector> shapes = {new_out_shape};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
return node;
}
ShapeVector GraphBuilder::InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
const ShapeVector &new_reshape_in) {
ShapeVector new_shape_out;
if (orig_reshape_in.size() == new_reshape_in.size()) {
return InferConcatReshapeOut(orig_reshape_in, orig_reshape_out, new_reshape_in);
} else {
MS_LOG(EXCEPTION) << "Stack combiner infer for reshape not impl yet";
}
return new_shape_out;
}
ShapeVector GraphBuilder::InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm) {
ShapeVector out_shape;
for (int64_t i : perm) {
auto idx = LongToSize(i);
out_shape.push_back(in_shape[idx]);
}
return out_shape;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,136 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include "include/backend/optimizer/pass.h"
#include "ir/func_graph.h"
#include "backend/common/graph_kernel/graph_kernel_helper.h"
#include "mindspore/core/ops/array_ops.h"
namespace mindspore::graphkernel {
struct Branch {
Branch(AnfNodePtrList lst, int pos) : ops(lst), target_op_pos(pos) {}
AnfNodePtrList ops;
int target_op_pos; // -1 means no target op in this branch
AnfNodePtr root_data{nullptr};
size_t size() { return ops.size(); }
AnfNodePtr GetTargetOp() { return GetOp(target_op_pos); }
AnfNodePtr GetOp(int depth) {
if (depth < 0 || depth >= static_cast<int>(ops.size())) {
return nullptr;
}
return ops[depth];
}
AnfNodePtr GetRootData() { return root_data; }
void SetDataRoot(AnfNodePtr data) { root_data = data; }
std::string ToString() {
std::string res;
res += "RootData: ";
res += root_data->fullname_with_scope();
res += "; Ops: [";
for (size_t i = 0; i < ops.size(); ++i) {
auto op = ops[i];
res += op->fullname_with_scope();
if (static_cast<int>(i) == target_op_pos) {
res += "(LEAD OP)";
}
res += ", ";
}
res += "]";
return res;
}
};
using Group = std::vector<Branch>;
using FIsSupportedOp = std::function<bool(const AnfNodePtr &n)>;
using FAreCompatibleOps = std::function<bool(const AnfNodePtr &a, const AnfNodePtr &b)>;
using AnfNodePtrSubstMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
using AnfNodePtrSet = std::unordered_set<AnfNodePtr>;
class BranchGroupFinder {
public:
BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops);
std::vector<Group> Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph = nullptr);
std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
private:
std::string op_name_;
AnfNodePtrSet op_roots_;
FIsSupportedOp fis_supported_op_;
FAreCompatibleOps fare_compatible_ops_;
Branch CreateBranch(AnfNodePtr lead_op);
AnfNodeIndexSet GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer);
};
class ParallelOpCombiner {
public:
explicit ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout);
AnfNodePtr Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph = nullptr);
protected:
virtual bool IsSupportedOp(const AnfNodePtr n) = 0;
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0;
virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0;
virtual bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) = 0;
virtual AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
size_t depth) = 0;
virtual void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) = 0;
bool AutoUpdateInfo(const CNodePtr &to_update, size_t out_size = 1);
std::map<size_t, AnfNodePtrList> GetUniqueInputs(const Group &branches, size_t depth);
FuncGraphPtr main_graph_;
AnfNodePtr combined_;
std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
std::unordered_set<std::string> unsupported_ops_{prim::kTranspose, prim::kReshape};
private:
void CombineBranches(const Group &branches);
bool CheckLevel(const Group &branches, size_t depth);
std::string op_name_;
uint64_t min_num_branches_{2};
std::string layout_;
};
class GraphBuilder {
public:
static CNodePtr NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs);
static CNodePtr NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
size_t split_num);
static CNodePtr NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node, size_t concat_dim,
size_t input_num);
static CNodePtr NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
const AnfNodePtr &orig_node);
static CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
const AnfNodePtr &orig_node);
static CNodePtr NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
const AnfNodePtr &orig_node);
static ShapeVector InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
const ShapeVector &new_reshape_in);
static ShapeVector InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
const ShapeVector &new_reshape_in);
static ShapeVector InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm);
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_

View File

@ -0,0 +1,195 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/graph_kernel/core/parallel_op_concatenate.h"
#include <vector>
#include <string>
#include <set>
#include <unordered_set>
#include "include/backend/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "kernel/common_utils.h"
#include "backend/common/graph_kernel/graph_kernel_helper.h"
#include "backend/common/graph_kernel/adapter/callback_impl.h"
namespace mindspore::graphkernel {
ParallelOpConcatenater::ParallelOpConcatenater(const std::string &op_name, uint64_t min_num_branches,
const std::string &layout)
: ParallelOpCombiner(op_name, min_num_branches, layout) {}
bool ParallelOpConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) {
auto cnode_a = a->cast<CNodePtr>();
auto cnode_b = b->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_a);
MS_EXCEPTION_IF_NULL(cnode_b);
auto arg_size = cnode_a->size();
if (arg_size != cnode_b->size()) {
MS_LOG(DEBUG) << "Args size not compatible: " << arg_size << " vs " << cnode_b->size();
return false;
}
auto cb = Callback::Instance();
for (size_t i = 1; i < arg_size; ++i) {
auto shape_a = cb->GetInputInferShape(a, i);
auto shape_b = cb->GetInputInferShape(b, i);
if (shape_a != shape_b) {
MS_LOG(ERROR) << "Args shape not compatible:" << shape_a << " vs " << shape_b;
return false;
}
}
return true;
}
AnfNodePtr ParallelOpConcatenater::MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
size_t depth) {
auto ew_plan = GetElemWiseFollowingPlan(branches, depth);
plans_.push_back(ew_plan);
auto overall_inputs = ReloadInputs(branches, depth, data);
if (branches.empty()) {
MS_LOG(EXCEPTION) << "Fail to sample ops in a empty group.";
}
// Since all the ops of same depth in group should be the same, we just sample op in first branch.
Branch b0 = branches[0];
auto orig_node = b0.GetOp(static_cast<int>(depth));
MS_EXCEPTION_IF_NULL(orig_node);
CNodePtr new_node;
if (GetCNodePrimitive(orig_node)->name() == prim::kReshape) {
new_node = GraphBuilder::NewReshapeNode(main_graph_, overall_inputs, orig_node);
} else if (GetCNodePrimitive(orig_node)->name() == prim::kTranspose) {
new_node = GraphBuilder::NewTransposeNode(main_graph_, overall_inputs, orig_node);
} else {
new_node = GraphBuilder::NewElemwiseNoAttrNode(main_graph_, overall_inputs, orig_node);
}
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(new_node), "AutoUpdateInfo fail");
return new_node;
}
std::map<size_t, AnfNodePtr> ParallelOpConcatenater::ConcatUniqueInputs(std::map<size_t, AnfNodePtrList> unique_inputs,
size_t concat_idx) {
std::map<size_t, AnfNodePtr> concated_inputs;
for (auto it : unique_inputs) {
size_t input_idx = it.first;
auto local_inputs = it.second;
if (local_inputs.size() < kDim2) {
MS_LOG(WARNING) << "Concat Op needs at least 2 inputs, while got " << local_inputs.size();
continue;
}
auto concat_node = GraphBuilder::NewConcatNode(main_graph_, local_inputs, concat_idx, local_inputs.size());
MS_EXCEPTION_IF_NULL(concat_node);
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(concat_node), "AutoUpdateInfo fail");
concated_inputs[input_idx] = concat_node;
}
return concated_inputs;
}
void ParallelOpConcatenater::UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) {
if (depth >= plans_.size()) {
MS_LOG(EXCEPTION) << "Cannot get plan at depth " << depth << " vs " << plans_.size();
}
auto ew_plan = plans_[depth];
auto split_node = GraphBuilder::NewSplitNode(main_graph_, data, ew_plan.split_out_idx, branches.size());
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(split_node, branches.size()), "AutoUpdateInfo fail");
main_graph_->AddNode(split_node);
auto mng = main_graph_->manager();
for (size_t i = 0; i < branches.size(); ++i) {
auto br = branches[i];
auto target = br.ops[depth];
auto idx_val = MakeValue(SizeToLong(i));
auto gt_idx = NewValueNode(idx_val);
gt_idx->set_abstract(idx_val->ToAbstract());
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), split_node, gt_idx};
auto new_out = main_graph_->NewCNode(gt_inputs);
new_out->set_abstract(target->abstract()->Clone());
mng->Replace(target, new_out);
}
return;
}
ConcatenatePlan ParallelOpConcatenater::GetElemWiseFollowingPlan(const Group &branches, size_t depth) {
if (depth - 1 >= plans_.size()) {
MS_LOG(EXCEPTION) << "Should get " << (depth - 1) << " plan first, current plan size = " << plans_.size();
}
auto last_plan = plans_[depth - 1];
ConcatenatePlan ew_plan;
auto unique_inputs = GetUniqueInputs(branches, depth);
auto cb = Callback::Instance();
for (auto it : unique_inputs) {
for (auto in : it.second) {
ew_plan.in_shape = cb->GetOutputInferShape(in, 0);
break;
}
}
auto UpdateIdx = [](ShapeVector &base_shape, ShapeVector &new_shape, size_t base_idx) {
if (new_shape.empty()) {
return base_idx;
}
auto rank_diff = static_cast<int>(base_shape.size()) - static_cast<int>(new_shape.size());
if (rank_diff > static_cast<int>(base_idx)) {
return base_idx;
}
return base_idx - rank_diff;
};
ew_plan.concat_in_idx = UpdateIdx(last_plan.in_shape, ew_plan.in_shape, last_plan.concat_in_idx);
Branch b0 = branches[0];
auto op = b0.ops[depth];
ew_plan.out_shape = cb->GetOutputInferShape(op, 0);
ew_plan.split_out_idx = UpdateIdx(last_plan.out_shape, ew_plan.out_shape, last_plan.split_out_idx);
MS_LOG(DEBUG) << "EW plan: " << ew_plan.concat_in_idx << ", " << ew_plan.split_out_idx << ", " << ew_plan.out_shape;
return ew_plan;
}
AnfNodePtrList ParallelOpConcatenater::ReloadInputs(const Group &branches, size_t depth, AnfNodePtr shared_input) {
Branch b1 = branches[0];
auto cnode = b1.ops[depth]->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_size = cnode->size();
auto plan = plans_[depth];
auto unique_inputs = GetUniqueInputs(branches, depth);
AnfNodePtrList overall_inputs{cnode->input(0)}; // prim
auto concated_inputs = ConcatUniqueInputs(unique_inputs, plan.concat_in_idx);
for (size_t i = 1; i < input_size; ++i) {
if (concated_inputs.find(i) != concated_inputs.end()) {
overall_inputs.push_back(concated_inputs[i]);
} else {
overall_inputs.push_back(shared_input);
}
}
return overall_inputs;
}
ShapeVector GraphBuilder::InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
const ShapeVector &new_reshape_in) {
std::map<int, int> idx_map_rev;
std::map<int, int> mul_map;
int oidx = static_cast<int>(orig_reshape_out.size()) - 1;
for (int ridx = static_cast<int>(orig_reshape_in.size()) - 1; ridx >= 0; --ridx) {
auto cur_size = orig_reshape_in[ridx];
mul_map[ridx] = new_reshape_in[ridx] / orig_reshape_in[ridx];
while (oidx >= 0 && cur_size >= orig_reshape_out[oidx] && cur_size % orig_reshape_out[oidx] == 0) {
idx_map_rev[oidx] = ridx;
cur_size = cur_size / orig_reshape_out[oidx];
oidx--;
}
}
ShapeVector new_shape_out;
for (int i = 0; i < static_cast<int>(orig_reshape_out.size()); ++i) {
auto in_idx = idx_map_rev[i];
auto mul = mul_map[in_idx];
new_shape_out.push_back(orig_reshape_out[i] * mul);
mul_map[in_idx] = 1;
}
return new_shape_out;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,56 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include "include/backend/optimizer/pass.h"
#include "ir/func_graph.h"
#include "backend/common/graph_kernel/graph_kernel_helper.h"
#include "backend/common/graph_kernel/core/parallel_op_combine.h"
namespace mindspore::graphkernel {
struct ConcatenatePlan {
int concat_in_idx{0};
int split_out_idx{0};
ShapeVector in_shape;
ShapeVector out_shape;
};
class ParallelOpConcatenater : public ParallelOpCombiner {
public:
explicit ParallelOpConcatenater(const std::string &op_name, uint64_t min_num_branches, const std::string &layout);
protected:
virtual bool IsSupportedOp(const AnfNodePtr n) = 0;
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0;
virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0;
bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b);
AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches, size_t depth) final;
void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) final;
std::map<size_t, AnfNodePtr> ConcatUniqueInputs(std::map<size_t, AnfNodePtrList> unique_inputs, size_t concat_idx);
ConcatenatePlan GetElemWiseFollowingPlan(const Group &branches, size_t depth);
AnfNodePtrList ReloadInputs(const Group &branches, size_t depth, AnfNodePtr shared_input);
std::vector<ConcatenatePlan> plans_;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_

View File

@ -319,6 +319,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("enable_lite_conv_tuning", &enable_lite_conv_tuning);
reg.AddFlag("enable_vectorization", &enable_vectorization);
reg.AddFlag("enable_dynamic_shape_fusion", &enable_dynamic_shape_fusion);
reg.AddFlag("enable_parallel_op_combine", &enable_parallel_op_combine);
// Integer flags
reg.AddFlag("reduce_fuse_depth", &reduce_fuse_depth);

View File

@ -93,6 +93,11 @@ class BACKEND_EXPORT GraphKernelFlags {
*/
unsigned int parallel_ops_level{OpLevel_0};
/**
* Enable parallel op combination, default is false.
*/
bool enable_parallel_op_combine{false};
/**
* Enable horizontal fusion in graph kernel fusion strategy, default is false.
*/

View File

@ -130,6 +130,16 @@ NodePtr LiteGraph::GraphBuilderBase::Emit(const std::string &op, const NodePtrLi
return op_ptr;
}
NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBaseList &baseinfolist,
const NodePtrList &inputs, const DAttrs &attrs) const {
PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName());
op_ptr->SetInputs(inputs);
op_ptr->SetAttrs(attrs);
op_ptr->SetBaseInfo(baseinfolist);
(void)graph_->ops_.emplace_back(op_ptr);
return op_ptr;
}
NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
const DAttrs &attrs) const {
PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName());

View File

@ -72,6 +72,8 @@ class LiteGraph::GraphBuilderBase {
NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const;
// Create op node with given baseinfo.
NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs,
const DAttrs &attrs = {}) const;
NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
const DAttrs &attrs = {}) const;
LiteGraphPtr Get() const { return graph_; }

View File

@ -31,6 +31,7 @@
#include "backend/common/graph_kernel/core/transform_op_optimizer.h"
#include "backend/common/graph_kernel/core/update_state_formatter.h"
#include "backend/common/graph_kernel/graph_kernel_flags.h"
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
#include "tools/graph_kernel/converter/akg/utils.h"
#include "tools/graph_kernel/converter/callback_impl.h"
@ -77,6 +78,11 @@ GkPassManagerPtr GraphKernelOptimizer::Cluster() const {
auto pm = std::make_shared<GraphKernelPassManagerLite>(kStageCluster, "cluster");
// Expand complex basic kernels to composite kernels
pm->Add(std::make_shared<GraphKernelExpanderLite>(), OptLevel_1);
if (GraphKernelFlags::GetInstance().enable_parallel_op_combine || is_ascend) {
pm->Add(std::make_shared<GraphKernelOpCombiner>(), OptLevel_2);
}
pm->Add(std::make_shared<ConvTuningExpander>(), OptLevel_1, is_cpu);
// Cluster basic kernels and composite kernels

View File

@ -0,0 +1,101 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
class CustomNet(Cell):
def __init__(self):
super(CustomNet, self).__init__()
self.matmul = P.MatMul()
self.add = P.Add()
self.abs = P.Abs()
def construct(self, mx_a, mx_b1, bias1, mx_b2, bias2):
# use Abs to create a shared input for matmul
abs1 = P.Abs()(mx_a)
abs2 = P.Abs()(mx_a)
# branch 1: matmul - add - abs
m1 = self.matmul(abs1, mx_b1)
m1 = self.add(m1, bias1)
m1 = self.abs(m1)
# branch 2: matmul - add - abs
m2 = self.matmul(abs2, mx_b2)
m2 = self.add(m2, bias2)
m2 = self.abs(m2)
return m1, m2
def get_output(i0, i1, i2, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel, save_graphs=False)
net = CustomNet()
mx_a = i0
mx_b1 = i1 * 3
mx_b2 = i1 * 2
bias1 = i2 * 3
bias2 = i2 * 2
output = net(mx_a, mx_b1, bias1, mx_b2, bias2)
return output
def run():
i0 = Tensor(np.random.normal(1, 0.01, [96, 800]).astype(np.float16))
i1 = Tensor(np.random.normal(1, 0.01, [800, 128]).astype(np.float16))
i2 = Tensor(np.random.normal(1, 0.01, [1, 128]).astype(np.float16))
expect = get_output(i0, i1, i2, False)
output = get_output(i0, i1, i2, True)
for exp, out in zip(expect, output):
expect_np = exp.asnumpy().copy()
output_np = out.asnumpy().copy()
if not np.allclose(expect_np, output_np, 1.e-4, 1.e-7):
return false
return true
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parallel_matmul_combine_ascend():
"""
Feature: Parallel Matmul combination
Description: on Ascend device
Expectation: network return same result with the feature on and off
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(graph_kernel_flags="--enable_parallel_op_combine=1 --opt_level=1")
assert run()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parallel_matmul_combine_gpu():
"""
Feature: Parallel Matmul combination
Description: on GPU device
Expectation: network return same result with the feature on and off
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(graph_kernel_flags="--enable_parallel_op_combine=1 --opt_level=1")
assert run()