forked from mindspore-Ecosystem/mindspore
!19487 [MS][LITE]fix weight quantlize bug of control flow graph
Merge pull request !19487 from mengyuanli/add_control_flow
This commit is contained in:
commit
86eb2a9605
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
|
@ -252,7 +253,30 @@ STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position)
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||
all_func_graphs->insert(func_graph);
|
||||
auto nodes = func_graph->GetOrderedCnodes();
|
||||
std::deque<CNodePtr> to_process{};
|
||||
to_process.insert(to_process.end(), nodes.begin(), nodes.end());
|
||||
while (!to_process.empty()) {
|
||||
auto &cur_cnode = to_process.front();
|
||||
to_process.pop_front();
|
||||
for (auto &input : cur_cnode->inputs()) {
|
||||
if (!IsValueNode<FuncGraph>(input)) {
|
||||
continue;
|
||||
}
|
||||
auto new_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
|
||||
continue;
|
||||
}
|
||||
all_func_graphs->insert(new_fg);
|
||||
auto new_nodes = new_fg->GetOrderedCnodes();
|
||||
to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum);
|
||||
|
@ -281,6 +305,19 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
std::set<FuncGraphPtr> all_func_graphs{};
|
||||
GetFuncGraphs(old_graph, &all_func_graphs);
|
||||
for (auto &item : all_func_graphs) {
|
||||
auto status = DoSingleGraphQuantize(item, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
MS_ASSERT(old_graph != nullptr);
|
||||
if (config == nullptr) {
|
||||
|
|
|
@ -54,6 +54,10 @@ class AnfTransform {
|
|||
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
||||
|
||||
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
||||
|
||||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue