remove unused quant_type judgement

This commit is contained in:
xuanyue 2021-08-14 10:48:21 +08:00
parent d23114fe89
commit bb34fb5d6c
15 changed files with 52 additions and 121 deletions

View File

@ -20,7 +20,6 @@
#include <map>
#include <string>
#include "include/lite_utils.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace converter {
@ -36,7 +35,6 @@ enum MS_API FmkType : int {
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
struct MS_API ConverterParameters {
FmkType fmk;
schema::QuantType quant_type;
std::string model_file;
std::string weight_file;
std::map<std::string, std::string> attrs;

View File

@ -33,7 +33,6 @@ namespace lite {
namespace {
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
converter_parameters->fmk = flag.fmk;
converter_parameters->quant_type = flag.quantType;
converter_parameters->model_file = flag.modelFile;
converter_parameters->weight_file = flag.weightFile;
}

View File

@ -28,7 +28,6 @@ class MindirAdjust {
public:
MindirAdjust() {}
~MindirAdjust() = default;
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
bool Run(const FuncGraphPtr &graph);
@ -37,7 +36,6 @@ class MindirAdjust {
int ValueNodeInt64Convert(AnfNodePtr anf_node);
int ComputeQuantParams(AnfNodePtr anf_node);
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
FmkType fmk_type_ = FmkType::kFmkTypeMs;
bool train_flag_ = false;
};

View File

@ -42,7 +42,6 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const
}
auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
mindir_adjust_pass->SetFmkType(flag.fmk);
mindir_adjust_pass->SetQuantType(flag.quantType);
mindir_adjust_pass->SetTrainFlag(flag.trainModel);
if (!mindir_adjust_pass->Run(func_graph)) {
MS_LOG(ERROR) << "MindIr adjust failed.";
@ -97,7 +96,6 @@ size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned cha
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
quant_type_ = flag.quantType;
FuncGraphPtr func_graph;
if (flag.dec_key.size() != 0) {
unsigned char key[32];
@ -128,7 +126,7 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel, flag.quantType);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel);
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -31,7 +31,6 @@ class MindsporeImporter {
private:
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
};

View File

@ -79,7 +79,6 @@ CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
auto weight_file = flag.weight_file;
quant_type_ = flag.quant_type;
STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -112,7 +111,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeCaffe, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -66,7 +66,6 @@ class CaffeModelParser : public converter::ModelParser {
caffe::NetParameter caffe_weight_;
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace mindspore::lite

View File

@ -60,7 +60,6 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
string model_file = flag.model_file;
quant_type_ = flag.quant_type;
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);
@ -95,7 +94,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeOnnx, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -99,7 +99,6 @@ class OnnxModelParser : public converter::ModelParser {
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore

View File

@ -492,7 +492,6 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensor
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
auto modelFile = flag.model_file;
quant_type_ = flag.quant_type;
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
@ -581,7 +580,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTf, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -108,7 +108,6 @@ class TFModelParser : public converter::ModelParser {
std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_;
std::unordered_map<std::string, int> node_output_num_;
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_;
};
} // namespace lite

View File

@ -54,7 +54,6 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
quant_type_ = flag.quant_type;
// load graph
tflite_model_ = ReadTfliteModel(model_file);
if (tflite_model_ == nullptr) {
@ -105,7 +104,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTflite, false, quant_type_);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false);
if (!unify_format->Run(res_graph_)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;

View File

@ -52,7 +52,6 @@ class TfliteModelParser : public converter::ModelParser {
STATUS ConvertGraphOutputs();
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
int round_type = 1);
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore

View File

@ -32,8 +32,7 @@ constexpr int kNumIndex_0 = 0;
constexpr int kNumIndex_1 = 1;
constexpr int kNumIndex_2 = 2;
constexpr int kNumIndex_3 = 3;
STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type,
schema::Format *src_format) {
STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
@ -47,13 +46,13 @@ STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType
} else if (format == schema::Format_NCHW) {
*src_format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "cnode format is invalid.";
MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope();
return RET_ERROR;
}
return RET_OK;
}
STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
@ -61,34 +60,22 @@ STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quan
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
switch (quant_type) {
case schema::QuantType_AwareTraining:
case schema::QuantType_PostTraining:
case schema::QuantType_WeightQuant:
case schema::QuantType_QUANT_NONE: {
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
*src_format = schema::Format_HWCK;
} else {
*src_format = schema::Format_HWKC;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
*src_format = schema::Format::Format_HWCK;
} else {
MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check.";
return RET_ERROR;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
*src_format = schema::Format_HWCK;
} else {
*src_format = schema::Format_HWKC;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
*src_format = schema::Format::Format_HWCK;
} else {
MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope();
return RET_ERROR;
}
return RET_OK;
}
STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type,
schema::Format *src_format) {
STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
@ -96,87 +83,49 @@ STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
switch (quant_type) {
case schema::QuantType_AwareTraining:
case schema::QuantType_PostTraining:
case schema::QuantType_WeightQuant:
case schema::QuantType_QUANT_NONE: {
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
*src_format = schema::Format_KHWC;
} else {
*src_format = schema::Format_CHWK;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
*src_format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "cannot decide weight format, current situation need to check.";
return RET_NOT_SUPPORT;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << cnode->fullname_with_scope();
return RET_ERROR;
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
*src_format = schema::Format_KHWC;
} else {
*src_format = schema::Format_CHWK;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
*src_format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope();
return RET_NOT_SUPPORT;
}
return RET_OK;
}
STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr);
*src_format = schema::Format_KCHW;
return RET_OK;
}
STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
int64_t format =
prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
switch (quant_type) {
case schema::QuantType_AwareTraining: {
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
*src_format = schema::Format_KHWC;
} else {
*src_format = schema::Format_CHWK;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
*src_format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
} break;
case schema::QuantType_PostTraining:
case schema::QuantType_WeightQuant:
case schema::QuantType_QUANT_NONE: {
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
if (format == schema::Format_NHWC) {
*src_format = schema::Format_KHWC;
} else if (format == schema::Format_NCHW) {
*src_format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "format is invalid, format is " << format;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "d an unsupported op type, which need to check. the type is " << prim->name();
return RET_NOT_SUPPORT;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << cnode->fullname_with_scope();
return lite::RET_ERROR;
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
if (format == schema::Format_NHWC) {
*src_format = schema::Format_KHWC;
} else if (format == schema::Format_NCHW) {
*src_format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "format is invalid, format is " << format;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "unknown op, please check.";
return RET_ERROR;
}
return RET_OK;
}
@ -246,12 +195,12 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
schema::Format *dst_format) {
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
*dst_format = schema::Format_KHWC;
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::QuantType, schema::Format *)>>
decide_functions = {{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = {
{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
auto iter = decide_functions.find(fmk_type_);
if (iter == decide_functions.end()) {
MS_LOG(ERROR) << "current fmk don't support, please check.";
@ -259,7 +208,7 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
}
auto decide_func = iter->second;
MS_ASSERT(decide_func != nullptr);
if (decide_func(cnode, quant_type_, src_format) != RET_OK) {
if (decide_func(cnode, src_format) != RET_OK) {
MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format.";
return RET_ERROR;
}

View File

@ -24,9 +24,8 @@ namespace mindspore {
namespace lite {
class UnifyFormatToNHWC : public opt::ToFormatBase {
public:
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
schema::QuantType quant_type = schema::QuantType_QUANT_NONE)
: ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {}
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
: ToFormatBase(fmk_type, train_flag) {}
~UnifyFormatToNHWC() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
@ -41,7 +40,6 @@ class UnifyFormatToNHWC : public opt::ToFormatBase {
bool DecideWhetherInferShapeForNewNode() override;
STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
schema::Format *dst_format) override;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};
} // namespace lite
} // namespace mindspore