remove unused quant_type judgement
This commit is contained in:
parent
d23114fe89
commit
bb34fb5d6c
|
@ -20,7 +20,6 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include "include/lite_utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace converter {
|
||||
|
@ -36,7 +35,6 @@ enum MS_API FmkType : int {
|
|||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
struct MS_API ConverterParameters {
|
||||
FmkType fmk;
|
||||
schema::QuantType quant_type;
|
||||
std::string model_file;
|
||||
std::string weight_file;
|
||||
std::map<std::string, std::string> attrs;
|
||||
|
|
|
@ -33,7 +33,6 @@ namespace lite {
|
|||
namespace {
|
||||
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
||||
converter_parameters->fmk = flag.fmk;
|
||||
converter_parameters->quant_type = flag.quantType;
|
||||
converter_parameters->model_file = flag.modelFile;
|
||||
converter_parameters->weight_file = flag.weightFile;
|
||||
}
|
||||
|
|
|
@ -28,7 +28,6 @@ class MindirAdjust {
|
|||
public:
|
||||
MindirAdjust() {}
|
||||
~MindirAdjust() = default;
|
||||
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
|
||||
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
|
||||
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
|
||||
bool Run(const FuncGraphPtr &graph);
|
||||
|
@ -37,7 +36,6 @@ class MindirAdjust {
|
|||
int ValueNodeInt64Convert(AnfNodePtr anf_node);
|
||||
int ComputeQuantParams(AnfNodePtr anf_node);
|
||||
|
||||
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type_ = FmkType::kFmkTypeMs;
|
||||
bool train_flag_ = false;
|
||||
};
|
||||
|
|
|
@ -42,7 +42,6 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
|
||||
mindir_adjust_pass->SetFmkType(flag.fmk);
|
||||
mindir_adjust_pass->SetQuantType(flag.quantType);
|
||||
mindir_adjust_pass->SetTrainFlag(flag.trainModel);
|
||||
if (!mindir_adjust_pass->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "MindIr adjust failed.";
|
||||
|
@ -97,7 +96,6 @@ size_t MindsporeImporter::Hex2ByteArray(const std::string &hex_str, unsigned cha
|
|||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||
quant_type_ = flag.quantType;
|
||||
FuncGraphPtr func_graph;
|
||||
if (flag.dec_key.size() != 0) {
|
||||
unsigned char key[32];
|
||||
|
@ -128,7 +126,7 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel, flag.quantType);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel);
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -31,7 +31,6 @@ class MindsporeImporter {
|
|||
|
||||
private:
|
||||
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||
};
|
||||
|
||||
|
|
|
@ -79,7 +79,6 @@ CaffeModelParser::~CaffeModelParser() = default;
|
|||
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
auto weight_file = flag.weight_file;
|
||||
quant_type_ = flag.quant_type;
|
||||
STATUS status = InitOriginModel(model_file, weight_file);
|
||||
if (status != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -112,7 +111,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeCaffe, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -66,7 +66,6 @@ class CaffeModelParser : public converter::ModelParser {
|
|||
caffe::NetParameter caffe_weight_;
|
||||
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
|
||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -60,7 +60,6 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
|
||||
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
string model_file = flag.model_file;
|
||||
quant_type_ = flag.quant_type;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto status = InitOriginModel(model_file);
|
||||
|
@ -95,7 +94,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeOnnx, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -99,7 +99,6 @@ class OnnxModelParser : public converter::ModelParser {
|
|||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
||||
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -492,7 +492,6 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensor
|
|||
|
||||
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto modelFile = flag.model_file;
|
||||
quant_type_ = flag.quant_type;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
if (status != RET_OK) {
|
||||
|
@ -581,7 +580,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTf, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -108,7 +108,6 @@ class TFModelParser : public converter::ModelParser {
|
|||
std::vector<std::string> while_cond_branch_name_;
|
||||
std::vector<std::string> if_then_branch_name_;
|
||||
std::unordered_map<std::string, int> node_output_num_;
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_;
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -54,7 +54,6 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
|||
|
||||
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
quant_type_ = flag.quant_type;
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file);
|
||||
if (tflite_model_ == nullptr) {
|
||||
|
@ -105,7 +104,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeTflite, false, quant_type_);
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -52,7 +52,6 @@ class TfliteModelParser : public converter::ModelParser {
|
|||
STATUS ConvertGraphOutputs();
|
||||
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
|
||||
int round_type = 1);
|
||||
schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,8 +32,7 @@ constexpr int kNumIndex_0 = 0;
|
|||
constexpr int kNumIndex_1 = 1;
|
||||
constexpr int kNumIndex_2 = 2;
|
||||
constexpr int kNumIndex_3 = 3;
|
||||
STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type,
|
||||
schema::Format *src_format) {
|
||||
STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
|
@ -47,13 +46,13 @@ STATUS DecideMINDIRConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType
|
|||
} else if (format == schema::Format_NCHW) {
|
||||
*src_format = schema::Format_KCHW;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "cnode format is invalid.";
|
||||
MS_LOG(ERROR) << "cnode format is invalid. " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
|
||||
STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
|
@ -61,34 +60,22 @@ STATUS DecideTFConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quan
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
switch (quant_type) {
|
||||
case schema::QuantType_AwareTraining:
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_HWCK;
|
||||
} else {
|
||||
*src_format = schema::Format_HWKC;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
*src_format = schema::Format::Format_HWCK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_HWCK;
|
||||
} else {
|
||||
*src_format = schema::Format_HWKC;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
*src_format = schema::Format::Format_HWCK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "depthwise-conv2dTranspose need to check. " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type,
|
||||
schema::Format *src_format) {
|
||||
STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
|
@ -96,87 +83,49 @@ STATUS DecideTFLITEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
switch (quant_type) {
|
||||
case schema::QuantType_AwareTraining:
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
} else {
|
||||
*src_format = schema::Format_CHWK;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
*src_format = schema::Format_CHWK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "cannot decide weight format, current situation need to check.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
||||
<< ", node: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
} else {
|
||||
*src_format = schema::Format_CHWK;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
*src_format = schema::Format_CHWK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "cannot decide weight format, current situation need to check. " << cnode->fullname_with_scope();
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
|
||||
STATUS DecideCAFFEConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr);
|
||||
*src_format = schema::Format_KCHW;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::QuantType quant_type, schema::Format *src_format) {
|
||||
STATUS DecideONNXConvWeightSrcFormat(const CNodePtr &cnode, schema::Format *src_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
int64_t format =
|
||||
prim->GetAttr(ops::kOriginalFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kOriginalFormat)) : 0;
|
||||
switch (quant_type) {
|
||||
case schema::QuantType_AwareTraining: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
} else {
|
||||
*src_format = schema::Format_CHWK;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
*src_format = schema::Format_KCHW;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
case schema::QuantType_PostTraining:
|
||||
case schema::QuantType_WeightQuant:
|
||||
case schema::QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (format == schema::Format_NHWC) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
} else if (format == schema::Format_NCHW) {
|
||||
*src_format = schema::Format_KCHW;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "format is invalid, format is " << format;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "d an unsupported op type, which need to check. the type is " << prim->name();
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
||||
<< ", node: " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (format == schema::Format_NHWC) {
|
||||
*src_format = schema::Format_KHWC;
|
||||
} else if (format == schema::Format_NCHW) {
|
||||
*src_format = schema::Format_KCHW;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "format is invalid, format is " << format;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unknown op, please check.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -246,12 +195,12 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
|
|||
schema::Format *dst_format) {
|
||||
MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr);
|
||||
*dst_format = schema::Format_KHWC;
|
||||
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::QuantType, schema::Format *)>>
|
||||
decide_functions = {{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
|
||||
std::map<converter::FmkType, std::function<int(const CNodePtr &, schema::Format *)>> decide_functions = {
|
||||
{converter::kFmkTypeMs, DecideMINDIRConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTf, DecideTFConvWeightSrcFormat},
|
||||
{converter::kFmkTypeTflite, DecideTFLITEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeCaffe, DecideCAFFEConvWeightSrcFormat},
|
||||
{converter::kFmkTypeOnnx, DecideONNXConvWeightSrcFormat}};
|
||||
auto iter = decide_functions.find(fmk_type_);
|
||||
if (iter == decide_functions.end()) {
|
||||
MS_LOG(ERROR) << "current fmk don't support, please check.";
|
||||
|
@ -259,7 +208,7 @@ STATUS UnifyFormatToNHWC::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode,
|
|||
}
|
||||
auto decide_func = iter->second;
|
||||
MS_ASSERT(decide_func != nullptr);
|
||||
if (decide_func(cnode, quant_type_, src_format) != RET_OK) {
|
||||
if (decide_func(cnode, src_format) != RET_OK) {
|
||||
MS_LOG(ERROR) << "run decide function failed, cannot decide conv weight format.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -24,9 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class UnifyFormatToNHWC : public opt::ToFormatBase {
|
||||
public:
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false,
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE)
|
||||
: ToFormatBase(fmk_type, train_flag), quant_type_(quant_type) {}
|
||||
explicit UnifyFormatToNHWC(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false)
|
||||
: ToFormatBase(fmk_type, train_flag) {}
|
||||
~UnifyFormatToNHWC() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
|
@ -41,7 +40,6 @@ class UnifyFormatToNHWC : public opt::ToFormatBase {
|
|||
bool DecideWhetherInferShapeForNewNode() override;
|
||||
STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format,
|
||||
schema::Format *dst_format) override;
|
||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue