!44971 weight quant support kPrimConv2D&kPrimMatMul

Merge pull request !44971 from yeyunpeng2020/master_quant_ci
This commit is contained in:
i-robot 2022-11-02 02:00:57 +00:00 committed by Gitee
commit db77a41083
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 5 deletions

View File

@ -22,6 +22,7 @@
#include <vector>
#include <string>
#include <algorithm>
#include "tools/optimizer/graph/node_infershape.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
@ -482,6 +483,12 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
CHECK_NULL_RETURN(quant_cast_cnode);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" +
std::to_string(input_index));
opt::NodeInferShape infer;
auto status = infer.InferShape(quant_cast_cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << quant_cast_cnode->fullname_with_scope() << " InferShape failed.";
return RET_ERROR;
}
auto manager = func_graph->manager();
CHECK_NULL_RETURN(manager);
auto ret = manager->Replace(input_node, quant_cast_cnode);

View File

@ -374,11 +374,12 @@ bool WeightQuantizer::CheckWeightQuantExist(const CNodePtr &cnode) {
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
weight_quantized_tensors_.clear();
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLstm, prim::kPrimGather,
prim::kPrimAdam, prim::kPrimSGD,
prim::kPrimApplyMomentum};
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLstm, prim::kPrimGather,
prim::kPrimAdam, prim::kPrimSGD,
prim::kPrimApplyMomentum, prim::kPrimConv2D,
prim::kPrimMatMul};
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
if (ret != RET_OK) {