!19487 [MS][LITE]fix weight quantlize bug of control flow graph

Merge pull request !19487 from mengyuanli/add_control_flow
This commit is contained in:
i-robot 2021-07-07 08:04:07 +00:00 committed by Gitee
commit 86eb2a9605
2 changed files with 42 additions and 1 deletions

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <deque>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ir/primitive.h" #include "mindspore/core/ir/primitive.h"
@ -252,7 +253,30 @@ STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position)
return RET_OK; return RET_OK;
} }
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) { void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
all_func_graphs->insert(func_graph);
auto nodes = func_graph->GetOrderedCnodes();
std::deque<CNodePtr> to_process{};
to_process.insert(to_process.end(), nodes.begin(), nodes.end());
while (!to_process.empty()) {
auto &cur_cnode = to_process.front();
to_process.pop_front();
for (auto &input : cur_cnode->inputs()) {
if (!IsValueNode<FuncGraph>(input)) {
continue;
}
auto new_fg = GetValueNode<FuncGraphPtr>(input);
if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
continue;
}
all_func_graphs->insert(new_fg);
auto new_nodes = new_fg->GetOrderedCnodes();
to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
}
}
}
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant // quant
if (config->quantType == schema::QuantType_PostTraining) { if (config->quantType == schema::QuantType_PostTraining) {
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum); this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum);
@ -281,6 +305,19 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
return RET_OK; return RET_OK;
} }
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
std::set<FuncGraphPtr> all_func_graphs{};
GetFuncGraphs(old_graph, &all_func_graphs);
for (auto &item : all_func_graphs) {
auto status = DoSingleGraphQuantize(item, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return status;
}
}
return RET_OK;
}
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(old_graph != nullptr); MS_ASSERT(old_graph != nullptr);
if (config == nullptr) { if (config == nullptr) {

View File

@ -54,6 +54,10 @@ class AnfTransform {
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position); static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore