support ascend infer process
fix compile add FuseLayerNorm pattern
This commit is contained in:
parent
d5e933bbb9
commit
ec1570cac8
|
@ -20,6 +20,7 @@
|
|||
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/op_adapter.cc" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/drop_out_gen_mask_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/include/backend/optimizer/op_adaptation_info_factory.h" "runtime/explicit"
|
||||
|
|
|
@ -93,6 +93,7 @@ OP_REGISTER("Acosh", ElemwiseOp);
|
|||
OP_REGISTER("Atan", ElemwiseOp);
|
||||
OP_REGISTER("Atan2", ElemwiseOp);
|
||||
OP_REGISTER("Expm1", ElemwiseOp);
|
||||
OP_REGISTER("FastGeLU", ElemwiseOp);
|
||||
// broadcast ops
|
||||
OP_REGISTER("BroadcastTo", BroadcastOp);
|
||||
OP_REGISTER("Tile", BroadcastOp);
|
||||
|
|
|
@ -193,8 +193,8 @@ bool FuseMatMul::Match(const AreaPtr &dom) {
|
|||
continue;
|
||||
}
|
||||
auto user_name = a->dom()->op();
|
||||
// MatMul + Add/Cast or BatchMatMul + elemwise
|
||||
if ((dom_name == kMatMulOpName && (user_name == kTensorAddOpName || user_name == kCastOpName)) ||
|
||||
if ((dom_name == kMatMulOpName &&
|
||||
(user_name == kTensorAddOpName || user_name == kCastOpName || user_name == kFastGeLUOpName)) ||
|
||||
(dom_name == kBatchMatMulOpName && a->pattern() == NodePattern::ELEMWISE)) {
|
||||
if (!HasCircle(dom, a)) {
|
||||
(void)fused_areas_.emplace_back(a);
|
||||
|
|
|
@ -1962,7 +1962,7 @@ std::vector<OutHandler> DfGraphConvertor::GetInputHandles(const AnfNodePtr &node
|
|||
MS_EXCEPTION_IF_NULL(pred_adpt);
|
||||
// When node's output is dynamic or node has multiple output, it need to get all handles.
|
||||
// TupleGetItem's input is dynamic output(eg:MakeTuple), but it only need to get one handle.
|
||||
if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp())) {
|
||||
if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp(input))) {
|
||||
MS_EXCEPTION_IF_NULL(Convert(input));
|
||||
handles = pred_adpt->getOutputs(Convert(input));
|
||||
} else {
|
||||
|
|
|
@ -19,11 +19,99 @@
|
|||
#include <map>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/split_combination_ops.h"
|
||||
#include "graph/operator_factory_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace transform {
|
||||
static uint32_t CustomInferFunc(const Operator &) { return 0; }
|
||||
|
||||
static ge::graphStatus CustomAkgOpInferFunc(Operator &op) {
|
||||
// output_names
|
||||
std::vector<std::string> output_names;
|
||||
auto status = op.GetAttr("output_names", output_names);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// output_shapes
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
status = op.GetAttr("output_shapes", output_shapes);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// output_formats
|
||||
std::vector<int32_t> output_formats;
|
||||
status = op.GetAttr("output_formats", output_formats);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
if (output_formats.size() != output_names.size()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// output_types
|
||||
std::vector<int32_t> output_types;
|
||||
status = op.GetAttr("output_types", output_types);
|
||||
if (status != 0) {
|
||||
return status;
|
||||
}
|
||||
if (output_types.size() != output_names.size()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Update output tensor desc
|
||||
for (size_t i = 0; i < output_names.size(); ++i) {
|
||||
ge::TensorDesc output_desc(ge::Shape(output_shapes[i]), static_cast<ge::Format>(output_formats[i]),
|
||||
static_cast<ge::DataType>(output_types[i]));
|
||||
op.UpdateOutputDesc(output_names[i], output_desc);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// check a Custom node is an akg kernel, it should be called in the case of node is a Custom node.
|
||||
bool IsAkgOp(const AnfNodePtr &node) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto type = prim->GetAttr("type");
|
||||
return (type != nullptr && GetValue<std::string>(type) == "GraphKernel");
|
||||
}
|
||||
|
||||
void RegisterAkgOp(const PrimitivePtr &prim, const std::string &op_type) {
|
||||
if (ge::OperatorFactoryImpl::IsExistOp(op_type)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto input_names_v = prim->GetAttr("input_names");
|
||||
MS_EXCEPTION_IF_NULL(input_names_v);
|
||||
auto input_names = GetValue<std::vector<std::string>>(input_names_v);
|
||||
auto output_names_v = prim->GetAttr("output_names");
|
||||
MS_EXCEPTION_IF_NULL(output_names_v);
|
||||
auto output_names = GetValue<std::vector<std::string>>(output_names_v);
|
||||
// Register op create function, which describes how to create a custom akg op
|
||||
(void)ge::OperatorFactoryImpl::RegisterOperatorCreator(op_type,
|
||||
[op_type, input_names, output_names](const std::string &name) {
|
||||
auto op = ge::CustomOperator(name, op_type);
|
||||
for (const auto &in_name : input_names) {
|
||||
op.CustomInputRegister(in_name);
|
||||
}
|
||||
for (const auto &out_name : output_names) {
|
||||
op.CustomOutputRegister(out_name);
|
||||
}
|
||||
op.CustomRequiredAttrRegister("info_path");
|
||||
op.CustomInferFuncRegister(CustomAkgOpInferFunc);
|
||||
return op;
|
||||
});
|
||||
// Register op infer shape function
|
||||
(void)ge::OperatorFactoryImpl::RegisterInferShapeFunc(op_type, CustomAkgOpInferFunc);
|
||||
}
|
||||
|
||||
bool OpAdapterImpl::IsCustomOp(const OperatorPtr &op) const {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
auto it = cus_input_map_->find(op->GetOpType());
|
||||
|
@ -118,7 +206,13 @@ OperatorPtr OpAdapterImpl::GenerateCustomOp(const AnfNodePtr anf) {
|
|||
MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "].";
|
||||
}
|
||||
|
||||
op->CustomInferFuncRegister(CustomInferFunc);
|
||||
if (IsAkgOp(anf)) {
|
||||
op->CustomRequiredAttrRegister("info_path");
|
||||
op->CustomInferFuncRegister(CustomAkgOpInferFunc);
|
||||
RegisterAkgOp(prim, op_type);
|
||||
} else {
|
||||
op->CustomInferFuncRegister(CustomInferFunc);
|
||||
}
|
||||
|
||||
return op;
|
||||
}
|
||||
|
@ -368,6 +462,12 @@ Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstra
|
|||
MS_EXCEPTION_IF_NULL(cus_op);
|
||||
std::map<int, std::string> output_map = (*cus_output_map_)[op->GetOpType()];
|
||||
(void)cus_op->UpdateOutputDesc(output_map[0], *desc);
|
||||
std::vector<std::vector<int64_t>> out_shapes{desc->GetShape().GetDims()};
|
||||
std::vector<int32_t> out_formats{static_cast<int32_t>(desc->GetFormat())};
|
||||
std::vector<int32_t> out_types{static_cast<int32_t>(desc->GetDataType())};
|
||||
cus_op->SetAttr("output_shapes", out_shapes);
|
||||
cus_op->SetAttr("output_formats", out_formats);
|
||||
cus_op->SetAttr("output_types", out_types);
|
||||
} else {
|
||||
if (!output_map_.empty()) {
|
||||
output_map_.begin()->second.update_out_desc(op, *desc);
|
||||
|
@ -433,6 +533,9 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> out_shapes;
|
||||
std::vector<int32_t> out_formats;
|
||||
std::vector<int32_t> out_types;
|
||||
for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
|
||||
auto tuple_type = dyn_cast<Tuple>(type);
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
|
@ -447,6 +550,9 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
if (is_custom_op) {
|
||||
(void)std::dynamic_pointer_cast<CustomOperator>(op)->UpdateOutputDesc((*cus_output_map_)[op->GetOpType()][i],
|
||||
*desc);
|
||||
out_shapes.push_back(desc->GetShape().GetDims());
|
||||
out_formats.push_back(static_cast<int32_t>(desc->GetFormat()));
|
||||
out_types.push_back(static_cast<int32_t>(desc->GetDataType()));
|
||||
} else {
|
||||
auto it = output_map_.find(i);
|
||||
if (it != output_map_.end()) {
|
||||
|
@ -456,6 +562,11 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
}
|
||||
}
|
||||
}
|
||||
if (is_custom_op) {
|
||||
op->SetAttr("output_shapes", out_shapes);
|
||||
op->SetAttr("output_formats", out_formats);
|
||||
op->SetAttr("output_types", out_types);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -245,7 +245,22 @@ class OpAdapter : public BaseOpAdapter {
|
|||
}
|
||||
bool IsDynInputOp(uint64_t index) override { return dyn_input_map_.find(index) != dyn_input_map_.end(); }
|
||||
bool IsDyOutputOp(uint64_t index) override { return dyn_output_map_.find(index) != dyn_output_map_.end(); }
|
||||
bool IsMultipleOutputOp() override { return output_map_.size() > 1; }
|
||||
bool IsMultipleOutputOp(const AnfNodePtr &anf) override {
|
||||
if (IsCustomCNode(anf)) {
|
||||
// Custom op
|
||||
auto node = anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->inputs().at(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_type = impl_->GetCustomOpType(prim);
|
||||
if (cus_output_map_.find(op_type) != cus_output_map_.end()) {
|
||||
return cus_output_map_[op_type].size() > 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Normal op
|
||||
return output_map_.size() > 1;
|
||||
}
|
||||
|
||||
Status SetOpSubgraphFunc(const OperatorPtr &op, std::shared_ptr<std::vector<DfGraph>> subgraphs) {
|
||||
return impl_->SetOpSubgraphFunc(op, subgraphs);
|
||||
|
|
|
@ -45,6 +45,8 @@ class CustomOperator : public Operator {
|
|||
|
||||
void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); }
|
||||
|
||||
void CustomRequiredAttrRegister(const string &name) { Operator::RequiredAttrRegister(name); }
|
||||
|
||||
void CustomInferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
|
||||
Operator::InferFuncRegister(func);
|
||||
}
|
||||
|
@ -172,7 +174,7 @@ class BaseOpAdapter {
|
|||
virtual std::map<std::string, ValuePtr> GetNormalOpAttrList(const AnfNodePtr &node) = 0;
|
||||
virtual bool IsDynInputOp(uint64_t index) = 0;
|
||||
virtual bool IsDyOutputOp(uint64_t index) = 0;
|
||||
virtual bool IsMultipleOutputOp() = 0;
|
||||
virtual bool IsMultipleOutputOp(const AnfNodePtr &anf) = 0;
|
||||
void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); }
|
||||
const std::vector<std::string> &GetAttrsFromDrawGraph() const { return attrs_vec_; }
|
||||
void clearAttrVect() { attrs_vec_.clear(); }
|
||||
|
|
|
@ -10,6 +10,10 @@ file(STRINGS "${TOP_DIR}/version.txt" MSVERSION)
|
|||
add_definitions(-DMSVERSION=\"${MSVERSION}\")
|
||||
add_compile_definitions(ENABLE_SECURITY)
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER AND MSLITE_ENABLE_GRAPH_KERNEL)
|
||||
add_compile_definitions(MSLITE_ENABLE_GRAPH_KERNEL)
|
||||
endif()
|
||||
|
||||
#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
|
||||
file(GLOB GE_EXECUTOR_SRC
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
#include "src/extendrt/utils/func_graph_utils.h"
|
||||
#include "transform/graph_ir/transform_util.h"
|
||||
#include "flow_graph/data_flow.h"
|
||||
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
|
||||
#include "tools/graph_kernel/converter/graph_kernel_optimization.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -245,6 +248,12 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map
|
|||
MS_LOG(ERROR) << "Input param graph_id is nullptr.";
|
||||
return false;
|
||||
}
|
||||
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
|
||||
if (GraphKernelOptimize(anf_graph, nullptr) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Run graphkernel optimization failed.";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
std::map<std::string, std::string> ge_options;
|
||||
GetGeGraphOptions(anf_graph, &ge_options);
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ AnfNodePtr AscendKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, c
|
|||
auto inputs = cnode->inputs();
|
||||
inputs[0] = NewValueNode(custom_prim);
|
||||
auto custom_cnode = func_graph->NewCNode(inputs);
|
||||
custom_prim->EraseAttr("IsFeatureMapInputList");
|
||||
custom_prim->EraseAttr("IsFeatureMapOutput");
|
||||
|
||||
auto json_kernel_name = node_info_map_[cnode->cast<AnfNodePtr>()];
|
||||
auto input_num = AnfUtils::GetInputTensorNum(cnode);
|
||||
|
@ -44,7 +46,10 @@ AnfNodePtr AscendKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, c
|
|||
output_names.push_back("y" + std::to_string(i));
|
||||
}
|
||||
|
||||
custom_prim->set_attr("reg_op_name", MakeValue(json_kernel_name));
|
||||
std::ostringstream oss;
|
||||
oss << "Fused_x" << input_num << "_y" << output_num;
|
||||
std::string op_tye = oss.str();
|
||||
custom_prim->set_attr("reg_op_name", MakeValue(op_tye));
|
||||
custom_prim->set_attr("info_path", MakeValue(dir_path_ + "/" + json_kernel_name + ".info"));
|
||||
custom_prim->set_attr("input_names", MakeValue(input_names));
|
||||
custom_prim->set_attr("output_names", MakeValue(output_names));
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/graph_kernel/converter/eliminate_redundant_op.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "backend/common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
bool EliminateReshape(const CNodePtr &cnode, const FuncGraphManagerPtr &mng) {
|
||||
// Reshape + FastGeLU + Reshape --> FastGeLU
|
||||
auto input = cnode->input(kIndex1);
|
||||
if (!IsPrimitiveCNode(input, prim::kPrimReshape) || mng->node_users()[input].size() > 1) {
|
||||
return false;
|
||||
}
|
||||
auto users = mng->node_users()[cnode];
|
||||
if (users.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto user = users.begin()->first;
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimReshape)) {
|
||||
return false;
|
||||
}
|
||||
auto cb = Callback::Instance();
|
||||
auto input_in_shape = cb->GetInputShape(input, 0);
|
||||
auto user_out_shape = cb->GetOutputShape(user, 0);
|
||||
if (input_in_shape == user_out_shape) {
|
||||
MS_LOG(INFO) << "Eliminate Reshape around: " << cnode->fullname_with_scope();
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
auto input_in_node = input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(input_in_node);
|
||||
cnode->set_input(kIndex1, input_in_node);
|
||||
cnode->set_abstract(input_in_node->abstract()->Clone());
|
||||
(void)mng->Replace(user, cnode);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool EliminateTranspose(const CNodePtr &cnode, const FuncGraphManagerPtr &mng) {
|
||||
// Reshape + Transpose + Reshape --> Reshape
|
||||
auto input = cnode->input(kIndex1);
|
||||
if (!IsPrimitiveCNode(input, prim::kPrimReshape) || mng->node_users()[input].size() > 1) {
|
||||
return false;
|
||||
}
|
||||
auto users = mng->node_users()[cnode];
|
||||
if (users.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto user = users.begin()->first;
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimReshape)) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int64_t> perm_list;
|
||||
if (cnode->input(kIndex2)->isa<Parameter>()) {
|
||||
auto perm_para = cnode->input(kIndex2)->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(perm_para);
|
||||
auto perm_tensor = perm_para->default_param()->cast<tensor::TensorPtr>();
|
||||
auto perm = static_cast<int32_t *>(perm_tensor->data_ptr()->data());
|
||||
std::transform(perm, perm + perm_tensor->shape()[0], std::back_inserter(perm_list), IntToLong);
|
||||
} else {
|
||||
auto perm_value = cnode->input(kIndex2)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
perm_list = GetValue<std::vector<int64_t>>(perm_value->value());
|
||||
}
|
||||
std::vector<int64_t> opt_perm = {0, 2, 1, 3};
|
||||
auto cb = Callback::Instance();
|
||||
auto x_shape = cb->GetInputShape(cnode, 0);
|
||||
if (perm_list == opt_perm && x_shape.size() == opt_perm.size() && (x_shape[kIndex1] == 1 || x_shape[kIndex2] == 1)) {
|
||||
MS_LOG(INFO) << "Eliminate Transpose: " << cnode->fullname_with_scope();
|
||||
auto user_cnode = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
user_cnode->set_input(kIndex1, input_cnode->input(kIndex1));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool EliminateRedundantOp::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (const auto &node : todos) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimFastGeLU)) {
|
||||
changed = EliminateReshape(cnode, mng) || changed;
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimTranspose)) {
|
||||
changed = EliminateTranspose(cnode, mng) || changed;
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H
|
||||
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class EliminateRedundantOp : public opt::Pass {
|
||||
public:
|
||||
EliminateRedundantOp() : Pass("eliminate_redundant_op") {}
|
||||
~EliminateRedundantOp() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OP_H
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "backend/common/graph_kernel/expanders/op_desc_registry.h"
|
||||
#include "tools/graph_kernel/converter/expanders/activation.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "ir/dtype.h"
|
||||
|
||||
namespace mindspore::graphkernel::expanders {
|
||||
class ReduceMean : public OpDesc {
|
||||
public:
|
||||
ReduceMean() {
|
||||
std::initializer_list<std::string> attrs{"axis", "keep_dims"};
|
||||
(void)validators_.emplace_back(std::make_unique<CheckAttr>(attrs));
|
||||
}
|
||||
~ReduceMean() = default;
|
||||
|
||||
protected:
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
const auto &x = inputs[0];
|
||||
auto rank = SizeToLong(x->shape.size());
|
||||
auto axis = GetAxisList(attrs_["axis"]);
|
||||
(void)std::for_each(axis.begin(), axis.end(), [rank](auto &a) { a = a < 0 ? a + rank : a; });
|
||||
if (axis.empty()) {
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
axis.push_back(i);
|
||||
}
|
||||
}
|
||||
int64_t sz = 1;
|
||||
for (size_t i = 0; i < x->shape.size(); ++i) {
|
||||
if (std::find(axis.begin(), axis.end(), SizeToLong(i)) != axis.end()) {
|
||||
sz *= SizeToLong(x->shape[i]);
|
||||
}
|
||||
}
|
||||
auto sum_x = gb.ReduceSum(x, axis, GetValue<bool>(attrs_["keep_dims"]));
|
||||
auto result = gb.Div(sum_x, gb.Const(sz, x->type));
|
||||
return {result};
|
||||
}
|
||||
};
|
||||
EXPANDER_OP_DESC_REGISTER("ReduceMean", ReduceMean);
|
||||
} // namespace mindspore::graphkernel::expanders
|
|
@ -27,16 +27,31 @@
|
|||
namespace mindspore::graphkernel {
|
||||
std::vector<PrimitivePtr> GraphKernelClusterLite::GetClusterableOpList() {
|
||||
std::vector<OpWithLevel> clusterable_ops_with_level = {
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAdd}, {kAllTarget, OpLevel_0, prim::kPrimMul},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimSub}, {kAllTarget, OpLevel_0, prim::kPrimRealDiv},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimLog}, {kAllTarget, OpLevel_0, prim::kPrimExp},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimPow}, {kAllTarget, OpLevel_0, prim::kPrimNeg},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimRsqrt}, {kAllTarget, OpLevel_0, prim::kPrimSqrt},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimSin}, {kAllTarget, OpLevel_0, prim::kPrimTanh},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimCos}, {kAllTarget, OpLevel_0, prim::kPrimGreater},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimGreaterEqual}, {kAllTarget, OpLevel_0, prim::kPrimLess},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimLessEqual}, {kAllTarget, OpLevel_0, prim::kPrimLogicalAnd},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimLogicalOr}, {kAllTarget, OpLevel_0, prim::kPrimLogicalNot},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAdd},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimMul},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimSub},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimSqrt},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimRealDiv},
|
||||
// ascend device
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimMatMul},
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimAssign},
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimFastGeLU},
|
||||
// cpu device
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLog},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimExp},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimPow},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimNeg},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimRsqrt},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimSin},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimTanh},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimCos},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimGreater},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimGreaterEqual},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLess},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLessEqual},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLogicalAnd},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLogicalOr},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimLogicalNot},
|
||||
};
|
||||
const auto &flags = GraphKernelFlags::GetInstance();
|
||||
return GkUtils::GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level, flags.enable_cluster_ops_only,
|
||||
|
|
|
@ -148,16 +148,18 @@ bool GraphKernelExpanderLite::DisableConvTuning() {
|
|||
|
||||
std::vector<PrimitivePtr> GraphKernelExpanderLite::InitOpList() {
|
||||
std::vector<OpWithLevel> expand_ops_with_level = {{kAllTarget, OpLevel_0, prim::kPrimSquare},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimExpandDims},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimSqueeze},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimTranspose},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimReshape},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimGather},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimShape},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimConcat},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimFusedBatchNorm},
|
||||
{kAllTarget, OpLevel_1, prim::kPrimSoftmax},
|
||||
// ascend device
|
||||
{kAscendDevice, OpLevel_0, prim::kPrimReduceMean},
|
||||
// cpu device
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimExpandDims},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimSqueeze},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimTranspose},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimReshape},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimGather},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimShape},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimConcat},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimFusedBatchNorm},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimSoftmax},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimAddFusion},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimMulFusion},
|
||||
{kCPUDevice, OpLevel_0, prim::kPrimSubFusion},
|
||||
|
@ -226,6 +228,7 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
|
|||
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimTranspose->name(), {InputToAttrDeco::Creator}},
|
||||
{prim::kPrimGather->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimReduceMean->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimConcat->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}},
|
||||
|
|
|
@ -38,6 +38,8 @@
|
|||
#include "tools/graph_kernel/converter/parameter_to_tensor.h"
|
||||
#include "tools/graph_kernel/converter/eliminate_maketuple_getitem.h"
|
||||
#include "tools/graph_kernel/converter/callback_impl.h"
|
||||
#include "tools/graph_kernel/converter/split_umonad.h"
|
||||
#include "tools/graph_kernel/converter/eliminate_redundant_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace graphkernel {
|
||||
|
@ -61,6 +63,12 @@ GkPassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
|||
// put an empty pass here to dump the ir before GraphKernel
|
||||
pm->Add(std::make_shared<EmptyPass>(), OptLevel_1);
|
||||
|
||||
// Assign(p, a, U) --> Depend(Assign(p, a), U)
|
||||
pm->Add(std::make_shared<SplitAssign>(), OptLevel_1, is_ascend);
|
||||
|
||||
// Eliminate redundant op, such as Reshape
|
||||
pm->Add(std::make_shared<EliminateRedundantOp>(), OptLevel_1, is_ascend);
|
||||
|
||||
// Recognize the formats for all CNodes
|
||||
pm->Add(std::make_shared<FormatRecognition>(), OptLevel_1);
|
||||
|
||||
|
@ -153,7 +161,12 @@ lite::STATUS GraphKernelOptimize(const FuncGraphPtr &func_graph, const std::shar
|
|||
#endif
|
||||
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
MS_LOG(INFO) << "Run graphkernel optimization begin.";
|
||||
graphkernel::GraphKernelOptimizer(param).Run(func_graph);
|
||||
auto p = param;
|
||||
if (p == nullptr) {
|
||||
p = std::make_shared<ConverterPara>();
|
||||
p->device = "Ascend";
|
||||
}
|
||||
graphkernel::GraphKernelOptimizer(p).Run(func_graph);
|
||||
MS_LOG(INFO) << "Run graphkernel optimization end.";
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -23,8 +23,36 @@ SPLIT_MODEL_REGISTER("Ascend", SplitModelAscend);
|
|||
constexpr size_t kReduceFusionDepth = 10;
|
||||
constexpr size_t kBroadcastFusionDepth = 6;
|
||||
|
||||
class FuseLayerNorm : public FusePattern {
|
||||
public:
|
||||
FuseLayerNorm() : FusePattern("layer_norm") { direction_ = FuseDirection::BACKWARD; }
|
||||
~FuseLayerNorm() = default;
|
||||
|
||||
protected:
|
||||
bool Check(const AreaPtr &dom) override { return (dom->dom()->op() == "ReduceSum"); }
|
||||
bool Match(const AreaPtr &dom) override {
|
||||
constexpr size_t c1 = 1;
|
||||
constexpr size_t c2 = 2;
|
||||
auto users = dom->users();
|
||||
if (users.size() != c1 || users[0]->pattern() != NodePattern::BROADCAST) {
|
||||
return false;
|
||||
}
|
||||
auto user_users = users[0]->users();
|
||||
if (user_users.size() != c2) {
|
||||
return false;
|
||||
}
|
||||
if ((user_users[0]->pattern() == NodePattern::REDUCE && user_users[1]->pattern() == NodePattern::BROADCAST) ||
|
||||
(user_users[0]->pattern() == NodePattern::BROADCAST && user_users[1]->pattern() == NodePattern::REDUCE)) {
|
||||
(void)fused_areas_.emplace_back(users[0]);
|
||||
(void)fused_areas_.emplace_back(user_users[0]);
|
||||
(void)fused_areas_.emplace_back(user_users[1]);
|
||||
}
|
||||
return !fused_areas_.empty();
|
||||
}
|
||||
};
|
||||
|
||||
void SplitModelAscend::InitFusePatterns() {
|
||||
AddPattern(std::make_shared<FuseReshape>(), true);
|
||||
AddPattern(std::make_shared<FuseVirtualNode>(), true);
|
||||
AddPattern(std::make_shared<ascend::FuseMatMul>(), true);
|
||||
AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true);
|
||||
AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true);
|
||||
|
@ -32,12 +60,13 @@ void SplitModelAscend::InitFusePatterns() {
|
|||
AddPattern(FuseReduceFwd::CreateWidthMatcher(kReduceFusionDepth), true);
|
||||
AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(kBroadcastFusionDepth), true);
|
||||
AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(kBroadcastFusionDepth), true);
|
||||
AddPattern(std::make_shared<FuseLayerNorm>(), true);
|
||||
}
|
||||
|
||||
AreaMode SplitModelAscend::GetDefaultAreaMode(const PrimOpPtr &node) const {
|
||||
if (node != nullptr && node->op() == kReshapeOpName) {
|
||||
return AreaMode::BASIC;
|
||||
if (node != nullptr && node->op() == "MatMul") {
|
||||
return AreaMode::COMPOSITE;
|
||||
}
|
||||
return AreaMode::COMPOSITE;
|
||||
return AreaMode::BASIC;
|
||||
}
|
||||
} // namespace mindspore::graphkernel::inner
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/graph_kernel/converter/split_umonad.h"
|
||||
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
/*
|
||||
* %1 = Assign(param, %0, UMonad)
|
||||
* =================>
|
||||
* %1 = Assign(param, %0)
|
||||
* %2 = Depend(%1, UMonad)
|
||||
* */
|
||||
bool SplitAssign::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (const auto &node : todos) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !IsPrimitiveCNode(cnode, prim::kPrimAssign)) {
|
||||
continue;
|
||||
}
|
||||
constexpr size_t umonad_idx = 3;
|
||||
if (cnode->inputs().size() != umonad_idx + 1) {
|
||||
continue;
|
||||
}
|
||||
auto umonad = cnode->input(umonad_idx);
|
||||
if (!HasAbstractUMonad(umonad)) {
|
||||
continue;
|
||||
}
|
||||
AnfNodePtrList new_inputs(cnode->inputs().begin(), cnode->inputs().begin() + umonad_idx);
|
||||
cnode->set_inputs(new_inputs);
|
||||
auto depend_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimDepend), cnode, umonad});
|
||||
depend_cnode->set_abstract(node->abstract()->Clone());
|
||||
(void)mng->Replace(node, depend_cnode);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_
|
||||
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "include/backend/optimizer/pass.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class SplitAssign : public opt::Pass {
|
||||
public:
|
||||
SplitAssign() : Pass("split_assign") {}
|
||||
~SplitAssign() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_UMONAD_H_
|
|
@ -102,6 +102,8 @@ Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::string OpAdapterImpl::GetCustomOpType(const PrimitivePtr &prim) const { return ""; }
|
||||
|
||||
bool IsCustomCNode(const mindspore::AnfNodePtr &node) { return true; }
|
||||
std::string TransformUtil::NormOpName(const std::string &anf_name) { return ""; }
|
||||
} // namespace transform
|
||||
|
|
Loading…
Reference in New Issue