train convert fusion

This commit is contained in:
yefeng 2021-09-08 10:32:21 +08:00
parent 6da6713e7a
commit 0e66043d3b
5 changed files with 118 additions and 3 deletions

View File

@ -84,7 +84,84 @@ AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
STATUS AnfTransform::FindInputCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (!utils::isa<CNodePtr>(input_node)) {
continue;
}
auto input_cnode = utils::cast<CNodePtr>(input_node);
MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (prim == nullptr) {
MS_LOG(DEBUG) << "Primitive is nullptr.";
continue;
}
prim->AddAttr("trainOp", MakeValue(true));
}
return RET_OK;
}
STATUS AnfTransform::FindSameParameterCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto graph_cnode = utils::cast<CNodePtr>(node);
MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
if (graph_prim == nullptr) {
MS_LOG(DEBUG) << "Primitive is nullptr.";
continue;
}
for (size_t i = 1; i < graph_cnode->inputs().size(); i++) {
for (size_t j = 1; j < cnode->inputs().size(); j++) {
if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
graph_prim->AddAttr("trainOp", MakeValue(true));
}
}
}
}
return RET_OK;
}
STATUS AnfTransform::FindTrainOp(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = utils::cast<CNodePtr>(node);
MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
MS_LOG(DEBUG) << "Primitive is nullptr.";
continue;
}
if (opt::IsTrainOp(cnode)) {
prim->AddAttr("trainOp", MakeValue(true));
auto status = FindInputCnode(func_graph, cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindInputCnode failed.";
return RET_ERROR;
}
status = FindSameParameterCnode(func_graph, cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindSameParameterCnode failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto status = FindTrainOp(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "FindTrainOp failed.";
return RET_ERROR;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
@ -126,9 +203,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
}
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
if (!config->trainModel) {
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
}
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
optimizer->AddPassManager(fusion_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run op fusion failed.";

View File

@ -58,6 +58,12 @@ class AnfTransform {
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
void AppendPassToStoreRoom(const converter::Flags *config);
static STATUS FindInputCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
static STATUS FindSameParameterCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
static STATUS FindTrainOp(const FuncGraphPtr &func_graph);
};
} // namespace lite
} // namespace mindspore

View File

@ -1087,5 +1087,26 @@ STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVe
*shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return lite::RET_OK;
}
bool IsTrainOp(const CNodePtr &cnode) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
auto cnode_type = prim->type_name();
// optimizer op
if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
return true;
}
// loss op
if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SpareSoftmaxCrossEntropyWithLogits" ||
cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogpitsGrad") {
return true;
}
// grad op
if (cnode_type.find("Grad") != std::string::npos ||
cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
return true;
}
return false;
}
} // namespace opt
} // namespace mindspore

View File

@ -123,6 +123,9 @@ inline bool IsSpecifiedNode(const BaseRef &n) {
}
return false;
}
bool IsTrainOp(const CNodePtr &cnode);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_

View File

@ -57,6 +57,11 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_RET(prim != nullptr, false);
if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
continue;
}
size_t index = 0;
if (!CheckAndGetMatMulIndex(cnode, &index)) {
continue;
@ -67,6 +72,11 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
(!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) {
continue;
}
auto matmul_prim = GetValueNode<PrimitivePtr>(matmul_cnode->input(0));
MS_CHECK_TRUE_RET(matmul_prim != nullptr, false);
if (matmul_prim->GetAttr("trainOp") != nullptr && GetValue<bool>(matmul_prim->GetAttr("trainOp"))) {
continue;
}
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());