!44971 weight quant support kPrimConv2D&kPrimMatMul
Merge pull request !44971 from yeyunpeng2020/master_quant_ci
This commit is contained in:
commit
db77a41083
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "tools/optimizer/graph/node_infershape.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
@ -482,6 +483,12 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
|
|||
CHECK_NULL_RETURN(quant_cast_cnode);
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" +
|
||||
std::to_string(input_index));
|
||||
opt::NodeInferShape infer;
|
||||
auto status = infer.InferShape(quant_cast_cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << quant_cast_cnode->fullname_with_scope() << " InferShape failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto ret = manager->Replace(input_node, quant_cast_cnode);
|
||||
|
|
|
@ -374,11 +374,12 @@ bool WeightQuantizer::CheckWeightQuantExist(const CNodePtr &cnode) {
|
|||
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
weight_quantized_tensors_.clear();
|
||||
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLstm, prim::kPrimGather,
|
||||
prim::kPrimAdam, prim::kPrimSGD,
|
||||
prim::kPrimApplyMomentum};
|
||||
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLstm, prim::kPrimGather,
|
||||
prim::kPrimAdam, prim::kPrimSGD,
|
||||
prim::kPrimApplyMomentum, prim::kPrimConv2D,
|
||||
prim::kPrimMatMul};
|
||||
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
|
||||
auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
|
||||
if (ret != RET_OK) {
|
||||
|
|
Loading…
Reference in New Issue