correct name of func and variable

This commit is contained in:
hangangqiang 2021-04-13 19:23:52 +08:00
parent c9d9e1cf32
commit 0d8302c0d1
18 changed files with 323 additions and 326 deletions

View File

@ -121,8 +121,8 @@ int RunConverter(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
auto meta_graph = converter->Convert(flags); auto meta_graph = converter->Convert(flags);
NoSupportOp::GetInstance()->PrintOps(); NotSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); status = ReturnCode::GetSingleReturnCode()->status_code();
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
oss.clear(); oss.clear();
oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status);

View File

@ -28,67 +28,67 @@ namespace mindspore {
namespace lite { namespace lite {
class ReturnCode { class ReturnCode {
public: public:
~ReturnCode() = default; virtual ~ReturnCode() = default;
static ReturnCode *GetSingleReturnCode() { static ReturnCode *GetSingleReturnCode() {
static ReturnCode returnCode; static ReturnCode return_code;
return &returnCode; return &return_code;
} }
void UpdateReturnCode(STATUS status) { void UpdateReturnCode(STATUS status) {
if (statusCode == RET_OK) { if (status_code_ == RET_OK) {
statusCode = status; status_code_ = status;
} }
} }
STATUS GetReturnCode() const { return statusCode; } STATUS status_code() const { return status_code_; }
private: private:
ReturnCode() { statusCode = RET_OK; } ReturnCode() = default;
int statusCode; int status_code_ = RET_OK;
}; };
class NoSupportOp { class NotSupportOp {
public: public:
~NoSupportOp() = default; virtual ~NotSupportOp() = default;
static NoSupportOp *GetInstance() { static NotSupportOp *GetInstance() {
static NoSupportOp noSupportOp; static NotSupportOp not_support_op;
return &noSupportOp; return &not_support_op;
} }
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; } void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; }
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } void InsertOp(const std::string &op_name) { not_support_ops_.insert(op_name); }
void PrintOps() const { void PrintOps() const {
if (!noSupportOps.empty()) { if (!not_support_ops_.empty()) {
MS_LOG(ERROR) << "==========================================="; MS_LOG(ERROR) << "===========================================";
MS_LOG(ERROR) << "UNSUPPORTED OP LIST:"; MS_LOG(ERROR) << "UNSUPPORTED OP LIST:";
for (auto &op_name : noSupportOps) { for (auto &op_name : not_support_ops_) {
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name; MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name;
} }
MS_LOG(ERROR) << "==========================================="; MS_LOG(ERROR) << "===========================================";
} }
} }
private: private:
NoSupportOp() { noSupportOps.clear(); } NotSupportOp() = default;
std::set<std::string> noSupportOps; std::set<std::string> not_support_ops_;
std::string fmkType; std::string fmk_type_;
}; };
class TensorDataType { class TensorDataType {
public: public:
~TensorDataType() = default; ~TensorDataType() = default;
static TensorDataType *GetInstance() { static TensorDataType *GetInstance() {
static TensorDataType tensorDataType; static TensorDataType tensor_data_type;
return &tensorDataType; return &tensor_data_type;
} }
void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; } void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; }
int32_t GetTensorType(int32_t index) const { int32_t GetTensorType(int32_t index) const {
if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) { if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) {
return TypeId::kTypeUnknown; return TypeId::kTypeUnknown;
} }
return tensorDataTypeMap.at(index); return tensor_data_type_map_.at(index);
} }
private: private:
TensorDataType() {} TensorDataType() {}
std::map<int32_t, int32_t> tensorDataTypeMap; std::map<int32_t, int32_t> tensor_data_type_map_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -30,17 +30,17 @@ Flags::Flags() {
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel", AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
""); "");
AddFlag(&Flags::inputDataTypeIn, "inputDataType", AddFlag(&Flags::inputDataTypeStr, "inputDataType",
"Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT", "Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
"DEFAULT"); "DEFAULT");
AddFlag(&Flags::outputDataTypeIn, "outputDataType", AddFlag(&Flags::outputDataTypeStr, "outputDataType",
"Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | " "Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
"UINT8 | DEFAULT", "UINT8 | DEFAULT",
"DEFAULT"); "DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); AddFlag(&Flags::quantTypeStr, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16"); AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
AddFlag(&Flags::trainModelIn, "trainModel", AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device. " "whether the model is going to be trained on device. "
@ -49,32 +49,32 @@ Flags::Flags() {
} }
int Flags::InitInputOutputDataType() { int Flags::InitInputOutputDataType() {
if (this->inputDataTypeIn == "FLOAT") { if (this->inputDataTypeStr == "FLOAT") {
this->inputDataType = TypeId::kNumberTypeFloat32; this->inputDataType = TypeId::kNumberTypeFloat32;
} else if (this->inputDataTypeIn == "INT8") { } else if (this->inputDataTypeStr == "INT8") {
this->inputDataType = TypeId::kNumberTypeInt8; this->inputDataType = TypeId::kNumberTypeInt8;
} else if (this->inputDataTypeIn == "UINT8") { } else if (this->inputDataTypeStr == "UINT8") {
this->inputDataType = TypeId::kNumberTypeUInt8; this->inputDataType = TypeId::kNumberTypeUInt8;
} else if (this->inputDataTypeIn == "DEFAULT") { } else if (this->inputDataTypeStr == "DEFAULT") {
this->inputDataType = TypeId::kTypeUnknown; this->inputDataType = TypeId::kTypeUnknown;
} else { } else {
std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT", std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->inputDataTypeIn.c_str(); this->inputDataTypeStr.c_str();
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (this->outputDataTypeIn == "FLOAT") { if (this->outputDataTypeStr == "FLOAT") {
this->outputDataType = TypeId::kNumberTypeFloat32; this->outputDataType = TypeId::kNumberTypeFloat32;
} else if (this->outputDataTypeIn == "INT8") { } else if (this->outputDataTypeStr == "INT8") {
this->outputDataType = TypeId::kNumberTypeInt8; this->outputDataType = TypeId::kNumberTypeInt8;
} else if (this->outputDataTypeIn == "UINT8") { } else if (this->outputDataTypeStr == "UINT8") {
this->outputDataType = TypeId::kNumberTypeUInt8; this->outputDataType = TypeId::kNumberTypeUInt8;
} else if (this->outputDataTypeIn == "DEFAULT") { } else if (this->outputDataTypeStr == "DEFAULT") {
this->outputDataType = TypeId::kTypeUnknown; this->outputDataType = TypeId::kTypeUnknown;
} else { } else {
std::cerr std::cerr
<< "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT", << "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->outputDataTypeIn.c_str(); this->outputDataTypeStr.c_str();
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
return RET_OK; return RET_OK;
@ -110,7 +110,7 @@ bool Flags::IsValidNum(const std::string &str, int *num) {
} }
int Flags::QuantParamInputCheck() { int Flags::QuantParamInputCheck() {
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) { if (!Flags::IsValidNum(this->quantWeightChannelStr, &this->quantWeightChannel)) {
std::cerr << "quantWeightChannel should be a valid number."; std::cerr << "quantWeightChannel should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -118,7 +118,7 @@ int Flags::QuantParamInputCheck() {
std::cerr << "quantWeightChannel should be greater than or equal to zero."; std::cerr << "quantWeightChannel should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) { if (!Flags::IsValidNum(this->quantWeightSizeStr, &this->quantWeightSize)) {
std::cerr << "quantWeightSize should be a valid number."; std::cerr << "quantWeightSize should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -138,11 +138,11 @@ int Flags::QuantParamInputCheck() {
} }
int Flags::InitQuantParam() { int Flags::InitQuantParam() {
if (this->quantTypeIn == "WeightQuant") { if (this->quantTypeStr == "WeightQuant") {
this->quantType = QuantType_WeightQuant; this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") { } else if (this->quantTypeStr == "PostTraining") {
this->quantType = QuantType_PostTraining; this->quantType = QuantType_PostTraining;
} else if (this->quantTypeIn.empty()) { } else if (this->quantTypeStr.empty()) {
this->quantType = QuantType_QUANT_NONE; this->quantType = QuantType_QUANT_NONE;
} else { } else {
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";

View File

@ -65,25 +65,20 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string fmkIn; std::string fmkIn;
FmkType fmk; FmkType fmk;
std::string weightFile; std::string weightFile;
std::string inputArrays;
std::string outputArrays;
std::string inputShapes;
// used for quantization
std::string quantTypeIn;
QuantType quantType;
std::string inferenceTypeIn;
std::string inputDataTypeIn;
std::string outputDataTypeIn;
// used for parse aware trainning
TypeId inputDataType; TypeId inputDataType;
TypeId outputDataType; TypeId outputDataType;
// used for quantization
std::string quantTypeStr;
QuantType quantType;
std::string inputDataTypeStr;
std::string outputDataTypeStr;
// used for post-trainning-weight // used for post-trainning-weight
std::string quantWeightSizeIn; std::string quantWeightSizeStr;
int quantWeightSize; int quantWeightSize;
std::string bitNumIn; std::string bitNumIn;
int bitNum; int bitNum;
std::string configFile; std::string configFile;
std::string quantWeightChannelIn; std::string quantWeightChannelStr;
int quantWeightChannel; int quantWeightChannel;
std::string trainModelIn; std::string trainModelIn;
bool trainModel = false; bool trainModel = false;

View File

@ -47,8 +47,8 @@ namespace mindspore::lite {
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
std::vector<schema::CNodeT *> old_nodes{}; std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graphDefT->nodes.size()); old_nodes.resize(graph_defT_->nodes.size());
std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(), std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(),
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); }); [](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
return old_nodes; return old_nodes;
} }
@ -57,33 +57,33 @@ GraphDefTransform::GraphDefTransform() = default;
GraphDefTransform::~GraphDefTransform() = default; GraphDefTransform::~GraphDefTransform() = default;
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; }
int GraphDefTransform::Transform(const converter::Flags &ctx) { int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status; STATUS status;
{ {
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer unusedOpRemoveOptimizer; Optimizer unused_op_remove_optimizer;
if (!ctx.trainModel) { if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass()); unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass());
} }
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes)); unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unusedOpRemoveOptimizer.Run(graphDefT); status = unused_op_remove_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed";
return status; return status;
} }
} }
// generate and infer quant parameters // generate and infer quant parameters
{ {
Optimizer inferQuantParamPass; Optimizer infer_quant_param_pass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass()); infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass()); infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT); status = infer_quant_param_pass.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
return status; return status;
} }
} }
@ -93,40 +93,40 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer; Optimizer format_trans_optimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass(); auto format_trans_pass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) { if (format_trans_pass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed"; MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
formatTransPass->SetQuantType(ctx.quantType); format_trans_pass->set_quant_type(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk); format_trans_pass->set_fmk_type(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass); format_trans_optimizer.AddPass(format_trans_pass);
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
if (ctx.fmk != converter::FmkType_TF) { if (ctx.fmk != converter::FmkType_TF) {
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); format_trans_optimizer.AddPass(new (std::nothrow) InferShapePass());
} }
status = formatTransOptimizer.Run(graphDefT); status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status; return status;
} }
} }
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer; Optimizer format_trans_optimizer;
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = formatTransOptimizer.Run(graphDefT); status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status; return status;
} }
} }
@ -134,15 +134,15 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer; Optimizer format_trans_optimizer;
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
} }
status = formatTransOptimizer.Run(graphDefT); status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status; return status;
} }
} }
@ -151,7 +151,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer; Optimizer replace_optimizer;
if (!ctx.trainModel) { if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) { if (batch_norm_scale_pass == nullptr) {
@ -159,13 +159,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return RET_ERROR; return RET_ERROR;
} }
batch_norm_scale_pass->SetFmk(ctx.fmk); batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass); replace_optimizer.AddPass(batch_norm_scale_pass);
} }
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes)); replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT); status = replace_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed";
return status; return status;
} }
} }
@ -173,13 +173,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer; Optimizer fusion_optimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT); status = fusion_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed";
return status; return status;
} }
} }
@ -188,12 +188,12 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
if (ctx.fmk != converter::FmkType_TF) { if (ctx.fmk != converter::FmkType_TF) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer tensorQuantOptimizer; Optimizer tensor_quant_optimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensorQuantOptimizer.Run(graphDefT); status = tensor_quant_optimizer.Run(graph_defT_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!"; MS_LOG(ERROR) << "DoQuantize failed!";
return status; return status;
@ -204,31 +204,31 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
if (ctx.fmk != converter::FmkType_TF) { if (ctx.fmk != converter::FmkType_TF) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer; Optimizer quant_node_optimizer;
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass());
status = quantNodeOptimizer.Run(graphDefT); status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status; return status;
} }
auto old_nodes2 = GetGraphNodes(); auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass()); quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) { if (dtype_trans_pass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed"; MS_LOG(ERROR) << "new dtype_trans_pass failed";
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
dTypeTransPass->SetInputDataDType(ctx.inputDataType); dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType); dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
quantNodeOptimizer.AddPass(dTypeTransPass); quant_node_optimizer.AddPass(dtype_trans_pass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quantNodeOptimizer.Run(graphDefT); status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status; return status;
} }
} }
@ -237,22 +237,22 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer switchOptimizer; Optimizer switch_optimizer;
switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = switchOptimizer.Run(graphDefT); status = switch_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run switch graphPasses Failed"; MS_LOG(ERROR) << "Run switch_optimizer Failed";
return status; return status;
} }
} }
// subgraph tensor pass // subgraph tensor pass
{ {
Optimizer subgraphTensorOptimizer; Optimizer subgraph_tensor_optimizer;
subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraphTensorOptimizer.Run(graphDefT); status = subgraph_tensor_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
return status; return status;
@ -263,33 +263,33 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer nameOptimizer; Optimizer name_optimizer;
nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass()); name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT); status = name_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
return status; return status;
} }
} }
{ {
Optimizer nestedLoopOptimizer; Optimizer nested_loop_optimizer;
nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nestedLoopOptimizer.Run(graphDefT); status = nested_loop_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed";
return status; return status;
} }
} }
{ {
Optimizer quantNodeOptimizer; Optimizer quant_param_optimizer;
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
status = quantNodeOptimizer.Run(graphDefT); status = quant_param_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed";
return status; return status;
} }
} }

View File

@ -36,12 +36,12 @@ class GraphDefTransform {
GraphDefTransform(); GraphDefTransform();
virtual ~GraphDefTransform(); virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx); virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dstDef); void SetGraphDef(schema::MetaGraphT *dst_def);
inline schema::MetaGraphT *GetOutput() { return graphDefT; } inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
protected: protected:
std::vector<schema::CNodeT *> GetGraphNodes(); std::vector<schema::CNodeT *> GetGraphNodes();
schema::MetaGraphT *graphDefT = nullptr; schema::MetaGraphT *graph_defT_ = nullptr;
Optimizer *optimizer = nullptr; Optimizer *optimizer = nullptr;
}; };
} // namespace lite } // namespace lite

View File

@ -55,34 +55,35 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
auto &graphInIdxes = graph->inputIndex; auto &graph_in_idxes = graph->inputIndex;
if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && if (this->input_data_dtype != TypeId::kNumberTypeFloat32 && this->input_data_dtype != TypeId::kNumberTypeUInt8 &&
this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) { this->input_data_dtype != TypeId::kNumberTypeInt8 && this->input_data_dtype != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype;
return RET_ERROR; return RET_ERROR;
} }
for (auto graphInIdx : graphInIdxes) { for (auto graph_in_idx : graph_in_idxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size()); MS_ASSERT(graph_in_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx); auto &tensor = graph->allTensors.at(graph_in_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue; continue;
} }
int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown
? this->inputDataDType ? this->input_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graphInIdx); : TensorDataType::GetInstance()->GetTensorType(graph_in_idx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name; auto node_name = (*iter)->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) { if ((*iter)->inputIndex.at(input_indexidx) == graph_in_idx) {
STATUS status = RET_OK; STATUS status = RET_OK;
// insert dtype cast node between input tensor and input node // insert dtype cast node between input tensor and input node
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status); iter =
InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed"; MS_LOG(ERROR) << "InsertDTypeTransNode before " << node_name.c_str() << " failed";
return status; return status;
} }
} }
@ -94,33 +95,34 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && if (this->output_data_dtype != TypeId::kNumberTypeFloat32 && this->output_data_dtype != TypeId::kNumberTypeUInt8 &&
this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) { this->output_data_dtype != TypeId::kNumberTypeInt8 && this->output_data_dtype != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; MS_LOG(ERROR) << "Invalid outputDataType: " << this->output_data_dtype;
return RET_ERROR; return RET_ERROR;
} }
auto &graphOutIdxes = graph->outputIndex; auto &graph_out_idxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) { for (auto graph_out_idx : graph_out_idxes) {
MS_ASSERT(graphOutIdx < graph->allTensors.size()); MS_ASSERT(graph_out_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphOutIdx); auto &tensor = graph->allTensors.at(graph_out_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue; continue;
} }
int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown
? this->outputDataDType ? this->output_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graphOutIdx); : TensorDataType::GetInstance()->GetTensorType(graph_out_idx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name; auto node_name = (*iter)->name;
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) { for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) {
// insert transNode // insert transNode
STATUS status = RET_OK; STATUS status = RET_OK;
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status); iter =
InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; MS_LOG(ERROR) << "InsertDTypeTransNode after " << node_name.c_str() << " failed";
return status; return status;
} }
break; break;
@ -231,52 +233,53 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
return RET_OK; return RET_OK;
} }
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inoutIdx, int32_t inputDataType, int32_t outputDataType, size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
STATUS *errorCode) { STATUS *error_code) {
MS_ASSERT((*existNodeIter) != nullptr); MS_ASSERT((*exist_node_iter) != nullptr);
auto existNodeName = (*existNodeIter)->name; auto exist_node_name = (*exist_node_iter)->name;
std::string tileName; std::string tile_name;
if (place == kBefore) { if (place == kBefore) {
tileName = existNodeName + "_pre"; tile_name = exist_node_name + "_pre";
} else { } else {
tileName = existNodeName + "_post"; tile_name = exist_node_name + "_post";
} }
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); auto trans_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) { if (trans_node == nullptr) {
MS_LOG(ERROR) << "new TransNode failed"; MS_LOG(ERROR) << "new TransNode failed";
*errorCode = RET_ERROR; *error_code = RET_ERROR;
return graph->nodes.end(); return graph->nodes.end();
} }
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; auto quant_dtype_cast_param = new (std::nothrow) QuantDTypeCastT;
if (quantDTypeCastParam == nullptr) { if (quant_dtype_cast_param == nullptr) {
MS_LOG(ERROR) << "new quantDTypeCastParam failed"; MS_LOG(ERROR) << "new quantDTypeCastParam failed";
*errorCode = RET_ERROR; *error_code = RET_ERROR;
return graph->nodes.end(); return graph->nodes.end();
} }
transNode->primitive = std::make_unique<schema::PrimitiveT>(); trans_node->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam; trans_node->primitive->value.value = quant_dtype_cast_param;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; trans_node->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTraining; trans_node->quantType = QuantType_AwareTraining;
quantDTypeCastParam->src_t = inputDataType; quant_dtype_cast_param->src_t = input_data_type;
quantDTypeCastParam->dst_t = outputDataType; quant_dtype_cast_param->dst_t = output_data_type;
if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) { if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeFloat32) {
transNode->name = "int8toft32_" + tileName + std::to_string(id++); trans_node->name = "int8toft32_" + tile_name + std::to_string(id_++);
} else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) { } else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeInt8) {
transNode->name = "ft32toint8_" + tileName + std::to_string(id++); trans_node->name = "ft32toint8_" + tile_name + std::to_string(id_++);
} else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) { } else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeInt8) {
transNode->name = "uint8toint8_" + tileName + std::to_string(id++); trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++);
} else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) { } else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) {
transNode->name = "int8touint8_" + tileName + std::to_string(id++); trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++);
} }
transNode->primitive->value.value = quantDTypeCastParam; trans_node->primitive->value.value = quant_dtype_cast_param;
int insert_num = 0; int insert_num = 0;
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, castOpCopyer); return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
castOpCopyer);
} }
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; } void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -30,15 +30,15 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
class DTypeTransPass : public GraphPass { class DTypeTransPass : public GraphPass {
public: public:
DTypeTransPass() : id(0) {} DTypeTransPass() : id_(0) {}
~DTypeTransPass() override = default; ~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override; STATUS Run(schema::MetaGraphT *graph) override;
void SetInputDataDType(TypeId dataType); void set_input_data_dtype(TypeId data_type);
void SetOutputDataDType(TypeId dataType); void set_output_data_dtype(TypeId dataType);
private: private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
@ -51,13 +51,14 @@ class DTypeTransPass : public GraphPass {
STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter); STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode); size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
STATUS *error_code);
private: private:
size_t id; size_t id_;
TypeId inputDataDType = TypeId::kNumberTypeFloat; TypeId input_data_dtype = TypeId::kNumberTypeFloat;
TypeId outputDataDType = TypeId::kNumberTypeFloat; TypeId output_data_dtype = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> { OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT); std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);

View File

@ -45,32 +45,32 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
return RET_OK; return RET_OK;
} }
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *afterNodeType) { FormatTransNodeType *after_node_type) {
if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc
if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE; return RET_NO_CHANGE;
} }
*beforeNodeType = kNHWC2NCHW; *before_node_type = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC; *after_node_type = kNCHW2NHWC;
return RET_OK; return RET_OK;
} else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS ||
fmk_type_ == converter::FmkType_ONNX) { fmk_type_ == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE; return RET_NO_CHANGE;
} }
*beforeNodeType = kNCHW2NHWC; *before_node_type = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW; *after_node_type = kNHWC2NCHW;
return RET_OK; return RET_OK;
} else if (fmk_type_ == converter::FmkType_TF) { } else if (fmk_type_ == converter::FmkType_TF) {
if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) { if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) {
*beforeNodeType = kNCHW2NHWC; *before_node_type = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW; *after_node_type = kNHWC2NCHW;
return RET_OK; return RET_OK;
} }
if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { if (IsContain(GetNchwOpList(), GetCNodeTType(node))) {
*beforeNodeType = kNHWC2NCHW; *before_node_type = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC; *after_node_type = kNCHW2NHWC;
return RET_OK; return RET_OK;
} }
return RET_NO_CHANGE; return RET_NO_CHANGE;
@ -96,36 +96,34 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
return RET_OK; return RET_OK;
} }
} }
auto graphInputIdxes = graph->inputIndex; auto graph_input_idxes = graph->inputIndex;
for (size_t i = 0; i < graphInputIdxes.size(); i++) { for (size_t i = 0; i < graph_input_idxes.size(); i++) {
bool transed = false; bool transed = false;
auto inputIdx = graphInputIdxes.at(i); auto input_idx = graph_input_idxes.at(i);
MS_ASSERT(inputIdx < subGraph->allTensors.size()); auto &tensor = graph->allTensors.at(input_idx);
auto &tensor = graph->allTensors.at(inputIdx);
if (tensor->dims.size() != kNCHWDimNumber) { if (tensor->dims.size() != kNCHWDimNumber) {
continue; continue;
} }
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) { if ((*iter)->inputIndex.at(input_index_idx) == input_idx) {
STATUS status = RET_OK; STATUS status = RET_OK;
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
return status; return status;
} }
// set first tensor format to nhwc // set first tensor format to nhwc
auto &transNode = *(iter - 1); auto &trans_node = *(iter - 1);
MS_ASSERT(transNode != nullptr); MS_ASSERT(trans_node != nullptr);
MS_ASSERT(transNode->inputIndex.size() == 1); MS_ASSERT(trans_node->inputIndex.size() == 1);
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front());
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); graph_in_tensor->format = schema::Format::Format_NHWC;
graphInTensor->format = schema::Format::Format_NHWC;
// assume parser not reformat shape // assume parser not reformat shape
auto oldDims = graphInTensor->dims; auto old_dims = graph_in_tensor->dims;
if (!transed) { if (!transed) {
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]};
transed = true; transed = true;
} }
} }
@ -143,10 +141,10 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
// insert before and after the op cal by nchw/nc4hw4 // insert before and after the op cal by nchw/nc4hw4
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
FormatTransNodeType beforeNodeType = kNCHW2NHWC; FormatTransNodeType before_node_type = kNCHW2NHWC;
FormatTransNodeType afterNodeType = kNHWC2NCHW; FormatTransNodeType after_node_type = kNHWC2NCHW;
STATUS status = RET_OK; STATUS status = RET_OK;
status = GetInsertFormatTrans(**iter, &beforeNodeType, &afterNodeType); status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type);
if (status == RET_NO_CHANGE) { if (status == RET_NO_CHANGE) {
continue; continue;
} }
@ -170,17 +168,17 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
} }
auto specInsertIndexes = GetExtNhwcIndexes(); auto spec_insert_indexes = GetExtNhwcIndexes();
auto opType = GetCNodeTType(**iter); auto op_type = GetCNodeTType(**iter);
if (specInsertIndexes.find(opType) != specInsertIndexes.end()) { if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) {
for (auto insert_index : specInsertIndexes[opType]) { for (auto insert_index : spec_insert_indexes[op_type]) {
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status); iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR; return RET_ERROR;
} }
} }
} else if (IsContain(GetNhwcAllInputOpList(), opType)) { } else if (IsContain(GetNhwcAllInputOpList(), op_type)) {
auto input_size = node->inputIndex.size(); auto input_size = node->inputIndex.size();
if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) {
if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) {
@ -188,16 +186,16 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
} }
} }
for (size_t i = 0; i < input_size; i++) { for (size_t i = 0; i < input_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR; return RET_ERROR;
} }
} }
} else { } else {
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status);
} }
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR; return RET_ERROR;
@ -206,29 +204,29 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
return RET_OK; return RET_OK;
} }
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) { size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) {
MS_ASSERT((*existNodeIter) != nullptr); MS_ASSERT((*exist_node_iter) != nullptr);
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
auto existNodeName = (*existNodeIter)->name; auto exist_node_name = (*exist_node_iter)->name;
std::string tileName; std::string tile_name;
if (place == kBefore) { if (place == kBefore) {
tileName = existNodeName + "_pre"; tile_name = exist_node_name + "_pre";
} else { } else {
tileName = existNodeName + "_post"; tile_name = exist_node_name + "_post";
} }
auto transNode = std::make_unique<schema::CNodeT>(); auto trans_node = std::make_unique<schema::CNodeT>();
transNode->primitive = std::make_unique<schema::PrimitiveT>(); trans_node->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.type = schema::PrimitiveType_Transpose; trans_node->primitive->value.type = schema::PrimitiveType_Transpose;
auto perm_tensor = std::make_unique<schema::TensorT>(); auto perm_tensor = std::make_unique<schema::TensorT>();
perm_tensor->dataType = kNumberTypeInt32; perm_tensor->dataType = kNumberTypeInt32;
perm_tensor->dims = {4}; perm_tensor->dims = {4};
std::vector<int> perm; std::vector<int> perm;
if (nodeType == kNCHW2NHWC) { if (node_type == kNCHW2NHWC) {
transNode->name = "nchw2nhwc_" + tileName + std::to_string(id_++); trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++);
perm = {0, 2, 3, 1}; perm = {0, 2, 3, 1};
} else { } else {
transNode->name = "nhwc2nchw_" + tileName + std::to_string(id_++); trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++);
perm = {0, 3, 1, 2}; perm = {0, 3, 1, 2};
} }
size_t bytes = perm.size() * sizeof(int); size_t bytes = perm.size() * sizeof(int);
@ -236,27 +234,27 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) { if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) {
MS_LOG(ERROR) << "memcpy data failed."; MS_LOG(ERROR) << "memcpy data failed.";
} }
perm_tensor->name = transNode->name + "_perm"; perm_tensor->name = trans_node->name + "_perm";
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> {
auto newOpDef = std::make_unique<schema::CNodeT>(); auto new_op_def = std::make_unique<schema::CNodeT>();
if (newOpDef == nullptr) { if (new_op_def == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed"; MS_LOG(ERROR) << "new CNodeT failed";
return nullptr; return nullptr;
} }
newOpDef->name = inOpDef->name; new_op_def->name = in_op_def->name;
newOpDef->quantType = inOpDef->quantType; new_op_def->quantType = in_op_def->quantType;
newOpDef->primitive = std::make_unique<schema::PrimitiveT>(); new_op_def->primitive = std::make_unique<schema::PrimitiveT>();
if (newOpDef->primitive == nullptr) { if (new_op_def->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed"; MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr; return nullptr;
} }
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; new_op_def->primitive->value.type = schema::PrimitiveType_Transpose;
return newOpDef; return new_op_def;
}; };
int insert_num = 0; int insert_num = 0;
auto iter = auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, TransposeOpCopyer); transpose_op_copyer);
size_t index = graph->allTensors.size(); size_t index = graph->allTensors.size();
graph->allTensors.push_back(std::move(perm_tensor)); graph->allTensors.push_back(std::move(perm_tensor));
for (int i = insert_num; i > 0; --i) { for (int i = insert_num; i > 0; --i) {

View File

@ -34,13 +34,13 @@ class FormatTransPass : public GraphPass {
STATUS Run(schema::MetaGraphT *graph) override; STATUS Run(schema::MetaGraphT *graph) override;
void SetQuantType(QuantType quantType) { this->quant_type_ = quantType; } void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; }
void SetFmk(converter::FmkType fmkType) { this->fmk_type_ = fmkType; } void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
protected: protected:
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place,
FormatTransNodeType nodeType, STATUS *errorCode); size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code);
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
@ -61,8 +61,8 @@ class FormatTransPass : public GraphPass {
int GetFormat(const schema::CNodeT &); int GetFormat(const schema::CNodeT &);
STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType, STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *afterNodeType); FormatTransNodeType *after_node_type);
protected: protected:
size_t id_ = 0; size_t id_ = 0;

View File

@ -20,13 +20,13 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
Optimizer::~Optimizer() { Optimizer::~Optimizer() {
for (auto pass : graphPasses) { for (auto pass : graph_passes_) {
if (pass != nullptr) { if (pass != nullptr) {
delete (pass); delete (pass);
} }
} }
for (auto pass : nodePasses) { for (auto pass : node_passes_) {
if (pass != nullptr) { if (pass != nullptr) {
delete (pass); delete (pass);
} }
@ -35,13 +35,13 @@ Optimizer::~Optimizer() {
void Optimizer::AddPass(GraphPass *graphPass) { void Optimizer::AddPass(GraphPass *graphPass) {
if (graphPass != nullptr) { if (graphPass != nullptr) {
this->graphPasses.emplace_back(graphPass); this->graph_passes_.emplace_back(graphPass);
} }
} }
void Optimizer::AddPass(NodePass *nodePass) { void Optimizer::AddPass(NodePass *nodePass) {
if (nodePass != nullptr) { if (nodePass != nullptr) {
this->nodePasses.emplace_back(nodePass); this->node_passes_.emplace_back(nodePass);
} }
} }
@ -51,7 +51,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
bool ifNotChanged = true; bool ifNotChanged = true;
// each node should go through all node pass not each node pass go through all node // each node should go through all node pass not each node pass go through all node
for (auto &opDef : graphDefT->nodes) { for (auto &opDef : graphDefT->nodes) {
for (auto pass : this->nodePasses) { for (auto pass : this->node_passes_) {
status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get())); status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get()));
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run NodePass failed"; MS_LOG(ERROR) << "Run NodePass failed";
@ -64,7 +64,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
} }
} }
for (auto pass : this->graphPasses) { for (auto pass : this->graph_passes_) {
status = pass->Run(graphDefT); status = pass->Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run GraphPass failed"; MS_LOG(ERROR) << "Run GraphPass failed";

View File

@ -41,10 +41,10 @@ class GraphPass : public Pass<schema::MetaGraphT> {
}; };
struct GraphNode { struct GraphNode {
GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {} GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : sub_graph_(subGraph), op_def_(opDefT) {}
~GraphNode() = default; ~GraphNode() = default;
schema::MetaGraphT *subGraph = nullptr; schema::MetaGraphT *sub_graph_ = nullptr;
schema::CNodeT *opDef = nullptr; schema::CNodeT *op_def_ = nullptr;
}; };
class NodePass : public Pass<GraphNode> { class NodePass : public Pass<GraphNode> {
@ -72,8 +72,8 @@ class Optimizer {
STATUS Run(schema::MetaGraphT *graphDefT); STATUS Run(schema::MetaGraphT *graphDefT);
private: private:
std::vector<GraphPass *> graphPasses; std::vector<GraphPass *> graph_passes_;
std::vector<NodePass *> nodePasses; std::vector<NodePass *> node_passes_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -98,7 +98,7 @@ STATUS CaffeModelParser::ConvertLayers() {
MS_LOG(INFO) << "parse op : " << layer.type(); MS_LOG(INFO) << "parse op : " << layer.type();
auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type()); auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
if (node_parser == nullptr) { if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(layer.type()); NotSupportOp::GetInstance()->InsertOp(layer.type());
status = (status == RET_OK ? RET_NOT_FIND_OP : status); status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue; continue;
} }

View File

@ -47,7 +47,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) { const QuantType &quant_type) {
NoSupportOp::GetInstance()->SetFmkType("ONNX"); NotSupportOp::GetInstance()->set_fmk_type("ONNX");
anf_root_graph_ = std::make_shared<FuncGraph>(); anf_root_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file); auto status = InitOriginModel(model_file);
if (RET_OK != status) { if (RET_OK != status) {
@ -195,7 +195,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
for (const auto &onnx_node : onnx_graph.node()) { for (const auto &onnx_node : onnx_graph.node()) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr) { if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
status = status == RET_OK ? RET_NOT_FIND_OP : status; status = status == RET_OK ? RET_NOT_FIND_OP : status;
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type(); MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
} }

View File

@ -476,7 +476,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
FuncGraphPtr paserTfFuction() { return nullptr; } FuncGraphPtr paserTfFuction() { return nullptr; }
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) { const QuantType &quantType) {
NoSupportOp::GetInstance()->SetFmkType("TF"); NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb"); auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
@ -888,7 +888,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(INFO) << "parse op : " << op_type; MS_LOG(INFO) << "parse op : " << op_type;
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type); NotSupportOp::GetInstance()->InsertOp(op_type);
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in " MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
<< func_graph_ptr->get_attr("graph_name")->ToString(); << func_graph_ptr->get_attr("graph_name")->ToString();
return RET_NOT_FIND_OP; return RET_NOT_FIND_OP;

View File

@ -101,7 +101,7 @@ std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type,
STATUS TfliteModelParser::ConvertOps() { STATUS TfliteModelParser::ConvertOps() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front(); const auto &tflite_subgraph = tflite_model_->subgraphs.front();
NoSupportOp::GetInstance()->SetFmkType("TFLITE"); NotSupportOp::GetInstance()->set_fmk_type("TFLITE");
STATUS status = RET_OK; STATUS status = RET_OK;
int op_idx = 0; int op_idx = 0;
for (auto &op : tflite_subgraph->operators) { for (auto &op : tflite_subgraph->operators) {
@ -113,7 +113,7 @@ STATUS TfliteModelParser::ConvertOps() {
MS_LOG(INFO) << "parse node :" << op_name; MS_LOG(INFO) << "parse node :" << op_name;
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type); auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type); NotSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status); status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue; continue;
} }

View File

@ -1344,7 +1344,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// add quant_cast // add quant_cast
quant::QuantCast quant_cast; quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32); quant_cast.set_input_data_dtype(kNumberTypeFloat32);
status = quant_cast.Run(func_graph); status = quant_cast.Run(func_graph);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error"; MS_LOG(ERROR) << "add QuantCast error";

View File

@ -26,12 +26,12 @@ namespace mindspore::lite::quant {
class QuantCast { class QuantCast {
public: public:
QuantCast() = default; QuantCast() = default;
~QuantCast() = default; virtual ~QuantCast() = default;
STATUS Run(const FuncGraphPtr &graph); STATUS Run(const FuncGraphPtr &graph);
void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } void set_input_data_dtype(TypeId data_type) { this->input_data_dtype_ = data_type; }
private: private:
TypeId inputDataDType = kNumberTypeFloat32; TypeId input_data_dtype_ = kNumberTypeFloat32;
}; };
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H