forked from mindspore-Ecosystem/mindspore
correct name of func and variable
This commit is contained in:
parent
c9d9e1cf32
commit
0d8302c0d1
|
@ -121,8 +121,8 @@ int RunConverter(int argc, const char **argv) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto meta_graph = converter->Convert(flags);
|
||||
NoSupportOp::GetInstance()->PrintOps();
|
||||
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
|
||||
NotSupportOp::GetInstance()->PrintOps();
|
||||
status = ReturnCode::GetSingleReturnCode()->status_code();
|
||||
if (meta_graph == nullptr) {
|
||||
oss.clear();
|
||||
oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status);
|
||||
|
|
|
@ -28,67 +28,67 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class ReturnCode {
|
||||
public:
|
||||
~ReturnCode() = default;
|
||||
virtual ~ReturnCode() = default;
|
||||
static ReturnCode *GetSingleReturnCode() {
|
||||
static ReturnCode returnCode;
|
||||
return &returnCode;
|
||||
static ReturnCode return_code;
|
||||
return &return_code;
|
||||
}
|
||||
void UpdateReturnCode(STATUS status) {
|
||||
if (statusCode == RET_OK) {
|
||||
statusCode = status;
|
||||
if (status_code_ == RET_OK) {
|
||||
status_code_ = status;
|
||||
}
|
||||
}
|
||||
STATUS GetReturnCode() const { return statusCode; }
|
||||
STATUS status_code() const { return status_code_; }
|
||||
|
||||
private:
|
||||
ReturnCode() { statusCode = RET_OK; }
|
||||
int statusCode;
|
||||
ReturnCode() = default;
|
||||
int status_code_ = RET_OK;
|
||||
};
|
||||
|
||||
class NoSupportOp {
|
||||
class NotSupportOp {
|
||||
public:
|
||||
~NoSupportOp() = default;
|
||||
static NoSupportOp *GetInstance() {
|
||||
static NoSupportOp noSupportOp;
|
||||
return &noSupportOp;
|
||||
virtual ~NotSupportOp() = default;
|
||||
static NotSupportOp *GetInstance() {
|
||||
static NotSupportOp not_support_op;
|
||||
return ¬_support_op;
|
||||
}
|
||||
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; }
|
||||
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
|
||||
void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; }
|
||||
void InsertOp(const std::string &op_name) { not_support_ops_.insert(op_name); }
|
||||
void PrintOps() const {
|
||||
if (!noSupportOps.empty()) {
|
||||
if (!not_support_ops_.empty()) {
|
||||
MS_LOG(ERROR) << "===========================================";
|
||||
MS_LOG(ERROR) << "UNSUPPORTED OP LIST:";
|
||||
for (auto &op_name : noSupportOps) {
|
||||
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name;
|
||||
for (auto &op_name : not_support_ops_) {
|
||||
MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name;
|
||||
}
|
||||
MS_LOG(ERROR) << "===========================================";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
NoSupportOp() { noSupportOps.clear(); }
|
||||
std::set<std::string> noSupportOps;
|
||||
std::string fmkType;
|
||||
NotSupportOp() = default;
|
||||
std::set<std::string> not_support_ops_;
|
||||
std::string fmk_type_;
|
||||
};
|
||||
|
||||
class TensorDataType {
|
||||
public:
|
||||
~TensorDataType() = default;
|
||||
static TensorDataType *GetInstance() {
|
||||
static TensorDataType tensorDataType;
|
||||
return &tensorDataType;
|
||||
static TensorDataType tensor_data_type;
|
||||
return &tensor_data_type;
|
||||
}
|
||||
void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; }
|
||||
void UpdateTensorType(int32_t index, int32_t type) { tensor_data_type_map_[index] = type; }
|
||||
int32_t GetTensorType(int32_t index) const {
|
||||
if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) {
|
||||
if (tensor_data_type_map_.find(index) == tensor_data_type_map_.end()) {
|
||||
return TypeId::kTypeUnknown;
|
||||
}
|
||||
return tensorDataTypeMap.at(index);
|
||||
return tensor_data_type_map_.at(index);
|
||||
}
|
||||
|
||||
private:
|
||||
TensorDataType() {}
|
||||
std::map<int32_t, int32_t> tensorDataTypeMap;
|
||||
std::map<int32_t, int32_t> tensor_data_type_map_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,17 +30,17 @@ Flags::Flags() {
|
|||
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
|
||||
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
|
||||
"");
|
||||
AddFlag(&Flags::inputDataTypeIn, "inputDataType",
|
||||
AddFlag(&Flags::inputDataTypeStr, "inputDataType",
|
||||
"Data type of input tensors, default is same with the type defined in model. FLOAT | INT8 | UINT8 | DEFAULT",
|
||||
"DEFAULT");
|
||||
AddFlag(&Flags::outputDataTypeIn, "outputDataType",
|
||||
AddFlag(&Flags::outputDataTypeStr, "outputDataType",
|
||||
"Data type of output and output tensors, default is same with the type defined in model. FLOAT | INT8 | "
|
||||
"UINT8 | DEFAULT",
|
||||
"DEFAULT");
|
||||
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
|
||||
AddFlag(&Flags::quantTypeStr, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
|
||||
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
|
||||
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
|
||||
AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16");
|
||||
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
|
||||
AddFlag(&Flags::trainModelIn, "trainModel",
|
||||
"whether the model is going to be trained on device. "
|
||||
|
@ -49,32 +49,32 @@ Flags::Flags() {
|
|||
}
|
||||
|
||||
int Flags::InitInputOutputDataType() {
|
||||
if (this->inputDataTypeIn == "FLOAT") {
|
||||
if (this->inputDataTypeStr == "FLOAT") {
|
||||
this->inputDataType = TypeId::kNumberTypeFloat32;
|
||||
} else if (this->inputDataTypeIn == "INT8") {
|
||||
} else if (this->inputDataTypeStr == "INT8") {
|
||||
this->inputDataType = TypeId::kNumberTypeInt8;
|
||||
} else if (this->inputDataTypeIn == "UINT8") {
|
||||
} else if (this->inputDataTypeStr == "UINT8") {
|
||||
this->inputDataType = TypeId::kNumberTypeUInt8;
|
||||
} else if (this->inputDataTypeIn == "DEFAULT") {
|
||||
} else if (this->inputDataTypeStr == "DEFAULT") {
|
||||
this->inputDataType = TypeId::kTypeUnknown;
|
||||
} else {
|
||||
std::cerr << "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
|
||||
this->inputDataTypeIn.c_str();
|
||||
this->inputDataTypeStr.c_str();
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->outputDataTypeIn == "FLOAT") {
|
||||
if (this->outputDataTypeStr == "FLOAT") {
|
||||
this->outputDataType = TypeId::kNumberTypeFloat32;
|
||||
} else if (this->outputDataTypeIn == "INT8") {
|
||||
} else if (this->outputDataTypeStr == "INT8") {
|
||||
this->outputDataType = TypeId::kNumberTypeInt8;
|
||||
} else if (this->outputDataTypeIn == "UINT8") {
|
||||
} else if (this->outputDataTypeStr == "UINT8") {
|
||||
this->outputDataType = TypeId::kNumberTypeUInt8;
|
||||
} else if (this->outputDataTypeIn == "DEFAULT") {
|
||||
} else if (this->outputDataTypeStr == "DEFAULT") {
|
||||
this->outputDataType = TypeId::kTypeUnknown;
|
||||
} else {
|
||||
std::cerr
|
||||
<< "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT",
|
||||
this->outputDataTypeIn.c_str();
|
||||
this->outputDataTypeStr.c_str();
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -110,7 +110,7 @@ bool Flags::IsValidNum(const std::string &str, int *num) {
|
|||
}
|
||||
|
||||
int Flags::QuantParamInputCheck() {
|
||||
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
|
||||
if (!Flags::IsValidNum(this->quantWeightChannelStr, &this->quantWeightChannel)) {
|
||||
std::cerr << "quantWeightChannel should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ int Flags::QuantParamInputCheck() {
|
|||
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
|
||||
if (!Flags::IsValidNum(this->quantWeightSizeStr, &this->quantWeightSize)) {
|
||||
std::cerr << "quantWeightSize should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
@ -138,11 +138,11 @@ int Flags::QuantParamInputCheck() {
|
|||
}
|
||||
|
||||
int Flags::InitQuantParam() {
|
||||
if (this->quantTypeIn == "WeightQuant") {
|
||||
if (this->quantTypeStr == "WeightQuant") {
|
||||
this->quantType = QuantType_WeightQuant;
|
||||
} else if (this->quantTypeIn == "PostTraining") {
|
||||
} else if (this->quantTypeStr == "PostTraining") {
|
||||
this->quantType = QuantType_PostTraining;
|
||||
} else if (this->quantTypeIn.empty()) {
|
||||
} else if (this->quantTypeStr.empty()) {
|
||||
this->quantType = QuantType_QUANT_NONE;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";
|
||||
|
|
|
@ -65,25 +65,20 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string fmkIn;
|
||||
FmkType fmk;
|
||||
std::string weightFile;
|
||||
std::string inputArrays;
|
||||
std::string outputArrays;
|
||||
std::string inputShapes;
|
||||
// used for quantization
|
||||
std::string quantTypeIn;
|
||||
QuantType quantType;
|
||||
std::string inferenceTypeIn;
|
||||
std::string inputDataTypeIn;
|
||||
std::string outputDataTypeIn;
|
||||
// used for parse aware trainning
|
||||
TypeId inputDataType;
|
||||
TypeId outputDataType;
|
||||
// used for quantization
|
||||
std::string quantTypeStr;
|
||||
QuantType quantType;
|
||||
std::string inputDataTypeStr;
|
||||
std::string outputDataTypeStr;
|
||||
// used for post-trainning-weight
|
||||
std::string quantWeightSizeIn;
|
||||
std::string quantWeightSizeStr;
|
||||
int quantWeightSize;
|
||||
std::string bitNumIn;
|
||||
int bitNum;
|
||||
std::string configFile;
|
||||
std::string quantWeightChannelIn;
|
||||
std::string quantWeightChannelStr;
|
||||
int quantWeightChannel;
|
||||
std::string trainModelIn;
|
||||
bool trainModel = false;
|
||||
|
|
|
@ -47,8 +47,8 @@ namespace mindspore::lite {
|
|||
|
||||
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
|
||||
std::vector<schema::CNodeT *> old_nodes{};
|
||||
old_nodes.resize(graphDefT->nodes.size());
|
||||
std::transform(graphDefT->nodes.begin(), graphDefT->nodes.end(), old_nodes.begin(),
|
||||
old_nodes.resize(graph_defT_->nodes.size());
|
||||
std::transform(graph_defT_->nodes.begin(), graph_defT_->nodes.end(), old_nodes.begin(),
|
||||
[](const std::unique_ptr<schema::CNodeT> &node) { return node.get(); });
|
||||
return old_nodes;
|
||||
}
|
||||
|
@ -57,33 +57,33 @@ GraphDefTransform::GraphDefTransform() = default;
|
|||
|
||||
GraphDefTransform::~GraphDefTransform() = default;
|
||||
|
||||
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
|
||||
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *dst_def) { graph_defT_ = dst_def; }
|
||||
|
||||
int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||
STATUS status;
|
||||
{
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer unusedOpRemoveOptimizer;
|
||||
Optimizer unused_op_remove_optimizer;
|
||||
if (!ctx.trainModel) {
|
||||
unusedOpRemoveOptimizer.AddPass(new DropoutNodeRemovePass());
|
||||
unused_op_remove_optimizer.AddPass(new DropoutNodeRemovePass());
|
||||
}
|
||||
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
|
||||
unusedOpRemoveOptimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = unusedOpRemoveOptimizer.Run(graphDefT);
|
||||
unused_op_remove_optimizer.AddPass(new IsolatedNodeRemovePass());
|
||||
unused_op_remove_optimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = unused_op_remove_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run unused_op_remove_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// generate and infer quant parameters
|
||||
{
|
||||
Optimizer inferQuantParamPass;
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
inferQuantParamPass.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
status = inferQuantParamPass.Run(graphDefT);
|
||||
Optimizer infer_quant_param_pass;
|
||||
infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
status = infer_quant_param_pass.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -93,40 +93,40 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
|
||||
Optimizer formatTransOptimizer;
|
||||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
Optimizer format_trans_optimizer;
|
||||
auto format_trans_pass = new (std::nothrow) FormatTransPass();
|
||||
if (format_trans_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
format_trans_pass->set_quant_type(ctx.quantType);
|
||||
format_trans_pass->set_fmk_type(ctx.fmk);
|
||||
format_trans_optimizer.AddPass(format_trans_pass);
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
}
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
status = format_trans_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer formatTransOptimizer;
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
Optimizer format_trans_optimizer;
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) TransOpRemovePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) TransOpInsertPass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = format_trans_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -134,15 +134,15 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer formatTransOptimizer;
|
||||
Optimizer format_trans_optimizer;
|
||||
if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) {
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
}
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
status = format_trans_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer fusionOptimizer;
|
||||
Optimizer replace_optimizer;
|
||||
if (!ctx.trainModel) {
|
||||
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
|
||||
if (batch_norm_scale_pass == nullptr) {
|
||||
|
@ -159,13 +159,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
batch_norm_scale_pass->SetFmk(ctx.fmk);
|
||||
fusionOptimizer.AddPass(batch_norm_scale_pass);
|
||||
replace_optimizer.AddPass(batch_norm_scale_pass);
|
||||
}
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusionOptimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||
status = replace_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed";
|
||||
MS_LOG(ERROR) << "Run replace_optimizer BatchNormConvertScalePass Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -173,13 +173,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer fusionOptimizer;
|
||||
fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
Optimizer fusion_optimizer;
|
||||
fusion_optimizer.AddPass(new (std::nothrow) MulAddFusionPass());
|
||||
fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = fusion_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run fusion_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -188,12 +188,12 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer tensorQuantOptimizer;
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||
tensorQuantOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = tensorQuantOptimizer.Run(graphDefT);
|
||||
Optimizer tensor_quant_optimizer;
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = tensor_quant_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantize failed!";
|
||||
return status;
|
||||
|
@ -204,31 +204,31 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
if (ctx.fmk != converter::FmkType_TF) {
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer quantNodeOptimizer;
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
Optimizer quant_node_optimizer;
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass());
|
||||
status = quant_node_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
auto old_nodes2 = GetGraphNodes();
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
|
||||
if (dtype_trans_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dtype_trans_pass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(dTypeTransPass);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
|
||||
dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
|
||||
quant_node_optimizer.AddPass(dtype_trans_pass);
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
|
||||
status = quant_node_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -237,22 +237,22 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer switchOptimizer;
|
||||
switchOptimizer.AddPass(new (std::nothrow) SwitchPass());
|
||||
switchOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
switchOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = switchOptimizer.Run(graphDefT);
|
||||
Optimizer switch_optimizer;
|
||||
switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
|
||||
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
status = switch_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run switch graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run switch_optimizer Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// subgraph tensor pass
|
||||
{
|
||||
Optimizer subgraphTensorOptimizer;
|
||||
subgraphTensorOptimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
|
||||
status = subgraphTensorOptimizer.Run(graphDefT);
|
||||
Optimizer subgraph_tensor_optimizer;
|
||||
subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
|
||||
status = subgraph_tensor_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
|
||||
return status;
|
||||
|
@ -263,33 +263,33 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes();
|
||||
Optimizer nameOptimizer;
|
||||
nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
nameOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
nameOptimizer.AddPass(new (std::nothrow) TensorNamePass());
|
||||
status = nameOptimizer.Run(graphDefT);
|
||||
Optimizer name_optimizer;
|
||||
name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
|
||||
status = name_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run nameOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
Optimizer nestedLoopOptimizer;
|
||||
nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
|
||||
status = nestedLoopOptimizer.Run(graphDefT);
|
||||
Optimizer nested_loop_optimizer;
|
||||
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
|
||||
status = nested_loop_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run nested_loop_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
Optimizer quantNodeOptimizer;
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||
status = quantNodeOptimizer.Run(graphDefT);
|
||||
Optimizer quant_param_optimizer;
|
||||
quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||
status = quant_param_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
|
||||
MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,12 +36,12 @@ class GraphDefTransform {
|
|||
GraphDefTransform();
|
||||
virtual ~GraphDefTransform();
|
||||
virtual int Transform(const converter::Flags &ctx);
|
||||
void SetGraphDef(schema::MetaGraphT *dstDef);
|
||||
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
|
||||
void SetGraphDef(schema::MetaGraphT *dst_def);
|
||||
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
|
||||
|
||||
protected:
|
||||
std::vector<schema::CNodeT *> GetGraphNodes();
|
||||
schema::MetaGraphT *graphDefT = nullptr;
|
||||
schema::MetaGraphT *graph_defT_ = nullptr;
|
||||
Optimizer *optimizer = nullptr;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -55,34 +55,35 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
|||
|
||||
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto &graphInIdxes = graph->inputIndex;
|
||||
if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 &&
|
||||
this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
|
||||
auto &graph_in_idxes = graph->inputIndex;
|
||||
if (this->input_data_dtype != TypeId::kNumberTypeFloat32 && this->input_data_dtype != TypeId::kNumberTypeUInt8 &&
|
||||
this->input_data_dtype != TypeId::kNumberTypeInt8 && this->input_data_dtype != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid inputDataType: " << this->input_data_dtype;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto graphInIdx : graphInIdxes) {
|
||||
MS_ASSERT(graphInIdx < graph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(graphInIdx);
|
||||
for (auto graph_in_idx : graph_in_idxes) {
|
||||
MS_ASSERT(graph_in_idx < graph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(graph_in_idx);
|
||||
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
|
||||
continue;
|
||||
}
|
||||
int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown
|
||||
? this->inputDataDType
|
||||
: TensorDataType::GetInstance()->GetTensorType(graphInIdx);
|
||||
int32_t tensor_data_type = this->input_data_dtype != TypeId::kTypeUnknown
|
||||
? this->input_data_dtype
|
||||
: TensorDataType::GetInstance()->GetTensorType(graph_in_idx);
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto nodeName = (*iter)->name;
|
||||
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
|
||||
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) {
|
||||
auto node_name = (*iter)->name;
|
||||
for (size_t input_indexidx = 0; input_indexidx < (*iter)->inputIndex.size(); input_indexidx++) {
|
||||
if ((*iter)->inputIndex.at(input_indexidx) == graph_in_idx) {
|
||||
STATUS status = RET_OK;
|
||||
|
||||
// insert dtype cast node between input tensor and input node
|
||||
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
|
||||
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status);
|
||||
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
|
||||
iter =
|
||||
InsertDTypeTransNode(graph, iter, kBefore, input_indexidx, tensor_data_type, tensor->dataType, &status);
|
||||
}
|
||||
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed";
|
||||
MS_LOG(ERROR) << "InsertDTypeTransNode before " << node_name.c_str() << " failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -94,33 +95,34 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
|
||||
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 &&
|
||||
this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType;
|
||||
if (this->output_data_dtype != TypeId::kNumberTypeFloat32 && this->output_data_dtype != TypeId::kNumberTypeUInt8 &&
|
||||
this->output_data_dtype != TypeId::kNumberTypeInt8 && this->output_data_dtype != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid outputDataType: " << this->output_data_dtype;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto &graphOutIdxes = graph->outputIndex;
|
||||
for (auto graphOutIdx : graphOutIdxes) {
|
||||
MS_ASSERT(graphOutIdx < graph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(graphOutIdx);
|
||||
auto &graph_out_idxes = graph->outputIndex;
|
||||
for (auto graph_out_idx : graph_out_idxes) {
|
||||
MS_ASSERT(graph_out_idx < graph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(graph_out_idx);
|
||||
if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) {
|
||||
continue;
|
||||
}
|
||||
int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown
|
||||
? this->outputDataDType
|
||||
: TensorDataType::GetInstance()->GetTensorType(graphOutIdx);
|
||||
int32_t tensor_data_type = this->output_data_dtype != TypeId::kTypeUnknown
|
||||
? this->output_data_dtype
|
||||
: TensorDataType::GetInstance()->GetTensorType(graph_out_idx);
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto nodeName = (*iter)->name;
|
||||
auto node_name = (*iter)->name;
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) {
|
||||
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
|
||||
if ((*iter)->outputIndex.at(outputIndexIdx) == graph_out_idx) {
|
||||
// insert transNode
|
||||
STATUS status = RET_OK;
|
||||
if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) {
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status);
|
||||
if (tensor_data_type != tensor->dataType && tensor_data_type != kTypeUnknown) {
|
||||
iter =
|
||||
InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensor_data_type, &status);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
|
||||
MS_LOG(ERROR) << "InsertDTypeTransNode after " << node_name.c_str() << " failed";
|
||||
return status;
|
||||
}
|
||||
break;
|
||||
|
@ -231,52 +233,53 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
|
||||
size_t inoutIdx, int32_t inputDataType, int32_t outputDataType,
|
||||
STATUS *errorCode) {
|
||||
MS_ASSERT((*existNodeIter) != nullptr);
|
||||
auto existNodeName = (*existNodeIter)->name;
|
||||
std::string tileName;
|
||||
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
|
||||
size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
|
||||
STATUS *error_code) {
|
||||
MS_ASSERT((*exist_node_iter) != nullptr);
|
||||
auto exist_node_name = (*exist_node_iter)->name;
|
||||
std::string tile_name;
|
||||
if (place == kBefore) {
|
||||
tileName = existNodeName + "_pre";
|
||||
tile_name = exist_node_name + "_pre";
|
||||
} else {
|
||||
tileName = existNodeName + "_post";
|
||||
tile_name = exist_node_name + "_post";
|
||||
}
|
||||
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
|
||||
if (transNode == nullptr) {
|
||||
auto trans_node = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
|
||||
if (trans_node == nullptr) {
|
||||
MS_LOG(ERROR) << "new TransNode failed";
|
||||
*errorCode = RET_ERROR;
|
||||
*error_code = RET_ERROR;
|
||||
return graph->nodes.end();
|
||||
}
|
||||
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
|
||||
if (quantDTypeCastParam == nullptr) {
|
||||
auto quant_dtype_cast_param = new (std::nothrow) QuantDTypeCastT;
|
||||
if (quant_dtype_cast_param == nullptr) {
|
||||
MS_LOG(ERROR) << "new quantDTypeCastParam failed";
|
||||
*errorCode = RET_ERROR;
|
||||
*error_code = RET_ERROR;
|
||||
return graph->nodes.end();
|
||||
}
|
||||
transNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
transNode->primitive->value.value = quantDTypeCastParam;
|
||||
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
|
||||
transNode->quantType = QuantType_AwareTraining;
|
||||
quantDTypeCastParam->src_t = inputDataType;
|
||||
quantDTypeCastParam->dst_t = outputDataType;
|
||||
if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeFloat32) {
|
||||
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
|
||||
} else if (inputDataType == TypeId::kNumberTypeFloat32 && outputDataType == TypeId::kNumberTypeInt8) {
|
||||
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
|
||||
} else if (inputDataType == TypeId::kNumberTypeUInt8 && outputDataType == TypeId::kNumberTypeInt8) {
|
||||
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
|
||||
} else if (inputDataType == TypeId::kNumberTypeInt8 && outputDataType == TypeId::kNumberTypeUInt8) {
|
||||
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
|
||||
trans_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
trans_node->primitive->value.value = quant_dtype_cast_param;
|
||||
trans_node->primitive->value.type = PrimitiveType_QuantDTypeCast;
|
||||
trans_node->quantType = QuantType_AwareTraining;
|
||||
quant_dtype_cast_param->src_t = input_data_type;
|
||||
quant_dtype_cast_param->dst_t = output_data_type;
|
||||
if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeFloat32) {
|
||||
trans_node->name = "int8toft32_" + tile_name + std::to_string(id_++);
|
||||
} else if (input_data_type == TypeId::kNumberTypeFloat32 && output_data_type == TypeId::kNumberTypeInt8) {
|
||||
trans_node->name = "ft32toint8_" + tile_name + std::to_string(id_++);
|
||||
} else if (input_data_type == TypeId::kNumberTypeUInt8 && output_data_type == TypeId::kNumberTypeInt8) {
|
||||
trans_node->name = "uint8toint8_" + tile_name + std::to_string(id_++);
|
||||
} else if (input_data_type == TypeId::kNumberTypeInt8 && output_data_type == TypeId::kNumberTypeUInt8) {
|
||||
trans_node->name = "int8touint8_" + tile_name + std::to_string(id_++);
|
||||
}
|
||||
transNode->primitive->value.value = quantDTypeCastParam;
|
||||
trans_node->primitive->value.value = quant_dtype_cast_param;
|
||||
int insert_num = 0;
|
||||
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, castOpCopyer);
|
||||
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
|
||||
castOpCopyer);
|
||||
}
|
||||
|
||||
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
|
||||
void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
|
||||
|
||||
void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; }
|
||||
void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,15 +30,15 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
|
|||
|
||||
class DTypeTransPass : public GraphPass {
|
||||
public:
|
||||
DTypeTransPass() : id(0) {}
|
||||
DTypeTransPass() : id_(0) {}
|
||||
|
||||
~DTypeTransPass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
void SetInputDataDType(TypeId dataType);
|
||||
void set_input_data_dtype(TypeId data_type);
|
||||
|
||||
void SetOutputDataDType(TypeId dataType);
|
||||
void set_output_data_dtype(TypeId dataType);
|
||||
|
||||
private:
|
||||
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
|
||||
|
@ -51,13 +51,14 @@ class DTypeTransPass : public GraphPass {
|
|||
|
||||
STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter);
|
||||
|
||||
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
|
||||
int32_t inputDataType, int32_t outputDataType, STATUS *errorCode);
|
||||
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
|
||||
size_t inout_idx, int32_t input_data_type, int32_t output_data_type,
|
||||
STATUS *error_code);
|
||||
|
||||
private:
|
||||
size_t id;
|
||||
TypeId inputDataDType = TypeId::kNumberTypeFloat;
|
||||
TypeId outputDataDType = TypeId::kNumberTypeFloat;
|
||||
size_t id_;
|
||||
TypeId input_data_dtype = TypeId::kNumberTypeFloat;
|
||||
TypeId output_data_dtype = TypeId::kNumberTypeFloat;
|
||||
|
||||
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
|
||||
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);
|
||||
|
|
|
@ -45,32 +45,32 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
|
||||
FormatTransNodeType *afterNodeType) {
|
||||
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
|
||||
FormatTransNodeType *after_node_type) {
|
||||
if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc
|
||||
if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
*beforeNodeType = kNHWC2NCHW;
|
||||
*afterNodeType = kNCHW2NHWC;
|
||||
*before_node_type = kNHWC2NCHW;
|
||||
*after_node_type = kNCHW2NHWC;
|
||||
return RET_OK;
|
||||
} else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS ||
|
||||
fmk_type_ == converter::FmkType_ONNX) {
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
*beforeNodeType = kNCHW2NHWC;
|
||||
*afterNodeType = kNHWC2NCHW;
|
||||
*before_node_type = kNCHW2NHWC;
|
||||
*after_node_type = kNHWC2NCHW;
|
||||
return RET_OK;
|
||||
} else if (fmk_type_ == converter::FmkType_TF) {
|
||||
if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) {
|
||||
*beforeNodeType = kNCHW2NHWC;
|
||||
*afterNodeType = kNHWC2NCHW;
|
||||
*before_node_type = kNCHW2NHWC;
|
||||
*after_node_type = kNHWC2NCHW;
|
||||
return RET_OK;
|
||||
}
|
||||
if (IsContain(GetNchwOpList(), GetCNodeTType(node))) {
|
||||
*beforeNodeType = kNHWC2NCHW;
|
||||
*afterNodeType = kNCHW2NHWC;
|
||||
*before_node_type = kNHWC2NCHW;
|
||||
*after_node_type = kNCHW2NHWC;
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_NO_CHANGE;
|
||||
|
@ -96,36 +96,34 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
}
|
||||
auto graphInputIdxes = graph->inputIndex;
|
||||
for (size_t i = 0; i < graphInputIdxes.size(); i++) {
|
||||
auto graph_input_idxes = graph->inputIndex;
|
||||
for (size_t i = 0; i < graph_input_idxes.size(); i++) {
|
||||
bool transed = false;
|
||||
auto inputIdx = graphInputIdxes.at(i);
|
||||
MS_ASSERT(inputIdx < subGraph->allTensors.size());
|
||||
auto &tensor = graph->allTensors.at(inputIdx);
|
||||
auto input_idx = graph_input_idxes.at(i);
|
||||
auto &tensor = graph->allTensors.at(input_idx);
|
||||
if (tensor->dims.size() != kNCHWDimNumber) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) {
|
||||
if ((*iter)->inputIndex.at(inputIndexIdx) == inputIdx) {
|
||||
for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) {
|
||||
if ((*iter)->inputIndex.at(input_index_idx) == input_idx) {
|
||||
STATUS status = RET_OK;
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status);
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed";
|
||||
return status;
|
||||
}
|
||||
// set first tensor format to nhwc
|
||||
auto &transNode = *(iter - 1);
|
||||
MS_ASSERT(transNode != nullptr);
|
||||
MS_ASSERT(transNode->inputIndex.size() == 1);
|
||||
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front());
|
||||
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front());
|
||||
graphInTensor->format = schema::Format::Format_NHWC;
|
||||
auto &trans_node = *(iter - 1);
|
||||
MS_ASSERT(trans_node != nullptr);
|
||||
MS_ASSERT(trans_node->inputIndex.size() == 1);
|
||||
auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front());
|
||||
graph_in_tensor->format = schema::Format::Format_NHWC;
|
||||
// assume parser not reformat shape
|
||||
auto oldDims = graphInTensor->dims;
|
||||
auto old_dims = graph_in_tensor->dims;
|
||||
if (!transed) {
|
||||
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]};
|
||||
graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]};
|
||||
transed = true;
|
||||
}
|
||||
}
|
||||
|
@ -143,10 +141,10 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
MS_ASSERT(graph != nullptr);
|
||||
// insert before and after the op cal by nchw/nc4hw4
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
FormatTransNodeType beforeNodeType = kNCHW2NHWC;
|
||||
FormatTransNodeType afterNodeType = kNHWC2NCHW;
|
||||
FormatTransNodeType before_node_type = kNCHW2NHWC;
|
||||
FormatTransNodeType after_node_type = kNHWC2NCHW;
|
||||
STATUS status = RET_OK;
|
||||
status = GetInsertFormatTrans(**iter, &beforeNodeType, &afterNodeType);
|
||||
status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type);
|
||||
if (status == RET_NO_CHANGE) {
|
||||
continue;
|
||||
}
|
||||
|
@ -170,17 +168,17 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
|
||||
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
|
||||
}
|
||||
auto specInsertIndexes = GetExtNhwcIndexes();
|
||||
auto opType = GetCNodeTType(**iter);
|
||||
if (specInsertIndexes.find(opType) != specInsertIndexes.end()) {
|
||||
for (auto insert_index : specInsertIndexes[opType]) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status);
|
||||
auto spec_insert_indexes = GetExtNhwcIndexes();
|
||||
auto op_type = GetCNodeTType(**iter);
|
||||
if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) {
|
||||
for (auto insert_index : spec_insert_indexes[op_type]) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else if (IsContain(GetNhwcAllInputOpList(), opType)) {
|
||||
} else if (IsContain(GetNhwcAllInputOpList(), op_type)) {
|
||||
auto input_size = node->inputIndex.size();
|
||||
if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) {
|
||||
if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) {
|
||||
|
@ -188,16 +186,16 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
}
|
||||
}
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status);
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status);
|
||||
}
|
||||
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
|
||||
iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
|
@ -206,29 +204,29 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
|
||||
size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) {
|
||||
MS_ASSERT((*existNodeIter) != nullptr);
|
||||
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place,
|
||||
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) {
|
||||
MS_ASSERT((*exist_node_iter) != nullptr);
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto existNodeName = (*existNodeIter)->name;
|
||||
std::string tileName;
|
||||
auto exist_node_name = (*exist_node_iter)->name;
|
||||
std::string tile_name;
|
||||
if (place == kBefore) {
|
||||
tileName = existNodeName + "_pre";
|
||||
tile_name = exist_node_name + "_pre";
|
||||
} else {
|
||||
tileName = existNodeName + "_post";
|
||||
tile_name = exist_node_name + "_post";
|
||||
}
|
||||
auto transNode = std::make_unique<schema::CNodeT>();
|
||||
transNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
transNode->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
auto trans_node = std::make_unique<schema::CNodeT>();
|
||||
trans_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
trans_node->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
auto perm_tensor = std::make_unique<schema::TensorT>();
|
||||
perm_tensor->dataType = kNumberTypeInt32;
|
||||
perm_tensor->dims = {4};
|
||||
std::vector<int> perm;
|
||||
if (nodeType == kNCHW2NHWC) {
|
||||
transNode->name = "nchw2nhwc_" + tileName + std::to_string(id_++);
|
||||
if (node_type == kNCHW2NHWC) {
|
||||
trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++);
|
||||
perm = {0, 2, 3, 1};
|
||||
} else {
|
||||
transNode->name = "nhwc2nchw_" + tileName + std::to_string(id_++);
|
||||
trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++);
|
||||
perm = {0, 3, 1, 2};
|
||||
}
|
||||
size_t bytes = perm.size() * sizeof(int);
|
||||
|
@ -236,27 +234,27 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
|
|||
if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy data failed.";
|
||||
}
|
||||
perm_tensor->name = transNode->name + "_perm";
|
||||
perm_tensor->name = trans_node->name + "_perm";
|
||||
|
||||
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
|
||||
auto newOpDef = std::make_unique<schema::CNodeT>();
|
||||
if (newOpDef == nullptr) {
|
||||
OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> {
|
||||
auto new_op_def = std::make_unique<schema::CNodeT>();
|
||||
if (new_op_def == nullptr) {
|
||||
MS_LOG(ERROR) << "new CNodeT failed";
|
||||
return nullptr;
|
||||
}
|
||||
newOpDef->name = inOpDef->name;
|
||||
newOpDef->quantType = inOpDef->quantType;
|
||||
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (newOpDef->primitive == nullptr) {
|
||||
new_op_def->name = in_op_def->name;
|
||||
new_op_def->quantType = in_op_def->quantType;
|
||||
new_op_def->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (new_op_def->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new PrimitiveT failed";
|
||||
return nullptr;
|
||||
}
|
||||
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
return newOpDef;
|
||||
new_op_def->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
return new_op_def;
|
||||
};
|
||||
int insert_num = 0;
|
||||
auto iter =
|
||||
InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, &insert_num, TransposeOpCopyer);
|
||||
auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
|
||||
transpose_op_copyer);
|
||||
size_t index = graph->allTensors.size();
|
||||
graph->allTensors.push_back(std::move(perm_tensor));
|
||||
for (int i = insert_num; i > 0; --i) {
|
||||
|
|
|
@ -34,13 +34,13 @@ class FormatTransPass : public GraphPass {
|
|||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
void SetQuantType(QuantType quantType) { this->quant_type_ = quantType; }
|
||||
void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; }
|
||||
|
||||
void SetFmk(converter::FmkType fmkType) { this->fmk_type_ = fmkType; }
|
||||
void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
|
||||
|
||||
protected:
|
||||
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
|
||||
FormatTransNodeType nodeType, STATUS *errorCode);
|
||||
NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place,
|
||||
size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code);
|
||||
|
||||
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
|
||||
|
||||
|
@ -61,8 +61,8 @@ class FormatTransPass : public GraphPass {
|
|||
|
||||
int GetFormat(const schema::CNodeT &);
|
||||
|
||||
STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
|
||||
FormatTransNodeType *afterNodeType);
|
||||
STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type,
|
||||
FormatTransNodeType *after_node_type);
|
||||
|
||||
protected:
|
||||
size_t id_ = 0;
|
||||
|
|
|
@ -20,13 +20,13 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
Optimizer::~Optimizer() {
|
||||
for (auto pass : graphPasses) {
|
||||
for (auto pass : graph_passes_) {
|
||||
if (pass != nullptr) {
|
||||
delete (pass);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto pass : nodePasses) {
|
||||
for (auto pass : node_passes_) {
|
||||
if (pass != nullptr) {
|
||||
delete (pass);
|
||||
}
|
||||
|
@ -35,13 +35,13 @@ Optimizer::~Optimizer() {
|
|||
|
||||
void Optimizer::AddPass(GraphPass *graphPass) {
|
||||
if (graphPass != nullptr) {
|
||||
this->graphPasses.emplace_back(graphPass);
|
||||
this->graph_passes_.emplace_back(graphPass);
|
||||
}
|
||||
}
|
||||
|
||||
void Optimizer::AddPass(NodePass *nodePass) {
|
||||
if (nodePass != nullptr) {
|
||||
this->nodePasses.emplace_back(nodePass);
|
||||
this->node_passes_.emplace_back(nodePass);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
|
|||
bool ifNotChanged = true;
|
||||
// each node should go through all node pass not each node pass go through all node
|
||||
for (auto &opDef : graphDefT->nodes) {
|
||||
for (auto pass : this->nodePasses) {
|
||||
for (auto pass : this->node_passes_) {
|
||||
status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get()));
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run NodePass failed";
|
||||
|
@ -64,7 +64,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
|
|||
}
|
||||
}
|
||||
|
||||
for (auto pass : this->graphPasses) {
|
||||
for (auto pass : this->graph_passes_) {
|
||||
status = pass->Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run GraphPass failed";
|
||||
|
|
|
@ -41,10 +41,10 @@ class GraphPass : public Pass<schema::MetaGraphT> {
|
|||
};
|
||||
|
||||
struct GraphNode {
|
||||
GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : subGraph(subGraph), opDef(opDefT) {}
|
||||
GraphNode(schema::MetaGraphT *subGraph, schema::CNodeT *opDefT) : sub_graph_(subGraph), op_def_(opDefT) {}
|
||||
~GraphNode() = default;
|
||||
schema::MetaGraphT *subGraph = nullptr;
|
||||
schema::CNodeT *opDef = nullptr;
|
||||
schema::MetaGraphT *sub_graph_ = nullptr;
|
||||
schema::CNodeT *op_def_ = nullptr;
|
||||
};
|
||||
|
||||
class NodePass : public Pass<GraphNode> {
|
||||
|
@ -72,8 +72,8 @@ class Optimizer {
|
|||
STATUS Run(schema::MetaGraphT *graphDefT);
|
||||
|
||||
private:
|
||||
std::vector<GraphPass *> graphPasses;
|
||||
std::vector<NodePass *> nodePasses;
|
||||
std::vector<GraphPass *> graph_passes_;
|
||||
std::vector<NodePass *> node_passes_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -98,7 +98,7 @@ STATUS CaffeModelParser::ConvertLayers() {
|
|||
MS_LOG(INFO) << "parse op : " << layer.type();
|
||||
auto node_parser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type());
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
NotSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
|
||||
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
anf_root_graph_ = std::make_shared<FuncGraph>();
|
||||
auto status = InitOriginModel(model_file);
|
||||
if (RET_OK != status) {
|
||||
|
@ -195,7 +195,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|||
for (const auto &onnx_node : onnx_graph.node()) {
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
status = status == RET_OK ? RET_NOT_FIND_OP : status;
|
||||
MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
|
||||
}
|
||||
|
|
|
@ -476,7 +476,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
|||
FuncGraphPtr paserTfFuction() { return nullptr; }
|
||||
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
NoSupportOp::GetInstance()->SetFmkType("TF");
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
|
||||
|
@ -888,7 +888,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
MS_LOG(INFO) << "parse op : " << op_type;
|
||||
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
|
||||
<< func_graph_ptr->get_attr("graph_name")->ToString();
|
||||
return RET_NOT_FIND_OP;
|
||||
|
|
|
@ -101,7 +101,7 @@ std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type,
|
|||
|
||||
STATUS TfliteModelParser::ConvertOps() {
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TFLITE");
|
||||
STATUS status = RET_OK;
|
||||
int op_idx = 0;
|
||||
for (auto &op : tflite_subgraph->operators) {
|
||||
|
@ -113,7 +113,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
|||
MS_LOG(INFO) << "parse node :" << op_name;
|
||||
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(tflite_op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1344,7 +1344,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
|
||||
// add quant_cast
|
||||
quant::QuantCast quant_cast;
|
||||
quant_cast.SetInputDataDType(kNumberTypeFloat32);
|
||||
quant_cast.set_input_data_dtype(kNumberTypeFloat32);
|
||||
status = quant_cast.Run(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
|
|
|
@ -26,12 +26,12 @@ namespace mindspore::lite::quant {
|
|||
class QuantCast {
|
||||
public:
|
||||
QuantCast() = default;
|
||||
~QuantCast() = default;
|
||||
virtual ~QuantCast() = default;
|
||||
STATUS Run(const FuncGraphPtr &graph);
|
||||
void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
|
||||
void set_input_data_dtype(TypeId data_type) { this->input_data_dtype_ = data_type; }
|
||||
|
||||
private:
|
||||
TypeId inputDataDType = kNumberTypeFloat32;
|
||||
TypeId input_data_dtype_ = kNumberTypeFloat32;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H
|
||||
|
|
Loading…
Reference in New Issue