forked from mindspore-Ecosystem/mindspore
!19487 [MS][LITE]fix weight quantlize bug of control flow graph
Merge pull request !19487 from mengyuanli/add_control_flow
This commit is contained in:
commit
86eb2a9605
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <deque>
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "mindspore/core/ir/primitive.h"
|
#include "mindspore/core/ir/primitive.h"
|
||||||
|
@ -252,7 +253,30 @@ STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position)
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||||
|
all_func_graphs->insert(func_graph);
|
||||||
|
auto nodes = func_graph->GetOrderedCnodes();
|
||||||
|
std::deque<CNodePtr> to_process{};
|
||||||
|
to_process.insert(to_process.end(), nodes.begin(), nodes.end());
|
||||||
|
while (!to_process.empty()) {
|
||||||
|
auto &cur_cnode = to_process.front();
|
||||||
|
to_process.pop_front();
|
||||||
|
for (auto &input : cur_cnode->inputs()) {
|
||||||
|
if (!IsValueNode<FuncGraph>(input)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto new_fg = GetValueNode<FuncGraphPtr>(input);
|
||||||
|
if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
all_func_graphs->insert(new_fg);
|
||||||
|
auto new_nodes = new_fg->GetOrderedCnodes();
|
||||||
|
to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||||
// quant
|
// quant
|
||||||
if (config->quantType == schema::QuantType_PostTraining) {
|
if (config->quantType == schema::QuantType_PostTraining) {
|
||||||
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum);
|
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->configFile, config->bitNum);
|
||||||
|
@ -281,6 +305,19 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||||
|
std::set<FuncGraphPtr> all_func_graphs{};
|
||||||
|
GetFuncGraphs(old_graph, &all_func_graphs);
|
||||||
|
for (auto &item : all_func_graphs) {
|
||||||
|
auto status = DoSingleGraphQuantize(item, config);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||||
MS_ASSERT(old_graph != nullptr);
|
MS_ASSERT(old_graph != nullptr);
|
||||||
if (config == nullptr) {
|
if (config == nullptr) {
|
||||||
|
|
|
@ -54,6 +54,10 @@ class AnfTransform {
|
||||||
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
||||||
|
|
||||||
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||||
|
|
||||||
|
static void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
||||||
|
|
||||||
|
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue