support ascend infer process

fix compile

add FuseLayerNorm pattern
This commit is contained in:
looop5 2023-04-12 11:50:02 +08:00
parent d5e933bbb9
commit ec1570cac8
20 changed files with 544 additions and 31 deletions

View File

@ -20,6 +20,7 @@
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
"mindspore/mindspore/ccsrc/transform/graph_ir/op_adapter.cc" "runtime/references"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc" "build/include"
"mindspore/mindspore/ccsrc/include/backend/optimizer/op_adaptation_info_factory.h" "runtime/explicit"

View File

@ -93,6 +93,7 @@ OP_REGISTER("Acosh", ElemwiseOp);
OP_REGISTER("Atan", ElemwiseOp);
OP_REGISTER("Atan2", ElemwiseOp);
OP_REGISTER("Expm1", ElemwiseOp);
OP_REGISTER("FastGeLU", ElemwiseOp);
// broadcast ops
OP_REGISTER("BroadcastTo", BroadcastOp);
OP_REGISTER("Tile", BroadcastOp);

View File

@ -193,8 +193,8 @@ bool FuseMatMul::Match(const AreaPtr &dom) {
continue;
}
auto user_name = a->dom()->op();
// MatMul + Add/Cast or BatchMatMul + elemwise
if ((dom_name == kMatMulOpName && (user_name == kTensorAddOpName || user_name == kCastOpName)) ||
if ((dom_name == kMatMulOpName &&
(user_name == kTensorAddOpName || user_name == kCastOpName || user_name == kFastGeLUOpName)) ||
(dom_name == kBatchMatMulOpName && a->pattern() == NodePattern::ELEMWISE)) {
if (!HasCircle(dom, a)) {
(void)fused_areas_.emplace_back(a);

View File

@ -1962,7 +1962,7 @@ std::vector<OutHandler> DfGraphConvertor::GetInputHandles(const AnfNodePtr &node
MS_EXCEPTION_IF_NULL(pred_adpt);
// When node's output is dynamic or node has multiple output, it need to get all handles.
// TupleGetItem's input is dynamic output(eg:MakeTuple), but it only need to get one handle.
if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp())) {
if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp(input))) {
MS_EXCEPTION_IF_NULL(Convert(input));
handles = pred_adpt->getOutputs(Convert(input));
} else {

View File

@ -19,11 +19,99 @@
#include <map>
#include "utils/check_convert_utils.h"
#include "ops/split_combination_ops.h"
#include "graph/operator_factory_impl.h"
namespace mindspore {
namespace transform {
static uint32_t CustomInferFunc(const Operator &) { return 0; }
static ge::graphStatus CustomAkgOpInferFunc(Operator &op) {
// output_names
std::vector<std::string> output_names;
auto status = op.GetAttr("output_names", output_names);
if (status != 0) {
return status;
}
// output_shapes
std::vector<std::vector<int64_t>> output_shapes;
status = op.GetAttr("output_shapes", output_shapes);
if (status != 0) {
return status;
}
if (output_shapes.size() != output_names.size()) {
return 1;
}
// output_formats
std::vector<int32_t> output_formats;
status = op.GetAttr("output_formats", output_formats);
if (status != 0) {
return status;
}
if (output_formats.size() != output_names.size()) {
return 1;
}
// output_types
std::vector<int32_t> output_types;
status = op.GetAttr("output_types", output_types);
if (status != 0) {
return status;
}
if (output_types.size() != output_names.size()) {
return 1;
}
// Update output tensor desc
for (size_t i = 0; i < output_names.size(); ++i) {
ge::TensorDesc output_desc(ge::Shape(output_shapes[i]), static_cast<ge::Format>(output_formats[i]),
static_cast<ge::DataType>(output_types[i]));
op.UpdateOutputDesc(output_names[i], output_desc);
}
return 0;
}
// check a Custom node is an akg kernel, it should be called in the case of node is a Custom node.
bool IsAkgOp(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto type = prim->GetAttr("type");
return (type != nullptr && GetValue<std::string>(type) == "GraphKernel");
}
void RegisterAkgOp(const PrimitivePtr &prim, const std::string &op_type) {
if (ge::OperatorFactoryImpl::IsExistOp(op_type)) {
return;
}
MS_EXCEPTION_IF_NULL(prim);
auto input_names_v = prim->GetAttr("input_names");
MS_EXCEPTION_IF_NULL(input_names_v);
auto input_names = GetValue<std::vector<std::string>>(input_names_v);
auto output_names_v = prim->GetAttr("output_names");
MS_EXCEPTION_IF_NULL(output_names_v);
auto output_names = GetValue<std::vector<std::string>>(output_names_v);
// Register op create function, which describes how to create a custom akg op
(void)ge::OperatorFactoryImpl::RegisterOperatorCreator(op_type,
[op_type, input_names, output_names](const std::string &name) {
auto op = ge::CustomOperator(name, op_type);
for (const auto &in_name : input_names) {
op.CustomInputRegister(in_name);
}
for (const auto &out_name : output_names) {
op.CustomOutputRegister(out_name);
}
op.CustomRequiredAttrRegister("info_path");
op.CustomInferFuncRegister(CustomAkgOpInferFunc);
return op;
});
// Register op infer shape function
(void)ge::OperatorFactoryImpl::RegisterInferShapeFunc(op_type, CustomAkgOpInferFunc);
}
bool OpAdapterImpl::IsCustomOp(const OperatorPtr &op) const {
MS_EXCEPTION_IF_NULL(op);
auto it = cus_input_map_->find(op->GetOpType());
@ -118,7 +206,13 @@ OperatorPtr OpAdapterImpl::GenerateCustomOp(const AnfNodePtr anf) {
MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "].";
}
op->CustomInferFuncRegister(CustomInferFunc);
if (IsAkgOp(anf)) {
op->CustomRequiredAttrRegister("info_path");
op->CustomInferFuncRegister(CustomAkgOpInferFunc);
RegisterAkgOp(prim, op_type);
} else {
op->CustomInferFuncRegister(CustomInferFunc);
}
return op;
}
@ -368,6 +462,12 @@ Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstra
MS_EXCEPTION_IF_NULL(cus_op);
std::map<int, std::string> output_map = (*cus_output_map_)[op->GetOpType()];
(void)cus_op->UpdateOutputDesc(output_map[0], *desc);
std::vector<std::vector<int64_t>> out_shapes{desc->GetShape().GetDims()};
std::vector<int32_t> out_formats{static_cast<int32_t>(desc->GetFormat())};
std::vector<int32_t> out_types{static_cast<int32_t>(desc->GetDataType())};
cus_op->SetAttr("output_shapes", out_shapes);
cus_op->SetAttr("output_formats", out_formats);
cus_op->SetAttr("output_types", out_types);
} else {
if (!output_map_.empty()) {
output_map_.begin()->second.update_out_desc(op, *desc);
@ -433,6 +533,9 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
return FAILED;
}
std::vector<std::vector<int64_t>> out_shapes;
std::vector<int32_t> out_formats;
std::vector<int32_t> out_types;
for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
auto tuple_type = dyn_cast<Tuple>(type);
MS_EXCEPTION_IF_NULL(tuple_type);
@ -447,6 +550,9 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
if (is_custom_op) {
(void)std::dynamic_pointer_cast<CustomOperator>(op)->UpdateOutputDesc((*cus_output_map_)[op->GetOpType()][i],
*desc);
out_shapes.push_back(desc->GetShape().GetDims());
out_formats.push_back(static_cast<int32_t>(desc->GetFormat()));
out_types.push_back(static_cast<int32_t>(desc->GetDataType()));
} else {
auto it = output_map_.find(i);
if (it != output_map_.end()) {
@ -456,6 +562,11 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
}
}
}
if (is_custom_op) {
op->SetAttr("output_shapes", out_shapes);
op->SetAttr("output_formats", out_formats);
op->SetAttr("output_types", out_types);
}
return SUCCESS;
}

View File

@ -245,7 +245,22 @@ class OpAdapter : public BaseOpAdapter {
}
bool IsDynInputOp(uint64_t index) override { return dyn_input_map_.find(index) != dyn_input_map_.end(); }
bool IsDyOutputOp(uint64_t index) override { return dyn_output_map_.find(index) != dyn_output_map_.end(); }
bool IsMultipleOutputOp() override { return output_map_.size() > 1; }
bool IsMultipleOutputOp(const AnfNodePtr &anf) override {
if (IsCustomCNode(anf)) {
// Custom op
auto node = anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto prim = GetValueNode<PrimitivePtr>(node->inputs().at(0));
MS_EXCEPTION_IF_NULL(prim);
auto op_type = impl_->GetCustomOpType(prim);
if (cus_output_map_.find(op_type) != cus_output_map_.end()) {
return cus_output_map_[op_type].size() > 1;
}
return false;
}
// Normal op
return output_map_.size() > 1;
}
Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) {
return impl_->SetOpSubgraphFunc(op, subgraphs);

View File

@ -45,6 +45,8 @@ class CustomOperator : public Operator {
void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); }
void CustomRequiredAttrRegister(const string &name) { Operator::RequiredAttrRegister(name); }
void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
Operator::InferFuncRegister(func);
}
@ -172,7 +174,7 @@ class BaseOpAdapter {
virtual std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) = 0;
virtual bool IsDynInputOp(uint64_t index) = 0;
virtual bool IsDyOutputOp(uint64_t index) = 0;
virtual bool IsMultipleOutputOp() = 0;
virtual bool IsMultipleOutputOp(const AnfNodePtr &anf) = 0;
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
void clearAttrVect() { attrs_vec_.clear(); }

View File

@ -10,6 +10,10 @@ file(STRINGS "${TOP_DIR}/version.txt" MSVERSION)
add_definitions(-DMSVERSION=\"${MSVERSION}\")
add_compile_definitions(ENABLE_SECURITY)
if(MSLITE_ENABLE_CONVERTER AND MSLITE_ENABLE_GRAPH_KERNEL)
add_compile_definitions(MSLITE_ENABLE_GRAPH_KERNEL)
endif()
#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
file(GLOB GE_EXECUTOR_SRC

View File

@ -33,6 +33,9 @@
#include "src/extendrt/utils/func_graph_utils.h"
#include "transform/graph_ir/transform_util.h"
#include "flow_graph/data_flow.h"
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
#include "tools/graph_kernel/converter/graph_kernel_optimization.h"
#endif
namespace mindspore {
namespace {
@ -245,6 +248,12 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map
MS_LOG(ERROR) << "Input param graph_id is nullptr.";
return false;
}
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
if (GraphKernelOptimize(anf_graph, nullptr) != lite::RET_OK) {
MS_LOG(ERROR) << "Run graphkernel optimization failed.";
return false;
}
#endif
std::map<std::string, std::string> ge_options;
GetGeGraphOptions(anf_graph, &ge_options);

View File

@ -31,6 +31,8 @@ AnfNodePtr AscendKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, c
auto inputs = cnode->inputs();
inputs[0] = NewValueNode(custom_prim);
auto custom_cnode = func_graph->NewCNode(inputs);
custom_prim->EraseAttr("IsFeatureMapInputList");
custom_prim->EraseAttr("IsFeatureMapOutput");
auto json_kernel_name = node_info_map_[cnode->cast<AnfNodePtr>()];
auto input_num = AnfUtils::GetInputTensorNum(cnode);
@ -44,7 +46,10 @@ AnfNodePtr AscendKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, c
output_names.push_back("y" + std::to_string(i));
}
custom_prim->set_attr("reg_op_name", MakeValue(json_kernel_name));
std::ostringstream oss;
oss << "Fused_x" << input_num << "_y" << output_num;
std::string op_tye = oss.str();
custom_prim->set_attr("reg_op_name", MakeValue(op_tye));
custom_prim->set_attr("info_path", MakeValue(dir_path_ + "/" + json_kernel_name + ".info"));
custom_prim->set_attr("input_names", MakeValue(input_names));
custom_prim->set_attr("output_names", MakeValue(output_names));

View File

@ -0,0 +1,126 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/graph_kernel/converter/eliminate_redundant_op.h"
#include <algorithm>
#include <vector>
#include "mindspore/core/ops/core_ops.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
#include "utils/anf_utils.h"
#include "backend/common/graph_kernel/core/graph_kernel_callback.h"
namespace mindspore::graphkernel {
namespace {
bool EliminateReshape(const CNodePtr &cnode, const FuncGraphManagerPtr &mng) {
// Reshape + FastGeLU + Reshape --> FastGeLU
auto input = cnode->input(kIndex1);
if (!IsPrimitiveCNode(input, prim::kPrimReshape) || mng->node_users()[input].size() > 1) {
return false;
}
auto users = mng->node_users()[cnode];
if (users.size() != 1) {
return false;
}
auto user = users.begin()->first;
if (!IsPrimitiveCNode(user, prim::kPrimReshape)) {
return false;
}
auto cb = Callback::Instance();
auto input_in_shape = cb->GetInputShape(input, 0);
auto user_out_shape = cb->GetOutputShape(user, 0);
if (input_in_shape == user_out_shape) {
MS_LOG(INFO) << "Eliminate Reshape around: " << cnode->fullname_with_scope();
auto input_cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto input_in_node = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(input_in_node);
cnode->set_input(kIndex1, input_in_node);
cnode->set_abstract(input_in_node->abstract()->Clone());
(void)mng->Replace(user, cnode);
return true;
}
return false;
}
bool EliminateTranspose(const CNodePtr &cnode, const FuncGraphManagerPtr &mng) {
// Reshape + Transpose + Reshape --> Reshape
auto input = cnode->input(kIndex1);
if (!IsPrimitiveCNode(input, prim::kPrimReshape) || mng->node_users()[input].size() > 1) {
return false;
}
auto users = mng->node_users()[cnode];
if (users.size() != 1) {
return false;
}
auto user = users.begin()->first;
if (!IsPrimitiveCNode(user, prim::kPrimReshape)) {
return false;
}
std::vector<int64_t> perm_list;
if (cnode->input(kIndex2)->isa<Parameter>()) {
auto perm_para = cnode->input(kIndex2)->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(perm_para);
auto perm_tensor = perm_para->default_param()->cast<tensor::TensorPtr>();
auto perm = static_cast<int32_t *>(perm_tensor->data_ptr()->data());
std::transform(perm, perm + perm_tensor->shape()[0], std::back_inserter(perm_list), IntToLong);
} else {
auto perm_value = cnode->input(kIndex2)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(perm_value);
perm_list = GetValue<std::vector<int64_t>>(perm_value->value());
}
std::vector<int64_t> opt_perm = {0, 2, 1, 3};
auto cb = Callback::Instance();
auto x_shape = cb->GetInputShape(cnode, 0);
if (perm_list == opt_perm && x_shape.size() == opt_perm.size() && (x_shape[kIndex1] == 1 || x_shape[kIndex2] == 1)) {
MS_LOG(INFO) << "Eliminate Transpose: " << cnode->fullname_with_scope();
auto user_cnode = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
auto input_cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
user_cnode->set_input(kIndex1, input_cnode->input(kIndex1));
return true;
}
return false;
}
} // namespace
bool EliminateRedundantOp::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (const auto &node : todos) {
if (node == nullptr) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (IsPrimitiveCNode(cnode, prim::kPrimFastGeLU)) {
changed = EliminateReshape(cnode, mng) || changed;
} else if (IsPrimitiveCNode(cnode, prim::kPrimTranspose)) {
changed = EliminateTranspose(cnode, mng) || changed;
}
}
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,30 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H
#include "ir/func_graph.h"
#include "include/backend/optimizer/pass.h"
namespace mindspore::graphkernel {
class EliminateRedundantOp : public opt::Pass {
public:
EliminateRedundantOp() : Pass("eliminate_redundant_op") {}
~EliminateRedundantOp() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H

View File

@ -0,0 +1,56 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "backend/common/graph_kernel/expanders/op_desc_registry.h"
#include "tools/graph_kernel/converter/expanders/activation.h"
#include "mindapi/base/types.h"
#include "ir/dtype.h"
namespace mindspore::graphkernel::expanders {
class ReduceMean : public OpDesc {
public:
ReduceMean() {
std::initializer_list<std::string> attrs{"axis", "keep_dims"};
(void)validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
}
~ReduceMean() = default;
protected:
NodePtrList Expand(const NodePtrList &inputs) override {
const auto &x = inputs[0];
auto rank = SizeToLong(x->shape.size());
auto axis = GetAxisList(attrs_["axis"]);
(void)std::for_each(axis.begin(), axis.end(), [rank](auto &a) { a = a < 0 ? a + rank : a; });
if (axis.empty()) {
for (int64_t i = 0; i < rank; ++i) {
axis.push_back(i);
}
}
int64_t sz = 1;
for (size_t i = 0; i < x->shape.size(); ++i) {
if (std::find(axis.begin(), axis.end(), SizeToLong(i)) != axis.end()) {
sz *= SizeToLong(x->shape[i]);
}
}
auto sum_x = gb.ReduceSum(x, axis, GetValue<bool>(attrs_["keep_dims"]));
auto result = gb.Div(sum_x, gb.Const(sz, x->type));
return {result};
}
};
EXPANDER_OP_DESC_REGISTER("ReduceMean", ReduceMean);
} // namespace mindspore::graphkernel::expanders

View File

@ -27,16 +27,31 @@
namespace mindspore::graphkernel {
std::vector<PrimitivePtr> GraphKernelClusterLite::GetClusterableOpList() {
std::vector<OpWithLevel> clusterable_ops_with_level = {
{kAllTarget, OpLevel_0, prim::kPrimAdd}, {kAllTarget, OpLevel_0, prim::kPrimMul},
{kAllTarget, OpLevel_0, prim::kPrimSub}, {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
{kAllTarget, OpLevel_0, prim::kPrimLog}, {kAllTarget, OpLevel_0, prim::kPrimExp},
{kAllTarget, OpLevel_0, prim::kPrimPow}, {kAllTarget, OpLevel_0, prim::kPrimNeg},
{kAllTarget, OpLevel_0, prim::kPrimRsqrt}, {kAllTarget, OpLevel_0, prim::kPrimSqrt},
{kAllTarget, OpLevel_0, prim::kPrimSin}, {kAllTarget, OpLevel_0, prim::kPrimTanh},
{kAllTarget, OpLevel_0, prim::kPrimCos}, {kAllTarget, OpLevel_0, prim::kPrimGreater},
{kAllTarget, OpLevel_0, prim::kPrimGreaterEqual}, {kAllTarget, OpLevel_0, prim::kPrimLess},
{kAllTarget, OpLevel_0, prim::kPrimLessEqual}, {kAllTarget, OpLevel_0, prim::kPrimLogicalAnd},
{kAllTarget, OpLevel_0, prim::kPrimLogicalOr}, {kAllTarget, OpLevel_0, prim::kPrimLogicalNot},
{kAllTarget, OpLevel_0, prim::kPrimAdd},
{kAllTarget, OpLevel_0, prim::kPrimMul},
{kAllTarget, OpLevel_0, prim::kPrimSub},
{kAllTarget, OpLevel_0, prim::kPrimSqrt},
{kAllTarget, OpLevel_0, prim::kPrimRealDiv},
// ascend device
{kAscendDevice, OpLevel_0, prim::kPrimMatMul},
{kAscendDevice, OpLevel_0, prim::kPrimAssign},
{kAscendDevice, OpLevel_0, prim::kPrimFastGeLU},
// cpu device
{kCPUDevice, OpLevel_0, prim::kPrimLog},
{kCPUDevice, OpLevel_0, prim::kPrimExp},
{kCPUDevice, OpLevel_0, prim::kPrimPow},
{kCPUDevice, OpLevel_0, prim::kPrimNeg},
{kCPUDevice, OpLevel_0, prim::kPrimRsqrt},
{kCPUDevice, OpLevel_0, prim::kPrimSin},
{kCPUDevice, OpLevel_0, prim::kPrimTanh},
{kCPUDevice, OpLevel_0, prim::kPrimCos},
{kCPUDevice, OpLevel_0, prim::kPrimGreater},
{kCPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
{kCPUDevice, OpLevel_0, prim::kPrimLess},
{kCPUDevice, OpLevel_0, prim::kPrimLessEqual},
{kCPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
{kCPUDevice, OpLevel_0, prim::kPrimLogicalOr},
{kCPUDevice, OpLevel_0, prim::kPrimLogicalNot},
};
const auto &flags = GraphKernelFlags::GetInstance();
return GkUtils::GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level, flags.enable_cluster_ops_only,

View File

@ -148,16 +148,18 @@ bool GraphKernelExpanderLite::DisableConvTuning() {
std::vector<PrimitivePtr> GraphKernelExpanderLite::InitOpList() {
std::vector<OpWithLevel> expand_ops_with_level = {{kAllTarget, OpLevel_0, prim::kPrimSquare},
{kAllTarget, OpLevel_1, prim::kPrimExpandDims},
{kAllTarget, OpLevel_1, prim::kPrimSqueeze},
{kAllTarget, OpLevel_1, prim::kPrimTranspose},
{kAllTarget, OpLevel_1, prim::kPrimReshape},
{kAllTarget, OpLevel_1, prim::kPrimGather},
{kAllTarget, OpLevel_1, prim::kPrimShape},
{kAllTarget, OpLevel_1, prim::kPrimConcat},
{kAllTarget, OpLevel_1, prim::kPrimFusedBatchNorm},
{kAllTarget, OpLevel_1, prim::kPrimSoftmax},
// ascend device
{kAscendDevice, OpLevel_0, prim::kPrimReduceMean},
// cpu device
{kCPUDevice, OpLevel_1, prim::kPrimExpandDims},
{kCPUDevice, OpLevel_1, prim::kPrimSqueeze},
{kCPUDevice, OpLevel_1, prim::kPrimTranspose},
{kCPUDevice, OpLevel_1, prim::kPrimReshape},
{kCPUDevice, OpLevel_1, prim::kPrimGather},
{kCPUDevice, OpLevel_1, prim::kPrimShape},
{kCPUDevice, OpLevel_1, prim::kPrimConcat},
{kCPUDevice, OpLevel_1, prim::kPrimFusedBatchNorm},
{kCPUDevice, OpLevel_1, prim::kPrimSoftmax},
{kCPUDevice, OpLevel_0, prim::kPrimAddFusion},
{kCPUDevice, OpLevel_0, prim::kPrimMulFusion},
{kCPUDevice, OpLevel_0, prim::kPrimSubFusion},
@ -226,6 +228,7 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimTranspose->name(), {InputToAttrDeco::Creator}},
{prim::kPrimGather->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimReduceMean->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimConcat->name(), {FixFormatDeco::Creator}},
{prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}},
{prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}},

View File

@ -38,6 +38,8 @@
#include "tools/graph_kernel/converter/parameter_to_tensor.h"
#include "tools/graph_kernel/converter/eliminate_maketuple_getitem.h"
#include "tools/graph_kernel/converter/callback_impl.h"
#include "tools/graph_kernel/converter/split_umonad.h"
#include "tools/graph_kernel/converter/eliminate_redundant_op.h"
namespace mindspore {
namespace graphkernel {
@ -61,6 +63,12 @@ GkPassManagerPtr GraphKernelOptimizer::PreProcess() const {
// put an empty pass here to dump the ir before GraphKernel
pm->Add(std::make_shared<EmptyPass>(), OptLevel_1);
// Assign(p, a, U) --> Depend(Assign(p, a), U)
pm->Add(std::make_shared<SplitAssign>(), OptLevel_1, is_ascend);
// Eliminate redundant op, such as Reshape
pm->Add(std::make_shared<EliminateRedundantOp>(), OptLevel_1, is_ascend);
// Recognize the formats for all CNodes
pm->Add(std::make_shared<FormatRecognition>(), OptLevel_1);
@ -153,7 +161,12 @@ lite::STATUS GraphKernelOptimize(const FuncGraphPtr &func_graph, const std::shar
#endif
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
MS_LOG(INFO) << "Run graphkernel optimization begin.";
graphkernel::GraphKernelOptimizer(param).Run(func_graph);
auto p = param;
if (p == nullptr) {
p = std::make_shared<ConverterPara>();
p->device = "Ascend";
}
graphkernel::GraphKernelOptimizer(p).Run(func_graph);
MS_LOG(INFO) << "Run graphkernel optimization end.";
}
return lite::RET_OK;

View File

@ -23,8 +23,36 @@ SPLIT_MODEL_REGISTER("Ascend", SplitModelAscend);
constexpr size_t kReduceFusionDepth = 10;
constexpr size_t kBroadcastFusionDepth = 6;
class FuseLayerNorm : public FusePattern {
public:
FuseLayerNorm() : FusePattern("layer_norm") { direction_ = FuseDirection::BACKWARD; }
~FuseLayerNorm() = default;
protected:
bool Check(const AreaPtr &dom) override { return (dom->dom()->op() == "ReduceSum"); }
bool Match(const AreaPtr &dom) override {
constexpr size_t c1 = 1;
constexpr size_t c2 = 2;
auto users = dom->users();
if (users.size() != c1 || users[0]->pattern() != NodePattern::BROADCAST) {
return false;
}
auto user_users = users[0]->users();
if (user_users.size() != c2) {
return false;
}
if ((user_users[0]->pattern() == NodePattern::REDUCE && user_users[1]->pattern() == NodePattern::BROADCAST) ||
(user_users[0]->pattern() == NodePattern::BROADCAST && user_users[1]->pattern() == NodePattern::REDUCE)) {
(void)fused_areas_.emplace_back(users[0]);
(void)fused_areas_.emplace_back(user_users[0]);
(void)fused_areas_.emplace_back(user_users[1]);
}
return !fused_areas_.empty();
}
};
void SplitModelAscend::InitFusePatterns() {
AddPattern(std::make_shared<FuseReshape>(), true);
AddPattern(std::make_shared<FuseVirtualNode>(), true);
AddPattern(std::make_shared<ascend::FuseMatMul>(), true);
AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true);
AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true);
@ -32,12 +60,13 @@ void SplitModelAscend::InitFusePatterns() {
AddPattern(FuseReduceFwd::CreateWidthMatcher(kReduceFusionDepth), true);
AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(kBroadcastFusionDepth), true);
AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(kBroadcastFusionDepth), true);
AddPattern(std::make_shared<FuseLayerNorm>(), true);
}
AreaMode SplitModelAscend::GetDefaultAreaMode(const PrimOpPtr &node) const {
if (node != nullptr && node->op() == kReshapeOpName) {
return AreaMode::BASIC;
if (node != nullptr && node->op() == "MatMul") {
return AreaMode::COMPOSITE;
}
return AreaMode::COMPOSITE;
return AreaMode::BASIC;
}
} // namespace mindspore::graphkernel::inner

View File

@ -0,0 +1,61 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/graph_kernel/converter/split_umonad.h"
#include "mindspore/core/ops/core_ops.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
#include "utils/anf_utils.h"
namespace mindspore::graphkernel {
/*
* %1 = Assign(param, %0, UMonad)
* =================>
* %1 = Assign(param, %0)
* %2 = Depend(%1, UMonad)
* */
bool SplitAssign::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (const auto &node : todos) {
if (node == nullptr) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !IsPrimitiveCNode(cnode, prim::kPrimAssign)) {
continue;
}
constexpr size_t umonad_idx = 3;
if (cnode->inputs().size() != umonad_idx + 1) {
continue;
}
auto umonad = cnode->input(umonad_idx);
if (!HasAbstractUMonad(umonad)) {
continue;
}
AnfNodePtrList new_inputs(cnode->inputs().begin(), cnode->inputs().begin() + umonad_idx);
cnode->set_inputs(new_inputs);
auto depend_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimDepend), cnode, umonad});
depend_cnode->set_abstract(node->abstract()->Clone());
(void)mng->Replace(node, depend_cnode);
changed = true;
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,30 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_
#include "ir/func_graph.h"
#include "include/backend/optimizer/pass.h"
namespace mindspore::graphkernel {
class SplitAssign : public opt::Pass {
public:
SplitAssign() : Pass("split_assign") {}
~SplitAssign() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_

View File

@ -102,6 +102,8 @@ Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared
return SUCCESS;
}
std::string OpAdapterImpl::GetCustomOpType(const PrimitivePtr &prim) const { return ""; }
bool IsCustomCNode(const mindspore::AnfNodePtr &node) { return true; }
std::string TransformUtil::NormOpName(const std::string &anf_name) { return ""; }
} // namespace transform