train convert fusion
This commit is contained in:
parent
6da6713e7a
commit
0e66043d3b
|
@ -84,7 +84,84 @@ AnfTransform::AnfTransform() = default;
|
|||
|
||||
AnfTransform::~AnfTransform() = default;
|
||||
|
||||
STATUS AnfTransform::FindInputCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (!utils::isa<CNodePtr>(input_node)) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = utils::cast<CNodePtr>(input_node);
|
||||
MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
|
||||
auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(DEBUG) << "Primitive is nullptr.";
|
||||
continue;
|
||||
}
|
||||
prim->AddAttr("trainOp", MakeValue(true));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::FindSameParameterCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto graph_cnode = utils::cast<CNodePtr>(node);
|
||||
MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
|
||||
auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
|
||||
if (graph_prim == nullptr) {
|
||||
MS_LOG(DEBUG) << "Primitive is nullptr.";
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < graph_cnode->inputs().size(); i++) {
|
||||
for (size_t j = 1; j < cnode->inputs().size(); j++) {
|
||||
if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
|
||||
graph_prim->AddAttr("trainOp", MakeValue(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::FindTrainOp(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(DEBUG) << "Primitive is nullptr.";
|
||||
continue;
|
||||
}
|
||||
if (opt::IsTrainOp(cnode)) {
|
||||
prim->AddAttr("trainOp", MakeValue(true));
|
||||
auto status = FindInputCnode(func_graph, cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FindInputCnode failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = FindSameParameterCnode(func_graph, cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FindSameParameterCnode failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto status = FindTrainOp(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FindTrainOp failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
|
||||
|
||||
|
@ -126,9 +203,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
}
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
|
||||
if (!config->trainModel) {
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
|
||||
}
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
|
||||
optimizer->AddPassManager(fusion_pm);
|
||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||
MS_LOG(ERROR) << "run op fusion failed.";
|
||||
|
|
|
@ -58,6 +58,12 @@ class AnfTransform {
|
|||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
void AppendPassToStoreRoom(const converter::Flags *config);
|
||||
|
||||
static STATUS FindInputCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
static STATUS FindSameParameterCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
static STATUS FindTrainOp(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1087,5 +1087,26 @@ STATUS FetchShapeFromAbstract(const abstract::AbstractBasePtr &abstract, ShapeVe
|
|||
*shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool IsTrainOp(const CNodePtr &cnode) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto cnode_type = prim->type_name();
|
||||
// optimizer op
|
||||
if (cnode_type == "Adam" || cnode_type == "SGD" || cnode_type == "ApplyMomentum") {
|
||||
return true;
|
||||
}
|
||||
// loss op
|
||||
if (cnode_type == "SoftmaxCrossEntropyWithLogits" || cnode_type == "SpareSoftmaxCrossEntropyWithLogits" ||
|
||||
cnode_type == "SmoothL1Loss" || cnode_type == "SmoothL1LossGrad" ||
|
||||
cnode_type == "SigmoidCrossEntropyWithLogits" || cnode_type == "SigmoidCrossEntropyWithLogpitsGrad") {
|
||||
return true;
|
||||
}
|
||||
// grad op
|
||||
if (cnode_type.find("Grad") != std::string::npos ||
|
||||
cnode->fullname_with_scope().find("Gradients") != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -123,6 +123,9 @@ inline bool IsSpecifiedNode(const BaseRef &n) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsTrainOp(const CNodePtr &cnode);
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_
|
||||
|
|
|
@ -57,6 +57,11 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
|
||||
continue;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
||||
if (prim->GetAttr("trainOp") != nullptr && GetValue<bool>(prim->GetAttr("trainOp"))) {
|
||||
continue;
|
||||
}
|
||||
size_t index = 0;
|
||||
if (!CheckAndGetMatMulIndex(cnode, &index)) {
|
||||
continue;
|
||||
|
@ -67,6 +72,11 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
(!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) {
|
||||
continue;
|
||||
}
|
||||
auto matmul_prim = GetValueNode<PrimitivePtr>(matmul_cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(matmul_prim != nullptr, false);
|
||||
if (matmul_prim->GetAttr("trainOp") != nullptr && GetValue<bool>(matmul_prim->GetAttr("trainOp"))) {
|
||||
continue;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
|
||||
|
|
Loading…
Reference in New Issue