!55337 [graph kernel] add GraphKernelOpCombiner pass
Merge pull request !55337 from yangsijia/graph-rewrite
This commit is contained in:
commit
2e7652bf66
|
@ -62,6 +62,7 @@
|
|||
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/common/graph_kernel/compact_tensor_liveness.h"
|
||||
#include "backend/common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
|
||||
|
||||
#ifdef ENABLE_AKG
|
||||
#include "backend/common/graph_kernel/graph_kernel_build.h"
|
||||
|
@ -110,6 +111,10 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
|||
// Expand complex basic kernels to composite kernels
|
||||
pm->Add(std::make_shared<GraphKernelExpanderWithPy>(), OptLevel_1);
|
||||
|
||||
if (GraphKernelFlags::GetInstance().enable_parallel_op_combine) {
|
||||
pm->Add(std::make_shared<GraphKernelOpCombiner>(), OptLevel_2);
|
||||
}
|
||||
|
||||
// Cluster basic kernels and composite kernels
|
||||
pm->Add(std::make_shared<GraphKernelCluster>(), OptLevel_1);
|
||||
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
bool GraphKernelOpCombiner::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto cb = Callback::Instance();
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
auto changed = false;
|
||||
for (auto node : nodes) {
|
||||
if (node->cast<CNodePtr>() == nullptr || !AnfUtils::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = node->fullname_with_scope();
|
||||
auto res = ConcatParallelMatMul(node, min_ops_to_combine_, default_layout_to_combine_, func_graph);
|
||||
if (res != nullptr) {
|
||||
changed = true;
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class GraphKernelOpCombiner : public opt::Pass {
|
||||
public:
|
||||
GraphKernelOpCombiner() : Pass("graph_kernel_op_combiner") {}
|
||||
~GraphKernelOpCombiner() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
uint64_t min_ops_to_combine_{2};
|
||||
std::string default_layout_to_combine_{kOpFormat_NCHW};
|
||||
};
|
||||
using GraphKernelOpCombinerPtr = std::shared_ptr<GraphKernelOpCombiner>;
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPHKERNEL_REWRITE_H_
|
|
@ -316,15 +316,20 @@ inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
|
|||
auto todos = TopoSort(func_graph->output());
|
||||
const auto ¶ms = func_graph->parameters();
|
||||
auto cb = Callback::Instance();
|
||||
auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) {
|
||||
auto shape = cb->GetOutputShape(node, 0);
|
||||
auto type = cb->GetOutputType(node, 0);
|
||||
auto format = cb->GetOutputFormat(node, 0);
|
||||
return inner::NodeBase({shape, type, format});
|
||||
auto ExtractBuildInfo = [&cb](const AnfNodePtr &node) -> inner::NodeBaseList {
|
||||
inner::NodeBaseList listinfo;
|
||||
size_t output_num = AnfUtils::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto shape = cb->GetOutputShape(node, i);
|
||||
auto type = cb->GetOutputType(node, i);
|
||||
auto format = cb->GetOutputFormat(node, i);
|
||||
listinfo.push_back(inner::NodeBase({shape, type, format}));
|
||||
}
|
||||
return listinfo;
|
||||
};
|
||||
// set inputs
|
||||
for (auto &p : params) {
|
||||
node_map[p] = gb.Parameter(ExtractBuildInfo(p));
|
||||
node_map[p] = gb.Parameter(ExtractBuildInfo(p)[0]);
|
||||
}
|
||||
// set ops
|
||||
for (auto node : todos) {
|
||||
|
@ -332,7 +337,7 @@ inner::LiteGraphPtr GkUtils::AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
|
|||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
if (node == func_graph->output() && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
break;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
|
||||
#include "base/base.h"
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
MMAttr GetMatMulTransposeAttr(const CNodePtr &matmul) {
|
||||
auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
|
||||
if (mm_attrs.count(kTransposeA) == 0 || mm_attrs.count(kTransposeB) == 0) {
|
||||
MS_LOG(WARNING) << "Can not find attr 'transpose_a' or 'transpose_b' in node " << matmul->fullname_with_scope();
|
||||
return std::make_pair(false, false);
|
||||
}
|
||||
auto trans_a = GetValue<bool>(mm_attrs[kTransposeA]);
|
||||
auto trans_b = GetValue<bool>(mm_attrs[kTransposeB]);
|
||||
return std::make_pair(trans_a, trans_b);
|
||||
}
|
||||
|
||||
CNodePtr NewMatMulNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs, const CNodePtr &orig_matmul,
|
||||
ShapeVector new_out_shape) {
|
||||
auto matmul = func_graph->NewCNode(matmul_inputs);
|
||||
func_graph->AddNode(matmul);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
MS_EXCEPTION_IF_NULL(matmul_inputs[1]);
|
||||
auto orig_cnode = matmul_inputs[1]->cast<CNodePtr>();
|
||||
if (orig_cnode != nullptr && orig_cnode->HasAttr(kOutputsFormat)) {
|
||||
auto input_format = GetValue<std::vector<std::string>>(orig_cnode->GetAttr(kOutputsFormat))[0];
|
||||
std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(matmul), input_format);
|
||||
matmul->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
|
||||
}
|
||||
auto [trans_a, trans_b] = GetMatMulTransposeAttr(orig_matmul);
|
||||
matmul->AddAttr(kTransposeA, MakeValue(trans_a));
|
||||
matmul->AddAttr(kTransposeB, MakeValue(trans_b));
|
||||
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(matmul_inputs[1], 0)};
|
||||
std::vector<ShapeVector> shapes = {new_out_shape};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, matmul.get());
|
||||
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
return matmul;
|
||||
}
|
||||
|
||||
BMNK GetBatchMNKV1(const CNodePtr &matmul) {
|
||||
size_t b, m, n, k = 0;
|
||||
auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
|
||||
auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
|
||||
auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
|
||||
if (shape_a.size() == kDim3 && shape_b.size() == kDim3 && shape_a[kIndex0] == shape_b[kIndex0]) {
|
||||
b = shape_a[kIndex0];
|
||||
shape_a.erase(shape_a.begin());
|
||||
shape_b.erase(shape_b.begin());
|
||||
} else {
|
||||
b = 1;
|
||||
}
|
||||
m = trans_a ? shape_a[kIndex1] : shape_a[kIndex0];
|
||||
k = trans_a ? shape_a[kIndex0] : shape_a[kIndex1];
|
||||
n = trans_b ? shape_b[kIndex0] : shape_b[kIndex1];
|
||||
return std::tuple(b, m, n, k);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ConcatenatePlan ParallelMatMulConcatenater::Analyse(const Group &branches) {
|
||||
ConcatenatePlan target_op_res;
|
||||
Branch b0 = branches[kIndex0];
|
||||
AnfNodePtr shared_input = b0.GetRootData();
|
||||
target_op_res.in_shape = Callback::Instance()->GetOutputInferShape(shared_input, kIndex0);
|
||||
auto matmul = b0.GetTargetOp()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
bool is_a_shared = false;
|
||||
for (size_t i = 1; i < matmul->size(); ++i) {
|
||||
auto in = matmul->input(i);
|
||||
if (in == shared_input) {
|
||||
is_a_shared = i == kIndex1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
|
||||
int64_t b, m, n, k;
|
||||
std::tie(b, m, n, k) = GetBatchMNKV1(matmul);
|
||||
if (is_a_shared) {
|
||||
auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
|
||||
size_t rank_b = shape_b.size();
|
||||
auto n_idx = trans_b ? rank_b - kIndex2 : rank_b - kIndex1;
|
||||
target_op_res.concat_in_idx = n_idx;
|
||||
target_op_res.split_out_idx = rank_b - kIndex1;
|
||||
int64_t new_n = n * branches.size();
|
||||
if (rank_b == kDim3) {
|
||||
target_op_res.out_shape = ShapeVector({b, m, new_n});
|
||||
} else {
|
||||
target_op_res.out_shape = ShapeVector({m, new_n});
|
||||
}
|
||||
} else {
|
||||
auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
|
||||
size_t rank_a = shape_a.size();
|
||||
auto m_idx = trans_a ? rank_a - kIndex1 : rank_a - kIndex2;
|
||||
target_op_res.concat_in_idx = m_idx;
|
||||
target_op_res.split_out_idx = rank_a - kIndex2;
|
||||
int64_t new_m = m * branches.size();
|
||||
if (rank_a == kDim3) {
|
||||
target_op_res.out_shape = ShapeVector({b, new_m, n});
|
||||
} else {
|
||||
target_op_res.out_shape = ShapeVector({new_m, n});
|
||||
}
|
||||
}
|
||||
return target_op_res;
|
||||
}
|
||||
|
||||
bool ParallelMatMulConcatenater::CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) {
|
||||
auto matmul1 = a->cast<CNodePtr>();
|
||||
auto matmul2 = b->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(matmul1);
|
||||
MS_EXCEPTION_IF_NULL(matmul2);
|
||||
auto [trans_a1, trans_b1] = GetMatMulTransposeAttr(matmul1);
|
||||
auto [trans_a2, trans_b2] = GetMatMulTransposeAttr(matmul2);
|
||||
return trans_a1 == trans_a2 && trans_b1 == trans_b2;
|
||||
}
|
||||
|
||||
bool ParallelMatMulConcatenater::IsSupportedOp(const AnfNodePtr n) {
|
||||
if (n->cast<CNodePtr>() == nullptr || unsupported_ops_.count(GetCNodePrimitive(n)->name())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr ParallelMatMulConcatenater::MakeCombinedOp(const Group &branches) {
|
||||
Branch b1 = branches[0];
|
||||
AnfNodePtr shared_input = b1.GetRootData();
|
||||
auto matmul_op = b1.GetTargetOp()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(matmul_op);
|
||||
auto plan = Analyse(branches);
|
||||
plans_.push_back(plan);
|
||||
auto overall_inputs = ReloadInputs(branches, b1.target_op_pos, shared_input);
|
||||
auto matmul = NewMatMulNode(main_graph_, overall_inputs, matmul_op, plan.out_shape);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(matmul), "AutoUpdateInfo fail");
|
||||
return matmul;
|
||||
}
|
||||
|
||||
bool ParallelMatMulConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) { return true; }
|
||||
|
||||
AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
if (layout == kOpFormat_NCHW) {
|
||||
auto res = ParallelMatMulConcatenater(min_num_branches, layout).Combine(root, func_graph);
|
||||
return res;
|
||||
}
|
||||
MS_LOG(WARNING) << "Not supported combine for layout " << layout;
|
||||
return root;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/graph_kernel/core/parallel_op_concatenate.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using BMNK = std::tuple<size_t, size_t, size_t, size_t>;
|
||||
using MMAttr = std::pair<bool, bool>;
|
||||
class ParallelMatMulConcatenater : public ParallelOpConcatenater {
|
||||
public:
|
||||
explicit ParallelMatMulConcatenater(uint64_t min_num_branches, const std::string &layout)
|
||||
: ParallelOpConcatenater("MatMul", min_num_branches, layout) {}
|
||||
|
||||
protected:
|
||||
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b);
|
||||
virtual bool IsSupportedOp(const AnfNodePtr n);
|
||||
virtual AnfNodePtr MakeCombinedOp(const Group &branches);
|
||||
bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) override;
|
||||
|
||||
private:
|
||||
ConcatenatePlan Analyse(const Group &branches);
|
||||
};
|
||||
|
||||
AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
|
||||
const FuncGraphPtr &func_graph = nullptr);
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_MATMUL_CONCATENATE_H_
|
|
@ -0,0 +1,443 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/graph_kernel/core/parallel_op_combine.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <deque>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "include/backend/kernel_graph.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/common/graph_kernel/adapter/callback_impl.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
constexpr auto kPerm = "perm";
|
||||
constexpr auto kShape = "shape";
|
||||
std::vector<int64_t> GetTransposePerm(const PrimitivePtr &primitive) {
|
||||
ValuePtr perm = primitive->GetAttr(kPerm);
|
||||
MS_EXCEPTION_IF_NULL(perm);
|
||||
auto perm_val = perm->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(perm_val);
|
||||
auto perm_val_data = perm_val->value();
|
||||
std::vector<int64_t> perm_int;
|
||||
(void)std::transform(perm_val_data.begin(), perm_val_data.end(), std::back_inserter(perm_int),
|
||||
[=](const ValuePtr &e) -> int64_t {
|
||||
if (e->isa<Int64Imm>()) {
|
||||
return GetValue<int64_t>(e);
|
||||
} else if (e->isa<Int32Imm>()) {
|
||||
return GetValue<int>(e);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Perm must be int";
|
||||
return -1;
|
||||
}
|
||||
});
|
||||
return perm_int;
|
||||
}
|
||||
} // namespace
|
||||
BranchGroupFinder::BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op,
|
||||
FAreCompatibleOps fare_compatible_ops)
|
||||
: op_name_(op_name), fis_supported_op_(fis_supported_op), fare_compatible_ops_(fare_compatible_ops) {}
|
||||
|
||||
AnfNodeIndexSet BranchGroupFinder::GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer) {
|
||||
AnfNodeIndexSet consumers;
|
||||
auto users = mng->node_users()[producer];
|
||||
for (auto it : users) {
|
||||
auto user = it.first;
|
||||
if (user->cast<CNodePtr>() && AnfUtils::IsRealKernel(user) && fis_supported_op_(user)) {
|
||||
consumers.add(CNodeIndexPair(it.first, it.second));
|
||||
children_map_[producer].insert(user);
|
||||
}
|
||||
}
|
||||
return consumers;
|
||||
}
|
||||
|
||||
std::vector<Group> BranchGroupFinder::Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph) {
|
||||
auto graph_kernel_fg = func_graph == nullptr ? common::AnfAlgo::GetCNodeFuncGraphPtr(start_node) : func_graph;
|
||||
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
|
||||
auto mng = graph_kernel_fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto cnode = start_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::deque<AnfNodePtr> init_consumer;
|
||||
std::transform(graph_kernel_fg->parameters().begin(), graph_kernel_fg->parameters().end(),
|
||||
std::back_inserter(init_consumer), [](const AnfNodePtr &global_in) { return global_in; });
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
init_consumer.push_back(cnode->input(i));
|
||||
}
|
||||
while (!init_consumer.empty()) {
|
||||
auto new_node = init_consumer.front();
|
||||
init_consumer.pop_front();
|
||||
auto new_consumer = GetConsumers(mng, new_node);
|
||||
std::transform(new_consumer.begin(), new_consumer.end(), std::back_inserter(init_consumer),
|
||||
[](const CNodeIndexPair &index_pair) { return index_pair.first; });
|
||||
}
|
||||
for (auto it : children_map_) {
|
||||
if (it.second.size() > 1) {
|
||||
op_roots_.insert(it.first);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Group> groups;
|
||||
for (const auto &root : op_roots_) {
|
||||
size_t ngroups = groups.size();
|
||||
auto childrens = children_map_.at(root);
|
||||
for (auto child : childrens) {
|
||||
auto prim_name = GetCNodePrimitive(child)->name();
|
||||
// Branch should start with target node that specified by `op_name_`
|
||||
if (prim_name != op_name_) {
|
||||
continue;
|
||||
}
|
||||
auto branch = CreateBranch(child);
|
||||
branch.SetDataRoot(root);
|
||||
// // position index less than 0 means we didn't find target op in this branch
|
||||
// if (branch.target_op_pos < 0) {
|
||||
// groups.emplace_back();
|
||||
// groups.back().push_back(branch);
|
||||
// continue;
|
||||
// }
|
||||
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [this, &branch](const Group &group) {
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(!group.empty() && !group[0].ops.empty(), "group empty or group[0] empty");
|
||||
auto top_branch = group[0];
|
||||
return (branch.target_op_pos == top_branch.target_op_pos) &&
|
||||
fare_compatible_ops_(branch.GetTargetOp(), top_branch.GetTargetOp());
|
||||
});
|
||||
if (it != groups.end()) {
|
||||
it->push_back(branch);
|
||||
} else {
|
||||
groups.emplace_back();
|
||||
groups.back().push_back(branch);
|
||||
}
|
||||
}
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
Branch BranchGroupFinder::CreateBranch(AnfNodePtr lead_op) {
|
||||
AnfNodePtrList ops{lead_op};
|
||||
int root_idx = GetCNodePrimitive(lead_op)->name() == op_name_ ? 0 : -1;
|
||||
auto it = children_map_.find(lead_op);
|
||||
while (it != children_map_.end() && it->second.size() == 1) {
|
||||
auto node = *(it->second).begin();
|
||||
ops.push_back(node);
|
||||
auto prim_name = GetCNodePrimitive(node)->name();
|
||||
if (prim_name == op_name_) {
|
||||
root_idx = ops.size();
|
||||
}
|
||||
it = children_map_.find(node);
|
||||
}
|
||||
return Branch(ops, root_idx);
|
||||
}
|
||||
|
||||
ParallelOpCombiner::ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout)
|
||||
: op_name_(op_name), min_num_branches_(min_num_branches), layout_(layout) {}
|
||||
|
||||
AnfNodePtr ParallelOpCombiner::Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
if (func_graph) {
|
||||
main_graph_ = func_graph;
|
||||
} else {
|
||||
main_graph_ = common::AnfAlgo::GetCNodeFuncGraphPtr(root);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(main_graph_);
|
||||
auto finder = BranchGroupFinder(
|
||||
op_name_, [&](const AnfNodePtr n) { return IsSupportedOp(n); },
|
||||
[&](const AnfNodePtr a, const AnfNodePtr b) { return CanOpsBeCombined(a, b); });
|
||||
auto groups = finder.Find(root, main_graph_);
|
||||
children_map_ = std::move(finder.children_map_);
|
||||
for (const Group &group : groups) {
|
||||
if (group.size() < min_num_branches_) {
|
||||
MS_LOG(INFO) << "group size = " << group.size() << " < " << min_num_branches_ << ", skip.";
|
||||
continue;
|
||||
}
|
||||
CombineBranches(group);
|
||||
}
|
||||
return combined_;
|
||||
}
|
||||
|
||||
void ParallelOpCombiner::CombineBranches(const Group &branches) {
|
||||
auto combined = MakeCombinedOp(branches);
|
||||
auto it = std::min_element(branches.begin(), branches.end(), [](const Branch &branch_a, const Branch &branch_b) {
|
||||
return branch_a.ops.size() < branch_b.ops.size();
|
||||
});
|
||||
size_t depth = it->ops.size();
|
||||
int pos;
|
||||
for (pos = 0; pos < static_cast<int>(depth); ++pos) {
|
||||
if (pos == it->target_op_pos) {
|
||||
continue;
|
||||
}
|
||||
if (!CheckLevel(branches, pos)) {
|
||||
break;
|
||||
}
|
||||
combined = MakeCombinedAnfNodePtrFromFollowingOps(combined, branches, pos);
|
||||
}
|
||||
UpdateGroupOutput(combined, branches, pos - 1);
|
||||
combined_ = combined;
|
||||
}
|
||||
|
||||
bool ParallelOpCombiner::CheckLevel(const Group &branches, size_t depth) {
|
||||
auto repr = branches[0].ops[depth];
|
||||
auto repr_prim_name = GetCNodePrimitive(repr)->name();
|
||||
// check if all branches in current depth can be combined
|
||||
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
|
||||
const Branch &branch = *it;
|
||||
auto node = branch.ops[depth];
|
||||
auto prim_name = GetCNodePrimitive(node)->name();
|
||||
if (prim_name != repr_prim_name) {
|
||||
MS_LOG(INFO) << "Prim not compatible!" << prim_name << " vs " << repr_prim_name;
|
||||
return false;
|
||||
}
|
||||
if (unsupported_ops_.find(prim_name) != unsupported_ops_.end()) {
|
||||
MS_LOG(INFO) << "Op " << prim_name << " not supported for combination for now, stop.";
|
||||
return false;
|
||||
}
|
||||
if (!IsArgCompatible(repr, node)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Op " << repr_prim_name << " can be combined at depth " << depth;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParallelOpCombiner::AutoUpdateInfo(const CNodePtr &to_update, size_t out_size) {
|
||||
if (to_update->size() < 2) {
|
||||
MS_LOG(ERROR) << "Cannot auto update for " << to_update->fullname_with_scope() << " with input size "
|
||||
<< to_update->size();
|
||||
return false;
|
||||
}
|
||||
auto rep_input = to_update->input(1);
|
||||
// NOTE: We assume the inputs' formats and types are consistent with outputs'.
|
||||
std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW;
|
||||
TypeId input_type = TypeId::kTypeUnknown;
|
||||
auto UpdateBoth = [&input_type, &input_format](const CNodePtr &cnode) -> bool {
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
input_type = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
if (!cnode->HasAttr(kOutputsFormat)) {
|
||||
return false;
|
||||
}
|
||||
auto prev_of = GetValue<std::vector<std::string>>(cnode->GetAttr(kOutputsFormat));
|
||||
if (prev_of.size() > 0) {
|
||||
input_format = prev_of[0];
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (AnfUtils::IsRealKernel(rep_input)) {
|
||||
UpdateBoth(rep_input->cast<CNodePtr>());
|
||||
}
|
||||
if (input_format.empty()) {
|
||||
auto it = children_map_.find(rep_input);
|
||||
if (it != children_map_.end()) {
|
||||
for (auto orig_user : it->second) {
|
||||
if (UpdateBoth(orig_user->cast<CNodePtr>())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (input_format.empty() && Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
|
||||
input_format = layout_;
|
||||
}
|
||||
if (input_format.empty() ||
|
||||
(Callback::Instance()->GetTargetFromContext() != kAscendDevice && input_type == TypeId::kTypeUnknown)) {
|
||||
MS_LOG(WARNING) << "Cannot auto update input format for node " << to_update->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
#ifndef MSLITE_ENABLE_GRAPH_KERNEL
|
||||
if (Callback::Instance()->GetTargetFromContext() != kAscendDevice) {
|
||||
to_update->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
std::vector<std::string> to_update_in_formats(to_update->size(), input_format);
|
||||
std::vector<TypeId> to_update_in_types(to_update->size(), input_type);
|
||||
std::vector<std::string> to_update_out_formats(out_size, input_format);
|
||||
std::vector<TypeId> to_update_out_types(out_size, input_type);
|
||||
auto graph_sel_info = BuildSelectKernelBuildInfo(to_update_in_formats, to_update_in_types, to_update_out_formats,
|
||||
to_update_out_types, kernel::GetProcessorFromContext());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, to_update.get());
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format);
|
||||
to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<size_t, AnfNodePtrList> ParallelOpCombiner::GetUniqueInputs(const Group &branches, size_t depth) {
|
||||
std::map<size_t, AnfNodePtrList> unique_inputs;
|
||||
AnfNodePtrList parent_in_branch;
|
||||
if (depth >= 1) {
|
||||
std::transform(branches.begin(), branches.end(), std::back_inserter(parent_in_branch),
|
||||
[&depth](const Branch &br) { return br.ops[depth - 1]; });
|
||||
} else {
|
||||
Branch b1 = branches[0];
|
||||
parent_in_branch.push_back(b1.GetRootData());
|
||||
}
|
||||
|
||||
for (auto br : branches) {
|
||||
auto op = br.ops[depth];
|
||||
auto cnode = op->cast<CNodePtr>();
|
||||
// Here we can know for sure that op's arg length are the same (check before)
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto in = cnode->input(i);
|
||||
if (std::any_of(parent_in_branch.begin(), parent_in_branch.end(),
|
||||
[&in](const AnfNodePtr &p) { return in == p; })) {
|
||||
continue;
|
||||
}
|
||||
unique_inputs[i].push_back(in);
|
||||
}
|
||||
}
|
||||
return unique_inputs;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node,
|
||||
size_t concat_dim, size_t input_num) {
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
|
||||
if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
|
||||
auto maketuple = NewTupleNode(func_graph, input_node);
|
||||
concat_inputs.push_back(maketuple);
|
||||
} else {
|
||||
for (size_t i = 0; i < input_node.size(); ++i) {
|
||||
auto n = input_node[i];
|
||||
concat_inputs.push_back(n);
|
||||
}
|
||||
}
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
func_graph->AddNode(concat);
|
||||
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(input_node[0], 0)};
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(input_node[0], 0);
|
||||
shape[concat_dim] *= input_num;
|
||||
std::vector<ShapeVector> shapes(1, shape);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, concat.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(concat_dim)), concat);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(static_cast<int64_t>(input_num)), concat);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrN, MakeValue(static_cast<int64_t>(input_num)), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs) {
|
||||
auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
AbstractBasePtrList abs_list;
|
||||
for (auto in : shared_inputs) {
|
||||
mk_inputs.push_back(in);
|
||||
abs_list.push_back(in->abstract());
|
||||
}
|
||||
auto make_tuple_node = func_graph->NewCNode(mk_inputs);
|
||||
func_graph->AddNode(make_tuple_node);
|
||||
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
return make_tuple_node;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
|
||||
size_t split_num) {
|
||||
if (split_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "split_num should not be zero.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<AnfNodePtr> split_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
|
||||
input_node};
|
||||
auto split = func_graph->NewCNode(split_inputs);
|
||||
func_graph->AddNode(split);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
auto dtype = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
std::vector<TypeId> dtypes(split_num, dtype);
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||
shape[split_dim] /= split_num;
|
||||
std::vector<ShapeVector> shapes(split_num, shape);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_num), split);
|
||||
return split;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtr &orig_node) {
|
||||
auto node = func_graph->NewCNode(inputs);
|
||||
func_graph->AddNode(node);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
|
||||
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
|
||||
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
|
||||
std::vector<ShapeVector> shapes = {common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0)};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
|
||||
return node;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtr &orig_node) {
|
||||
auto node = func_graph->NewCNode(inputs);
|
||||
func_graph->AddNode(node);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
|
||||
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
|
||||
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
|
||||
auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
|
||||
auto orig_shape_in = common::AnfAlgo::GetPrevNodeOutputInferShape(orig_node, 0);
|
||||
auto orig_shape_out = common::AnfAlgo::GetOutputInferShape(orig_node, 0);
|
||||
auto new_out_shape = InferReshapeOut(orig_shape_in, orig_shape_out, new_shape_in);
|
||||
GetCNodePrimitive(node)->set_attr(kShape, MakeValue(new_out_shape));
|
||||
std::vector<ShapeVector> shapes = {new_out_shape};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
|
||||
return node;
|
||||
}
|
||||
|
||||
CNodePtr GraphBuilder::NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtr &orig_node) {
|
||||
auto node = func_graph->NewCNode(inputs);
|
||||
func_graph->AddNode(node);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() > kIndex1, "Input size should be larger than 1");
|
||||
MS_EXCEPTION_IF_NULL(inputs[kIndex1]);
|
||||
std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(inputs[kIndex1], 0)};
|
||||
auto new_shape_in = common::AnfAlgo::GetOutputInferShape(inputs[kIndex1], 0);
|
||||
auto perm_int = GetTransposePerm(GetCNodePrimitive(node));
|
||||
auto new_out_shape = InferTransposeOut(new_shape_in, perm_int);
|
||||
std::vector<ShapeVector> shapes = {new_out_shape};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, node.get());
|
||||
return node;
|
||||
}
|
||||
|
||||
ShapeVector GraphBuilder::InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
|
||||
const ShapeVector &new_reshape_in) {
|
||||
ShapeVector new_shape_out;
|
||||
if (orig_reshape_in.size() == new_reshape_in.size()) {
|
||||
return InferConcatReshapeOut(orig_reshape_in, orig_reshape_out, new_reshape_in);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Stack combiner infer for reshape not impl yet";
|
||||
}
|
||||
return new_shape_out;
|
||||
}
|
||||
|
||||
ShapeVector GraphBuilder::InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm) {
|
||||
ShapeVector out_shape;
|
||||
for (int64_t i : perm) {
|
||||
auto idx = LongToSize(i);
|
||||
out_shape.push_back(in_shape[idx]);
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "mindspore/core/ops/array_ops.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
struct Branch {
|
||||
Branch(AnfNodePtrList lst, int pos) : ops(lst), target_op_pos(pos) {}
|
||||
AnfNodePtrList ops;
|
||||
int target_op_pos; // -1 means no target op in this branch
|
||||
AnfNodePtr root_data{nullptr};
|
||||
size_t size() { return ops.size(); }
|
||||
AnfNodePtr GetTargetOp() { return GetOp(target_op_pos); }
|
||||
AnfNodePtr GetOp(int depth) {
|
||||
if (depth < 0 || depth >= static_cast<int>(ops.size())) {
|
||||
return nullptr;
|
||||
}
|
||||
return ops[depth];
|
||||
}
|
||||
|
||||
AnfNodePtr GetRootData() { return root_data; }
|
||||
void SetDataRoot(AnfNodePtr data) { root_data = data; }
|
||||
std::string ToString() {
|
||||
std::string res;
|
||||
res += "RootData: ";
|
||||
res += root_data->fullname_with_scope();
|
||||
res += "; Ops: [";
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
auto op = ops[i];
|
||||
res += op->fullname_with_scope();
|
||||
if (static_cast<int>(i) == target_op_pos) {
|
||||
res += "(LEAD OP)";
|
||||
}
|
||||
res += ", ";
|
||||
}
|
||||
res += "]";
|
||||
return res;
|
||||
}
|
||||
};
|
||||
using Group = std::vector<Branch>;
|
||||
using FIsSupportedOp = std::function<bool(const AnfNodePtr &n)>;
|
||||
using FAreCompatibleOps = std::function<bool(const AnfNodePtr &a, const AnfNodePtr &b)>;
|
||||
using AnfNodePtrSubstMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
|
||||
using AnfNodePtrSet = std::unordered_set<AnfNodePtr>;
|
||||
class BranchGroupFinder {
|
||||
public:
|
||||
BranchGroupFinder(const std::string &op_name, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops);
|
||||
std::vector<Group> Find(const AnfNodePtr &start_node, const FuncGraphPtr &func_graph = nullptr);
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
|
||||
|
||||
private:
|
||||
std::string op_name_;
|
||||
AnfNodePtrSet op_roots_;
|
||||
FIsSupportedOp fis_supported_op_;
|
||||
FAreCompatibleOps fare_compatible_ops_;
|
||||
Branch CreateBranch(AnfNodePtr lead_op);
|
||||
AnfNodeIndexSet GetConsumers(FuncGraphManagerPtr mng, const AnfNodePtr &producer);
|
||||
};
|
||||
|
||||
class ParallelOpCombiner {
|
||||
public:
|
||||
explicit ParallelOpCombiner(const std::string &op_name, uint64_t min_num_branches, const std::string &layout);
|
||||
AnfNodePtr Combine(const AnfNodePtr &root, const FuncGraphPtr &func_graph = nullptr);
|
||||
|
||||
protected:
|
||||
virtual bool IsSupportedOp(const AnfNodePtr n) = 0;
|
||||
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0;
|
||||
virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0;
|
||||
virtual bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) = 0;
|
||||
virtual AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
|
||||
size_t depth) = 0;
|
||||
virtual void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) = 0;
|
||||
bool AutoUpdateInfo(const CNodePtr &to_update, size_t out_size = 1);
|
||||
|
||||
std::map<size_t, AnfNodePtrList> GetUniqueInputs(const Group &branches, size_t depth);
|
||||
|
||||
FuncGraphPtr main_graph_;
|
||||
AnfNodePtr combined_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtrSet> children_map_;
|
||||
std::unordered_set<std::string> unsupported_ops_{prim::kTranspose, prim::kReshape};
|
||||
|
||||
private:
|
||||
void CombineBranches(const Group &branches);
|
||||
bool CheckLevel(const Group &branches, size_t depth);
|
||||
|
||||
std::string op_name_;
|
||||
uint64_t min_num_branches_{2};
|
||||
std::string layout_;
|
||||
};
|
||||
|
||||
class GraphBuilder {
|
||||
public:
|
||||
static CNodePtr NewTupleNode(const FuncGraphPtr &func_graph, AnfNodePtrList shared_inputs);
|
||||
static CNodePtr NewSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, size_t split_dim,
|
||||
size_t split_num);
|
||||
static CNodePtr NewConcatNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_node, size_t concat_dim,
|
||||
size_t input_num);
|
||||
static CNodePtr NewElemwiseNoAttrNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
|
||||
const AnfNodePtr &orig_node);
|
||||
static CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
|
||||
const AnfNodePtr &orig_node);
|
||||
static CNodePtr NewTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs,
|
||||
const AnfNodePtr &orig_node);
|
||||
|
||||
static ShapeVector InferReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
|
||||
const ShapeVector &new_reshape_in);
|
||||
static ShapeVector InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
|
||||
const ShapeVector &new_reshape_in);
|
||||
static ShapeVector InferTransposeOut(const ShapeVector &in_shape, const std::vector<int64_t> &perm);
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_COMBINE_H_
|
|
@ -0,0 +1,195 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/graph_kernel/core/parallel_op_concatenate.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/graph_kernel/adapter/callback_impl.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
ParallelOpConcatenater::ParallelOpConcatenater(const std::string &op_name, uint64_t min_num_branches,
|
||||
const std::string &layout)
|
||||
: ParallelOpCombiner(op_name, min_num_branches, layout) {}
|
||||
|
||||
bool ParallelOpConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) {
|
||||
auto cnode_a = a->cast<CNodePtr>();
|
||||
auto cnode_b = b->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode_a);
|
||||
MS_EXCEPTION_IF_NULL(cnode_b);
|
||||
auto arg_size = cnode_a->size();
|
||||
if (arg_size != cnode_b->size()) {
|
||||
MS_LOG(DEBUG) << "Args size not compatible: " << arg_size << " vs " << cnode_b->size();
|
||||
return false;
|
||||
}
|
||||
auto cb = Callback::Instance();
|
||||
for (size_t i = 1; i < arg_size; ++i) {
|
||||
auto shape_a = cb->GetInputInferShape(a, i);
|
||||
auto shape_b = cb->GetInputInferShape(b, i);
|
||||
if (shape_a != shape_b) {
|
||||
MS_LOG(ERROR) << "Args shape not compatible:" << shape_a << " vs " << shape_b;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr ParallelOpConcatenater::MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches,
|
||||
size_t depth) {
|
||||
auto ew_plan = GetElemWiseFollowingPlan(branches, depth);
|
||||
plans_.push_back(ew_plan);
|
||||
auto overall_inputs = ReloadInputs(branches, depth, data);
|
||||
if (branches.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to sample ops in a empty group.";
|
||||
}
|
||||
// Since all the ops of same depth in group should be the same, we just sample op in first branch.
|
||||
Branch b0 = branches[0];
|
||||
auto orig_node = b0.GetOp(static_cast<int>(depth));
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
CNodePtr new_node;
|
||||
if (GetCNodePrimitive(orig_node)->name() == prim::kReshape) {
|
||||
new_node = GraphBuilder::NewReshapeNode(main_graph_, overall_inputs, orig_node);
|
||||
} else if (GetCNodePrimitive(orig_node)->name() == prim::kTranspose) {
|
||||
new_node = GraphBuilder::NewTransposeNode(main_graph_, overall_inputs, orig_node);
|
||||
} else {
|
||||
new_node = GraphBuilder::NewElemwiseNoAttrNode(main_graph_, overall_inputs, orig_node);
|
||||
}
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(new_node), "AutoUpdateInfo fail");
|
||||
return new_node;
|
||||
}
|
||||
|
||||
std::map<size_t, AnfNodePtr> ParallelOpConcatenater::ConcatUniqueInputs(std::map<size_t, AnfNodePtrList> unique_inputs,
|
||||
size_t concat_idx) {
|
||||
std::map<size_t, AnfNodePtr> concated_inputs;
|
||||
for (auto it : unique_inputs) {
|
||||
size_t input_idx = it.first;
|
||||
auto local_inputs = it.second;
|
||||
if (local_inputs.size() < kDim2) {
|
||||
MS_LOG(WARNING) << "Concat Op needs at least 2 inputs, while got " << local_inputs.size();
|
||||
continue;
|
||||
}
|
||||
auto concat_node = GraphBuilder::NewConcatNode(main_graph_, local_inputs, concat_idx, local_inputs.size());
|
||||
MS_EXCEPTION_IF_NULL(concat_node);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(concat_node), "AutoUpdateInfo fail");
|
||||
concated_inputs[input_idx] = concat_node;
|
||||
}
|
||||
return concated_inputs;
|
||||
}
|
||||
|
||||
void ParallelOpConcatenater::UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) {
|
||||
if (depth >= plans_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot get plan at depth " << depth << " vs " << plans_.size();
|
||||
}
|
||||
auto ew_plan = plans_[depth];
|
||||
auto split_node = GraphBuilder::NewSplitNode(main_graph_, data, ew_plan.split_out_idx, branches.size());
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(split_node, branches.size()), "AutoUpdateInfo fail");
|
||||
main_graph_->AddNode(split_node);
|
||||
auto mng = main_graph_->manager();
|
||||
for (size_t i = 0; i < branches.size(); ++i) {
|
||||
auto br = branches[i];
|
||||
auto target = br.ops[depth];
|
||||
auto idx_val = MakeValue(SizeToLong(i));
|
||||
auto gt_idx = NewValueNode(idx_val);
|
||||
gt_idx->set_abstract(idx_val->ToAbstract());
|
||||
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), split_node, gt_idx};
|
||||
auto new_out = main_graph_->NewCNode(gt_inputs);
|
||||
new_out->set_abstract(target->abstract()->Clone());
|
||||
mng->Replace(target, new_out);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ConcatenatePlan ParallelOpConcatenater::GetElemWiseFollowingPlan(const Group &branches, size_t depth) {
|
||||
if (depth - 1 >= plans_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Should get " << (depth - 1) << " plan first, current plan size = " << plans_.size();
|
||||
}
|
||||
auto last_plan = plans_[depth - 1];
|
||||
ConcatenatePlan ew_plan;
|
||||
auto unique_inputs = GetUniqueInputs(branches, depth);
|
||||
auto cb = Callback::Instance();
|
||||
for (auto it : unique_inputs) {
|
||||
for (auto in : it.second) {
|
||||
ew_plan.in_shape = cb->GetOutputInferShape(in, 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto UpdateIdx = [](ShapeVector &base_shape, ShapeVector &new_shape, size_t base_idx) {
|
||||
if (new_shape.empty()) {
|
||||
return base_idx;
|
||||
}
|
||||
auto rank_diff = static_cast<int>(base_shape.size()) - static_cast<int>(new_shape.size());
|
||||
if (rank_diff > static_cast<int>(base_idx)) {
|
||||
return base_idx;
|
||||
}
|
||||
return base_idx - rank_diff;
|
||||
};
|
||||
ew_plan.concat_in_idx = UpdateIdx(last_plan.in_shape, ew_plan.in_shape, last_plan.concat_in_idx);
|
||||
Branch b0 = branches[0];
|
||||
auto op = b0.ops[depth];
|
||||
ew_plan.out_shape = cb->GetOutputInferShape(op, 0);
|
||||
ew_plan.split_out_idx = UpdateIdx(last_plan.out_shape, ew_plan.out_shape, last_plan.split_out_idx);
|
||||
MS_LOG(DEBUG) << "EW plan: " << ew_plan.concat_in_idx << ", " << ew_plan.split_out_idx << ", " << ew_plan.out_shape;
|
||||
return ew_plan;
|
||||
}
|
||||
|
||||
AnfNodePtrList ParallelOpConcatenater::ReloadInputs(const Group &branches, size_t depth, AnfNodePtr shared_input) {
|
||||
Branch b1 = branches[0];
|
||||
auto cnode = b1.ops[depth]->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_size = cnode->size();
|
||||
auto plan = plans_[depth];
|
||||
auto unique_inputs = GetUniqueInputs(branches, depth);
|
||||
AnfNodePtrList overall_inputs{cnode->input(0)}; // prim
|
||||
auto concated_inputs = ConcatUniqueInputs(unique_inputs, plan.concat_in_idx);
|
||||
for (size_t i = 1; i < input_size; ++i) {
|
||||
if (concated_inputs.find(i) != concated_inputs.end()) {
|
||||
overall_inputs.push_back(concated_inputs[i]);
|
||||
} else {
|
||||
overall_inputs.push_back(shared_input);
|
||||
}
|
||||
}
|
||||
return overall_inputs;
|
||||
}
|
||||
|
||||
ShapeVector GraphBuilder::InferConcatReshapeOut(const ShapeVector &orig_reshape_in, const ShapeVector &orig_reshape_out,
|
||||
const ShapeVector &new_reshape_in) {
|
||||
std::map<int, int> idx_map_rev;
|
||||
std::map<int, int> mul_map;
|
||||
int oidx = static_cast<int>(orig_reshape_out.size()) - 1;
|
||||
for (int ridx = static_cast<int>(orig_reshape_in.size()) - 1; ridx >= 0; --ridx) {
|
||||
auto cur_size = orig_reshape_in[ridx];
|
||||
mul_map[ridx] = new_reshape_in[ridx] / orig_reshape_in[ridx];
|
||||
while (oidx >= 0 && cur_size >= orig_reshape_out[oidx] && cur_size % orig_reshape_out[oidx] == 0) {
|
||||
idx_map_rev[oidx] = ridx;
|
||||
cur_size = cur_size / orig_reshape_out[oidx];
|
||||
oidx--;
|
||||
}
|
||||
}
|
||||
ShapeVector new_shape_out;
|
||||
for (int i = 0; i < static_cast<int>(orig_reshape_out.size()); ++i) {
|
||||
auto in_idx = idx_map_rev[i];
|
||||
auto mul = mul_map[in_idx];
|
||||
new_shape_out.push_back(orig_reshape_out[i] * mul);
|
||||
mul_map[in_idx] = 1;
|
||||
}
|
||||
return new_shape_out;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either AnfNodePtress or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/graph_kernel/core/parallel_op_combine.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
struct ConcatenatePlan {
|
||||
int concat_in_idx{0};
|
||||
int split_out_idx{0};
|
||||
ShapeVector in_shape;
|
||||
ShapeVector out_shape;
|
||||
};
|
||||
|
||||
class ParallelOpConcatenater : public ParallelOpCombiner {
|
||||
public:
|
||||
explicit ParallelOpConcatenater(const std::string &op_name, uint64_t min_num_branches, const std::string &layout);
|
||||
|
||||
protected:
|
||||
virtual bool IsSupportedOp(const AnfNodePtr n) = 0;
|
||||
virtual bool CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) = 0;
|
||||
virtual AnfNodePtr MakeCombinedOp(const Group &branches) = 0;
|
||||
bool IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b);
|
||||
AnfNodePtr MakeCombinedAnfNodePtrFromFollowingOps(const AnfNodePtr &data, const Group &branches, size_t depth) final;
|
||||
void UpdateGroupOutput(const AnfNodePtr &data, const Group &branches, size_t depth) final;
|
||||
|
||||
std::map<size_t, AnfNodePtr> ConcatUniqueInputs(std::map<size_t, AnfNodePtrList> unique_inputs, size_t concat_idx);
|
||||
ConcatenatePlan GetElemWiseFollowingPlan(const Group &branches, size_t depth);
|
||||
AnfNodePtrList ReloadInputs(const Group &branches, size_t depth, AnfNodePtr shared_input);
|
||||
std::vector<ConcatenatePlan> plans_;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_OP_CONCATENATE_H_
|
|
@ -319,6 +319,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("enable_lite_conv_tuning", &enable_lite_conv_tuning);
|
||||
reg.AddFlag("enable_vectorization", &enable_vectorization);
|
||||
reg.AddFlag("enable_dynamic_shape_fusion", &enable_dynamic_shape_fusion);
|
||||
reg.AddFlag("enable_parallel_op_combine", &enable_parallel_op_combine);
|
||||
|
||||
// Integer flags
|
||||
reg.AddFlag("reduce_fuse_depth", &reduce_fuse_depth);
|
||||
|
|
|
@ -93,6 +93,11 @@ class BACKEND_EXPORT GraphKernelFlags {
|
|||
*/
|
||||
unsigned int parallel_ops_level{OpLevel_0};
|
||||
|
||||
/**
|
||||
* Enable parallel op combination, default is false.
|
||||
*/
|
||||
bool enable_parallel_op_combine{false};
|
||||
|
||||
/**
|
||||
* Enable horizontal fusion in graph kernel fusion strategy, default is false.
|
||||
*/
|
||||
|
|
|
@ -130,6 +130,16 @@ NodePtr LiteGraph::GraphBuilderBase::Emit(const std::string &op, const NodePtrLi
|
|||
return op_ptr;
|
||||
}
|
||||
|
||||
NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBaseList &baseinfolist,
|
||||
const NodePtrList &inputs, const DAttrs &attrs) const {
|
||||
PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName());
|
||||
op_ptr->SetInputs(inputs);
|
||||
op_ptr->SetAttrs(attrs);
|
||||
op_ptr->SetBaseInfo(baseinfolist);
|
||||
(void)graph_->ops_.emplace_back(op_ptr);
|
||||
return op_ptr;
|
||||
}
|
||||
|
||||
NodePtr LiteGraph::GraphBuilderBase::Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
|
||||
const DAttrs &attrs) const {
|
||||
PrimOpPtr op_ptr = CreateOp(op, graph_->NodeName());
|
||||
|
|
|
@ -72,6 +72,8 @@ class LiteGraph::GraphBuilderBase {
|
|||
NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const;
|
||||
|
||||
// Create op node with given baseinfo.
|
||||
NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs,
|
||||
const DAttrs &attrs = {}) const;
|
||||
NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
|
||||
const DAttrs &attrs = {}) const;
|
||||
LiteGraphPtr Get() const { return graph_; }
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "backend/common/graph_kernel/core/transform_op_optimizer.h"
|
||||
#include "backend/common/graph_kernel/core/update_state_formatter.h"
|
||||
#include "backend/common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h"
|
||||
|
||||
#include "tools/graph_kernel/converter/akg/utils.h"
|
||||
#include "tools/graph_kernel/converter/callback_impl.h"
|
||||
|
@ -77,6 +78,11 @@ GkPassManagerPtr GraphKernelOptimizer::Cluster() const {
|
|||
auto pm = std::make_shared<GraphKernelPassManagerLite>(kStageCluster, "cluster");
|
||||
// Expand complex basic kernels to composite kernels
|
||||
pm->Add(std::make_shared<GraphKernelExpanderLite>(), OptLevel_1);
|
||||
|
||||
if (GraphKernelFlags::GetInstance().enable_parallel_op_combine || is_ascend) {
|
||||
pm->Add(std::make_shared<GraphKernelOpCombiner>(), OptLevel_2);
|
||||
}
|
||||
|
||||
pm->Add(std::make_shared<ConvTuningExpander>(), OptLevel_1, is_cpu);
|
||||
|
||||
// Cluster basic kernels and composite kernels
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
class CustomNet(Cell):
|
||||
def __init__(self):
|
||||
super(CustomNet, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.add = P.Add()
|
||||
self.abs = P.Abs()
|
||||
|
||||
def construct(self, mx_a, mx_b1, bias1, mx_b2, bias2):
|
||||
# use Abs to create a shared input for matmul
|
||||
abs1 = P.Abs()(mx_a)
|
||||
abs2 = P.Abs()(mx_a)
|
||||
|
||||
# branch 1: matmul - add - abs
|
||||
m1 = self.matmul(abs1, mx_b1)
|
||||
m1 = self.add(m1, bias1)
|
||||
m1 = self.abs(m1)
|
||||
|
||||
# branch 2: matmul - add - abs
|
||||
m2 = self.matmul(abs2, mx_b2)
|
||||
m2 = self.add(m2, bias2)
|
||||
m2 = self.abs(m2)
|
||||
return m1, m2
|
||||
|
||||
|
||||
def get_output(i0, i1, i2, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel, save_graphs=False)
|
||||
net = CustomNet()
|
||||
mx_a = i0
|
||||
mx_b1 = i1 * 3
|
||||
mx_b2 = i1 * 2
|
||||
bias1 = i2 * 3
|
||||
bias2 = i2 * 2
|
||||
output = net(mx_a, mx_b1, bias1, mx_b2, bias2)
|
||||
return output
|
||||
|
||||
|
||||
def run():
|
||||
i0 = Tensor(np.random.normal(1, 0.01, [96, 800]).astype(np.float16))
|
||||
i1 = Tensor(np.random.normal(1, 0.01, [800, 128]).astype(np.float16))
|
||||
i2 = Tensor(np.random.normal(1, 0.01, [1, 128]).astype(np.float16))
|
||||
|
||||
expect = get_output(i0, i1, i2, False)
|
||||
output = get_output(i0, i1, i2, True)
|
||||
for exp, out in zip(expect, output):
|
||||
expect_np = exp.asnumpy().copy()
|
||||
output_np = out.asnumpy().copy()
|
||||
if not np.allclose(expect_np, output_np, 1.e-4, 1.e-7):
|
||||
return false
|
||||
return true
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parallel_matmul_combine_ascend():
|
||||
"""
|
||||
Feature: Parallel Matmul combination
|
||||
Description: on Ascend device
|
||||
Expectation: network return same result with the feature on and off
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_op_combine=1 --opt_level=1")
|
||||
assert run()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parallel_matmul_combine_gpu():
|
||||
"""
|
||||
Feature: Parallel Matmul combination
|
||||
Description: on GPU device
|
||||
Expectation: network return same result with the feature on and off
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(graph_kernel_flags="--enable_parallel_op_combine=1 --opt_level=1")
|
||||
assert run()
|
Loading…
Reference in New Issue