correct name of func and variable

This commit is contained in:
hangangqiang 2021-04-13 19:23:52 +08:00
parent c9d9e1cf32
commit 0d8302c0d1
18 changed files with 323 additions and 326 deletions

View File

@ -121,8 +121,8 @@ int RunConverter(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
auto meta_graph = converter->Convert(flags);
NoSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
NotSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->status_code();
if (meta_graph == nullptr) {
oss.clear();
oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status);

View File

@ -28,67 +28,67 @@ namespace mindspore {
namespace lite {
class ReturnCode {
public:
~ReturnCode() = default;
virtual ~ReturnCode() = default;
static ReturnCode *GetSingleReturnCode() {
static ReturnCode returnCode;
return &returnCode;
static ReturnCode return_code;
return &return_code;
}
void UpdateReturnCode(STATUS status) {
if (statusCode == RET_OK) {
statusCode = status;
if (status_code_ == RET_OK) {
status_code_ = status;
}
}
STATUS GetReturnCode() const { return statusCode; }
STATUS status_code() const { return status_code_; }
private:
ReturnCode() { statusCode = RET_OK; }
int statusCode;
ReturnCode() = default;
int status_code_ = RET_OK;
};
class NoSupportOp {
class NotSupportOp {
public:
~NoSupportOp() = default;
static NoSupportOp *GetInstance() {
static NoSupportOp noSupportOp;
return &noSupportOp;
virtual ~NotSupportOp() = default;
static NotSupportOp *GetInstance() {
static NotSupportOp not_support_op;
return &not_support_op;
}
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; }
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; }
void InsertOp(const std::string &op_name) { not_support_ops_.insert(op_name); }
void PrintOps() const {
if (!noSupportOps.empty()) {
if (!not_support_ops_.empty()) {
MS_LOG(ERROR) << "===========================================";
MS_LOG(ERROR) << "UNSUPPORTED OP LIST:";
for (auto &op_name : noSupportOps) {
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name;
for (auto &op_name : not_support_ops_) {
MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name;
}
MS_LOG(ERROR) << "===========================================";
}
}
private:
NoSupportOp() { noSupportOps.clear(); }
std::set<std::string> noSupportOps;
std::string fmkType;
NotSupportOp() = default;
std::set<std::string> not_support_ops_;
std::string fmk_type_;
};
class TensorDataType {
public:
~TensorDataType() = default;
static TensorDataType *GetInstance() {
static TensorDataType tensorDataType;
return &tensorDataType;
static TensorDataType tensor_data_type;
return &tensor_data_type;
}
void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; }
void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; }
int32_t GetTensorType(int32_t index) const {
if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) {
if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) {
return TypeId::kTypeUnknown;
}
return tensorDataTypeMap.at(index);
return tensor_data_type_map_.at(index);
}
private:
TensorDataType() {}
std::map<int32_t, int32_t> tensorDataTypeMap;
std::map<int32_t, int32_t> tensor_data_type_map_;
};
} // namespace lite
} // namespace mindspore

View File

@ -30,17 +30,17 @@ Flags::Flags() {
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
"");
AddFlag(&Flags::inputDataTypeIn, "inputDataType",
AddFlag(&Flags::inputDataTypeStr, "inputDataType",
"Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
"DEFAULT");
AddFlag(&Flags::outputDataTypeIn, "outputDataType",
AddFlag(&Flags::outputDataTypeStr, "outputDataType",
"Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
"UINT8 | DEFAULT",
"DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::quantTypeStr, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device. "
@ -49,32 +49,32 @@ Flags::Flags() {
}
int Flags::InitInputOutputDataType() {
if (this->inputDataTypeIn == "FLOAT") {
if (this->inputDataTypeStr == "FLOAT") {
this->inputDataType = TypeId::kNumberTypeFloat32;
} else if (this->inputDataTypeIn == "INT8") {
} else if (this->inputDataTypeStr == "INT8") {
this->inputDataType = TypeId::kNumberTypeInt8;
} else if (this->inputDataTypeIn == "UINT8") {
} else if (this->inputDataTypeStr == "UINT8") {
this->inputDataType = TypeId::kNumberTypeUInt8;
} else if (this->inputDataTypeIn == "DEFAULT") {
} else if (this->inputDataTypeStr == "DEFAULT") {
this->inputDataType = TypeId::kTypeUnknown;
} else {
std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->inputDataTypeIn.c_str();
this->inputDataTypeStr.c_str();
return RET_INPUT_PARAM_INVALID;
}
if (this->outputDataTypeIn == "FLOAT") {
if (this->outputDataTypeStr == "FLOAT") {
this->outputDataType = TypeId::kNumberTypeFloat32;
} else if (this->outputDataTypeIn == "INT8") {
} else if (this->outputDataTypeStr == "INT8") {
this->outputDataType = TypeId::kNumberTypeInt8;
} else if (this->outputDataTypeIn == "UINT8") {
} else if (this->outputDataTypeStr == "UINT8") {
this->outputDataType = TypeId::kNumberTypeUInt8;
} else if (this->outputDataTypeIn == "DEFAULT") {
} else if (this->outputDataTypeStr == "DEFAULT") {
this->outputDataType = TypeId::kTypeUnknown;
} else {
std::cerr
<< "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
this->outputDataTypeIn.c_str();
this->outputDataTypeStr.c_str();
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
@ -110,7 +110,7 @@ bool Flags::IsValidNum(const std::string &str, int *num) {
}
int Flags::QuantParamInputCheck() {
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
if (!Flags::IsValidNum(this->quantWeightChannelStr, &this->quantWeightChannel)) {
std::cerr << "quantWeightChannel should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
@ -118,7 +118,7 @@ int Flags::QuantParamInputCheck() {
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
if (!Flags::IsValidNum(this->quantWeightSizeStr, &this->quantWeightSize)) {
std::cerr << "quantWeightSize should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
@ -138,11 +138,11 @@ int Flags::QuantParamInputCheck() {
}
int Flags::InitQuantParam() {
if (this->quantTypeIn == "WeightQuant") {
if (this->quantTypeStr == "WeightQuant") {
this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") {
} else if (this->quantTypeStr == "PostTraining") {
this->quantType = QuantType_PostTraining;
} else if (this->quantTypeIn.empty()) {
} else if (this->quantTypeStr.empty()) {
this->quantType = QuantType_QUANT_NONE;
} else {
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";

View File

@ -65,25 +65,20 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string fmkIn;
FmkType fmk;
std::string weightFile;
std::string inputArrays;
std::string outputArrays;
std::string inputShapes;
// used for quantization
std::string quantTypeIn;
QuantType quantType;
std::string inferenceTypeIn;
std::string inputDataTypeIn;
std::string outputDataTypeIn;
// used for parse aware trainning
TypeId inputDataType;
TypeId outputDataType;
// used for quantization
std::string quantTypeStr;
QuantType quantType;
std::string inputDataTypeStr;
std::string outputDataTypeStr;
// used for post-trainning-weight
std::string quantWeightSizeIn;
std::string quantWeightSizeStr;
int quantWeightSize;
std::string bitNumIn;
int bitNum;
std::string configFile;
std::string quantWeightChannelIn;
std::string quantWeightChannelStr;
int quantWeightChannel;
std::string trainModelIn;
bool trainModel = false;

View File

@ -47,8 +47,8 @@ namespace mindspore::lite {
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graphDefT->nodes.size());
std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(),
old_nodes.resize(graph_defT_->nodes.size());
std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(),
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
return old_nodes;
}
@ -57,33 +57,33 @@ GraphDefTransform::GraphDefTransform() = default;
GraphDefTransform::~GraphDefTransform() = default;
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; }
int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
{
auto old_nodes = GetGraphNodes();
Optimizer unusedOpRemoveOptimizer;
Optimizer unused_op_remove_optimizer;
if (!ctx.trainModel) {
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass());
}
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unusedOpRemoveOptimizer.Run(graphDefT);
unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass());
unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = unused_op_remove_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed";
return status;
}
}
// generate and infer quant parameters
{
Optimizer inferQuantParamPass;
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
status = inferQuantParamPass.Run(graphDefT);
Optimizer infer_quant_param_pass;
infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
status = infer_quant_param_pass.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
return status;
}
}
@ -93,40 +93,40 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) {
Optimizer format_trans_optimizer;
auto format_trans_pass = new (std::nothrow) FormatTransPass();
if (format_trans_pass == nullptr) {
MS_LOG(ERROR) << "new formatTransPass failed";
return RET_MEMORY_FAILED;
}
formatTransPass->SetQuantType(ctx.quantType);
formatTransPass->SetFmk(ctx.fmk);
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
format_trans_pass->set_quant_type(ctx.quantType);
format_trans_pass->set_fmk_type(ctx.fmk);
format_trans_optimizer.AddPass(format_trans_pass);
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
if (ctx.fmk != converter::FmkType_TF) {
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
format_trans_optimizer.AddPass(new (std::nothrow) InferShapePass());
}
status = formatTransOptimizer.Run(graphDefT);
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = formatTransOptimizer.Run(graphDefT);
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass());
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
@ -134,15 +134,15 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer formatTransOptimizer;
Optimizer format_trans_optimizer;
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}
status = formatTransOptimizer.Run(graphDefT);
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
@ -151,7 +151,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
Optimizer replace_optimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
@ -159,13 +159,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
fusionOptimizer.AddPass(batch_norm_scale_pass);
replace_optimizer.AddPass(batch_norm_scale_pass);
}
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = replace_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed";
return status;
}
}
@ -173,13 +173,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusionOptimizer.Run(graphDefT);
Optimizer fusion_optimizer;
fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass());
fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = fusion_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed";
return status;
}
}
@ -188,12 +188,12 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
if (ctx.fmk != converter::FmkType_TF) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer tensorQuantOptimizer;
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensorQuantOptimizer.Run(graphDefT);
Optimizer tensor_quant_optimizer;
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass());
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensor_quant_optimizer.Run(graph_defT_);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
@ -204,31 +204,31 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
if (ctx.fmk != converter::FmkType_TF) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer quantNodeOptimizer;
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
status = quantNodeOptimizer.Run(graphDefT);
Optimizer quant_node_optimizer;
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass());
status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
if (dtype_trans_pass == nullptr) {
MS_LOG(ERROR) << "new dtype_trans_pass failed";
return RET_MEMORY_FAILED;
}
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quantNodeOptimizer.Run(graphDefT);
dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
quant_node_optimizer.AddPass(dtype_trans_pass);
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status;
}
}
@ -237,22 +237,22 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer switchOptimizer;
switchOptimizer.AddPass(new (std::nothrow) SwitchPass());
switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = switchOptimizer.Run(graphDefT);
Optimizer switch_optimizer;
switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = switch_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run switch graphPasses Failed";
MS_LOG(ERROR) << "Run switch_optimizer Failed";
return status;
}
}
// subgraph tensor pass
{
Optimizer subgraphTensorOptimizer;
subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraphTensorOptimizer.Run(graphDefT);
Optimizer subgraph_tensor_optimizer;
subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraph_tensor_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
return status;
@ -263,33 +263,33 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer nameOptimizer;
nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
status = nameOptimizer.Run(graphDefT);
Optimizer name_optimizer;
name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
status = name_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
return status;
}
}
{
Optimizer nestedLoopOptimizer;
nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nestedLoopOptimizer.Run(graphDefT);
Optimizer nested_loop_optimizer;
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nested_loop_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed";
return status;
}
}
{
Optimizer quantNodeOptimizer;
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
status = quantNodeOptimizer.Run(graphDefT);
Optimizer quant_param_optimizer;
quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
status = quant_param_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed";
return status;
}
}

View File

@ -36,12 +36,12 @@ class GraphDefTransform {
GraphDefTransform();
virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dstDef);
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
void SetGraphDef(schema::MetaGraphT *dst_def);
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
protected:
std::vector<schema::CNodeT *> GetGraphNodes();
schema::MetaGraphT *graphDefT = nullptr;
schema::MetaGraphT *graph_defT_ = nullptr;
Optimizer *optimizer = nullptr;
};
} // namespace lite

View File

@ -55,34 +55,35 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto &graphInIdxes = graph->inputIndex;
if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 &&
this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
auto &graph_in_idxes = graph->inputIndex;
if (this->input_data_dtype != TypeId::kNumberTypeFloat32 && this->input_data_dtype != TypeId::kNumberTypeUInt8 &&
this->input_data_dtype != TypeId::kNumberTypeInt8 && this->input_data_dtype != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype;
return RET_ERROR;
}
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
for (auto graph_in_idx : graph_in_idxes) {
MS_ASSERT(graph_in_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graph_in_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}
int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown
? this->inputDataDType
: TensorDataType::GetInstance()->GetTensorType(graphInIdx);
int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown
? this->input_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graph_in_idx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) {
auto node_name = (*iter)->name;
for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) {
if ((*iter)->inputIndex.at(input_indexidx) == graph_in_idx) {
STATUS status = RET_OK;
// insert dtype cast node between input tensor and input node
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status);
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
iter =
InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed";
MS_LOG(ERROR) << "InsertDTypeTransNode before " << node_name.c_str() << " failed";
return status;
}
}
@ -94,33 +95,34 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 &&
this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType;
if (this->output_data_dtype != TypeId::kNumberTypeFloat32 && this->output_data_dtype != TypeId::kNumberTypeUInt8 &&
this->output_data_dtype != TypeId::kNumberTypeInt8 && this->output_data_dtype != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->output_data_dtype;
return RET_ERROR;
}
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
MS_ASSERT(graphOutIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphOutIdx);
auto &graph_out_idxes = graph->outputIndex;
for (auto graph_out_idx : graph_out_idxes) {
MS_ASSERT(graph_out_idx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graph_out_idx);
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
continue;
}
int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown
? this->outputDataDType
: TensorDataType::GetInstance()->GetTensorType(graphOutIdx);
int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown
? this->output_data_dtype
: TensorDataType::GetInstance()->GetTensorType(graph_out_idx);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto nodeName = (*iter)->name;
auto node_name = (*iter)->name;
MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) {
// insert transNode
STATUS status = RET_OK;
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status);
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
iter =
InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
MS_LOG(ERROR) << "InsertDTypeTransNode after " << node_name.c_str() << " failed";
return status;
}
break;
@ -231,52 +233,53 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
return RET_OK;
}
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, int32_t inputDataType, int32_t outputDataType,
STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
STATUS *error_code) {
MS_ASSERT((*exist_node_iter) != nullptr);
auto exist_node_name = (*exist_node_iter)->name;
std::string tile_name;
if (place == kBefore) {
tileName = existNodeName + "_pre";
tile_name = exist_node_name + "_pre";
} else {
tileName = existNodeName + "_post";
tile_name = exist_node_name + "_post";
}
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) {
auto trans_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (trans_node == nullptr) {
MS_LOG(ERROR) << "new TransNode failed";
*errorCode = RET_ERROR;
*error_code = RET_ERROR;
return graph->nodes.end();
}
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (quantDTypeCastParam == nullptr) {
auto quant_dtype_cast_param = new (std::nothrow) QuantDTypeCastT;
if (quant_dtype_cast_param == nullptr) {
MS_LOG(ERROR) << "new quantDTypeCastParam failed";
*errorCode = RET_ERROR;
*error_code = RET_ERROR;
return graph->nodes.end();
}
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTraining;
quantDTypeCastParam->src_t = inputDataType;
quantDTypeCastParam->dst_t = outputDataType;
if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) {
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
} else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) {
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
} else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) {
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
} else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) {
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
trans_node->primitive = std::make_unique<schema::PrimitiveT>();
trans_node->primitive->value.value = quant_dtype_cast_param;
trans_node->primitive->value.type = PrimitiveType_QuantDTypeCast;
trans_node->quantType = QuantType_AwareTraining;
quant_dtype_cast_param->src_t = input_data_type;
quant_dtype_cast_param->dst_t = output_data_type;
if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeFloat32) {
trans_node->name = "int8toft32_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeInt8) {
trans_node->name = "ft32toint8_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeInt8) {
trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++);
} else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) {
trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++);
}
transNode->primitive->value.value = quantDTypeCastParam;
trans_node->primitive->value.value = quant_dtype_cast_param;
int insert_num = 0;
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, castOpCopyer);
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
castOpCopyer);
}
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; }
void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
} // namespace lite
} // namespace mindspore

View File

@ -30,15 +30,15 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
class DTypeTransPass : public GraphPass {
public:
DTypeTransPass() : id(0) {}
DTypeTransPass() : id_(0) {}
~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
void SetInputDataDType(TypeId dataType);
void set_input_data_dtype(TypeId data_type);
void SetOutputDataDType(TypeId dataType);
void set_output_data_dtype(TypeId dataType);
private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
@ -51,13 +51,14 @@ class DTypeTransPass : public GraphPass {
STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
STATUS *error_code);
private:
size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat;
TypeId outputDataDType = TypeId::kNumberTypeFloat;
size_t id_;
TypeId input_data_dtype = TypeId::kNumberTypeFloat;
TypeId output_data_dtype = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);

View File

@ -45,32 +45,32 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
return RET_OK;
}
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
FormatTransNodeType *afterNodeType) {
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *after_node_type) {
if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc
if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*beforeNodeType = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC;
*before_node_type = kNHWC2NCHW;
*after_node_type = kNCHW2NHWC;
return RET_OK;
} else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS ||
fmk_type_ == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*beforeNodeType = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW;
*before_node_type = kNCHW2NHWC;
*after_node_type = kNHWC2NCHW;
return RET_OK;
} else if (fmk_type_ == converter::FmkType_TF) {
if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) {
*beforeNodeType = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW;
*before_node_type = kNCHW2NHWC;
*after_node_type = kNHWC2NCHW;
return RET_OK;
}
if (IsContain(GetNchwOpList(), GetCNodeTType(node))) {
*beforeNodeType = kNHWC2NCHW;
*afterNodeType = kNCHW2NHWC;
*before_node_type = kNHWC2NCHW;
*after_node_type = kNCHW2NHWC;
return RET_OK;
}
return RET_NO_CHANGE;
@ -96,36 +96,34 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
return RET_OK;
}
}
auto graphInputIdxes = graph->inputIndex;
for (size_t i = 0; i < graphInputIdxes.size(); i++) {
auto graph_input_idxes = graph->inputIndex;
for (size_t i = 0; i < graph_input_idxes.size(); i++) {
bool transed = false;
auto inputIdx = graphInputIdxes.at(i);
MS_ASSERT(inputIdx < subGraph->allTensors.size());
auto &tensor = graph->allTensors.at(inputIdx);
auto input_idx = graph_input_idxes.at(i);
auto &tensor = graph->allTensors.at(input_idx);
if (tensor->dims.size() != kNCHWDimNumber) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) {
for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) {
if ((*iter)->inputIndex.at(input_index_idx) == input_idx) {
STATUS status = RET_OK;
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
return status;
}
// set first tensor format to nhwc
auto &transNode = *(iter - 1);
MS_ASSERT(transNode != nullptr);
MS_ASSERT(transNode->inputIndex.size() == 1);
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front());
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front());
graphInTensor->format = schema::Format::Format_NHWC;
auto &trans_node = *(iter - 1);
MS_ASSERT(trans_node != nullptr);
MS_ASSERT(trans_node->inputIndex.size() == 1);
auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front());
graph_in_tensor->format = schema::Format::Format_NHWC;
// assume parser not reformat shape
auto oldDims = graphInTensor->dims;
auto old_dims = graph_in_tensor->dims;
if (!transed) {
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]};
transed = true;
}
}
@ -143,10 +141,10 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert before and after the op cal by nchw/nc4hw4
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
FormatTransNodeType beforeNodeType = kNCHW2NHWC;
FormatTransNodeType afterNodeType = kNHWC2NCHW;
FormatTransNodeType before_node_type = kNCHW2NHWC;
FormatTransNodeType after_node_type = kNHWC2NCHW;
STATUS status = RET_OK;
status = GetInsertFormatTrans(**iter, &beforeNodeType, &afterNodeType);
status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type);
if (status == RET_NO_CHANGE) {
continue;
}
@ -170,17 +168,17 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
}
auto specInsertIndexes = GetExtNhwcIndexes();
auto opType = GetCNodeTType(**iter);
if (specInsertIndexes.find(opType) != specInsertIndexes.end()) {
for (auto insert_index : specInsertIndexes[opType]) {
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status);
auto spec_insert_indexes = GetExtNhwcIndexes();
auto op_type = GetCNodeTType(**iter);
if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) {
for (auto insert_index : spec_insert_indexes[op_type]) {
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR;
}
}
} else if (IsContain(GetNhwcAllInputOpList(), opType)) {
} else if (IsContain(GetNhwcAllInputOpList(), op_type)) {
auto input_size = node->inputIndex.size();
if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) {
if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) {
@ -188,16 +186,16 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
}
}
for (size_t i = 0; i < input_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status);
iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR;
}
}
} else {
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status);
}
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR;
@ -206,29 +204,29 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
return RET_OK;
}
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) {
MS_ASSERT((*exist_node_iter) != nullptr);
MS_ASSERT(graph != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
auto exist_node_name = (*exist_node_iter)->name;
std::string tile_name;
if (place == kBefore) {
tileName = existNodeName + "_pre";
tile_name = exist_node_name + "_pre";
} else {
tileName = existNodeName + "_post";
tile_name = exist_node_name + "_post";
}
auto transNode = std::make_unique<schema::CNodeT>();
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.type = schema::PrimitiveType_Transpose;
auto trans_node = std::make_unique<schema::CNodeT>();
trans_node->primitive = std::make_unique<schema::PrimitiveT>();
trans_node->primitive->value.type = schema::PrimitiveType_Transpose;
auto perm_tensor = std::make_unique<schema::TensorT>();
perm_tensor->dataType = kNumberTypeInt32;
perm_tensor->dims = {4};
std::vector<int> perm;
if (nodeType == kNCHW2NHWC) {
transNode->name = "nchw2nhwc_" + tileName + std::to_string(id_++);
if (node_type == kNCHW2NHWC) {
trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++);
perm = {0, 2, 3, 1};
} else {
transNode->name = "nhwc2nchw_" + tileName + std::to_string(id_++);
trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++);
perm = {0, 3, 1, 2};
}
size_t bytes = perm.size() * sizeof(int);
@ -236,27 +234,27 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
}
perm_tensor->name = transNode->name + "_perm";
perm_tensor->name = trans_node->name + "_perm";
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
auto newOpDef = std::make_unique<schema::CNodeT>();
if (newOpDef == nullptr) {
OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> {
auto new_op_def = std::make_unique<schema::CNodeT>();
if (new_op_def == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newOpDef->name = inOpDef->name;
newOpDef->quantType = inOpDef->quantType;
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
if (newOpDef->primitive == nullptr) {
new_op_def->name = in_op_def->name;
new_op_def->quantType = in_op_def->quantType;
new_op_def->primitive = std::make_unique<schema::PrimitiveT>();
if (new_op_def->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr;
}
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
return newOpDef;
new_op_def->primitive->value.type = schema::PrimitiveType_Transpose;
return new_op_def;
};
int insert_num = 0;
auto iter =
InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, TransposeOpCopyer);
auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
transpose_op_copyer);
size_t index = graph->allTensors.size();
graph->allTensors.push_back(std::move(perm_tensor));
for (int i = insert_num; i > 0; --i) {

View File

@ -34,13 +34,13 @@ class FormatTransPass : public GraphPass {
STATUS Run(schema::MetaGraphT *graph) override;
void SetQuantType(QuantType quantType) { this->quant_type_ = quantType; }
void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; }
void SetFmk(converter::FmkType fmkType) { this->fmk_type_ = fmkType; }
void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
protected:
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
FormatTransNodeType nodeType, STATUS *errorCode);
NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place,
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code);
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
@ -61,8 +61,8 @@ class FormatTransPass : public GraphPass {
int GetFormat(const schema::CNodeT &);
STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
FormatTransNodeType *afterNodeType);
STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
FormatTransNodeType *after_node_type);
protected:
size_t id_ = 0;

View File

@ -20,13 +20,13 @@
namespace mindspore {
namespace lite {
Optimizer::~Optimizer() {
for (auto pass : graphPasses) {
for (auto pass : graph_passes_) {
if (pass != nullptr) {
delete (pass);
}
}
for (auto pass : nodePasses) {
for (auto pass : node_passes_) {
if (pass != nullptr) {
delete (pass);
}
@ -35,13 +35,13 @@ Optimizer::~Optimizer() {
void Optimizer::AddPass(GraphPass *graphPass) {
if (graphPass != nullptr) {
this->graphPasses.emplace_back(graphPass);
this->graph_passes_.emplace_back(graphPass);
}
}
void Optimizer::AddPass(NodePass *nodePass) {
if (nodePass != nullptr) {
this->nodePasses.emplace_back(nodePass);
this->node_passes_.emplace_back(nodePass);
}
}
@ -51,7 +51,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
bool ifNotChanged = true;
// each node should go through all node pass not each node pass go through all node
for (auto &opDef : graphDefT->nodes) {
for (auto pass : this->nodePasses) {
for (auto pass : this->node_passes_) {
status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get()));
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run NodePass failed";
@ -64,7 +64,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
}
}
for (auto pass : this->graphPasses) {
for (auto pass : this->graph_passes_) {
status = pass->Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run GraphPass failed";

View File

@ -41,10 +41,10 @@ class GraphPass : public Pass<schema::MetaGraphT> {
};
struct GraphNode {
GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {}
GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : sub_graph_(subGraph), op_def_(opDefT) {}
~GraphNode() = default;
schema::MetaGraphT *subGraph = nullptr;
schema::CNodeT *opDef = nullptr;
schema::MetaGraphT *sub_graph_ = nullptr;
schema::CNodeT *op_def_ = nullptr;
};
class NodePass : public Pass<GraphNode> {
@ -72,8 +72,8 @@ class Optimizer {
STATUS Run(schema::MetaGraphT *graphDefT);
private:
std::vector<GraphPass *> graphPasses;
std::vector<NodePass *> nodePasses;
std::vector<GraphPass *> graph_passes_;
std::vector<NodePass *> node_passes_;
};
} // namespace lite
} // namespace mindspore

View File

@ -98,7 +98,7 @@ STATUS CaffeModelParser::ConvertLayers() {
MS_LOG(INFO) << "parse op : " << layer.type();
auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(layer.type());
NotSupportOp::GetInstance()->InsertOp(layer.type());
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}

View File

@ -47,7 +47,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
NoSupportOp::GetInstance()->SetFmkType("ONNX");
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
anf_root_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);
if (RET_OK != status) {
@ -195,7 +195,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
for (const auto &onnx_node : onnx_graph.node()) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
status = status == RET_OK ? RET_NOT_FIND_OP : status;
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
}

View File

@ -476,7 +476,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
FuncGraphPtr paserTfFuction() { return nullptr; }
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
NoSupportOp::GetInstance()->SetFmkType("TF");
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
@ -888,7 +888,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(INFO) << "parse op : " << op_type;
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
NotSupportOp::GetInstance()->InsertOp(op_type);
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
<< func_graph_ptr->get_attr("graph_name")->ToString();
return RET_NOT_FIND_OP;

View File

@ -101,7 +101,7 @@ std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type,
STATUS TfliteModelParser::ConvertOps() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
NotSupportOp::GetInstance()->set_fmk_type("TFLITE");
STATUS status = RET_OK;
int op_idx = 0;
for (auto &op : tflite_subgraph->operators) {
@ -113,7 +113,7 @@ STATUS TfliteModelParser::ConvertOps() {
MS_LOG(INFO) << "parse node :" << op_name;
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
NotSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}

View File

@ -1344,7 +1344,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// add quant_cast
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
quant_cast.set_input_data_dtype(kNumberTypeFloat32);
status = quant_cast.Run(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";

View File

@ -26,12 +26,12 @@ namespace mindspore::lite::quant {
class QuantCast {
public:
QuantCast() = default;
~QuantCast() = default;
virtual ~QuantCast() = default;
STATUS Run(const FuncGraphPtr &graph);
void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
void set_input_data_dtype(TypeId data_type) { this->input_data_dtype_ = data_type; }
private:
TypeId inputDataDType = kNumberTypeFloat32;
TypeId input_data_dtype_ = kNumberTypeFloat32;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H