forked from mindspore-Ecosystem/mindspore
!16820 [MS_LITE] weight format
From: @YeFeng_24 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5d81221b58
|
@ -213,8 +213,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
|
|
|
@ -106,13 +106,13 @@ ml_face_glasses 2.5
|
|||
# ml_segmentation_matting 26 # output value unstable
|
||||
ml_segmentation_atlanta_10 5
|
||||
# ml_bodymask: The difference of output node divided by a very small value leads to a large error
|
||||
ml_bodymask 14 13
|
||||
ml_bodymask 16 18
|
||||
ml_Hand_deploy 4 4
|
||||
# ml_hand_3d_detection: The difference of output node divided by a very small value leads to a large error
|
||||
ml_hand_3d_detection 12 10
|
||||
ml_hand_3d_regression 3 4
|
||||
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
|
||||
ml_ARengine23_bodypose 56 58
|
||||
ml_ARengine23_bodypose 56 59
|
||||
ml_ocr_bank_card_detection_inception_tmp 20
|
||||
ml_ocr_bank_card_recognition_fcny 0.5
|
||||
hiai_cv_aestheticsEngineModel_osp 1.6
|
||||
|
|
|
@ -68,7 +68,7 @@ ml_location_lane_counter0.onnx 0.5
|
|||
#The encoder an decoder model are used in ml_asr scene, both have value overflow. Not suitable for fp16.
|
||||
#But added for guarding process.
|
||||
encoder.onnx;1,32,83 1262
|
||||
mtk_emotions-d2012-75.onnx 5
|
||||
mtk_emotions-d2012-75.onnx 6
|
||||
mtk_detect-mbv1-shortcut-400-400.onnx 0.5
|
||||
mtk_detect-mbv2-shortcut-400-400.onnx 0.5
|
||||
mtk_detect_mbv1_640_480.onnx 0.5
|
||||
|
|
|
@ -9,7 +9,7 @@ ml_video_edit_generate_filter.pb 2
|
|||
ml_ocr_jk.pb 0.7
|
||||
# The accumulated error causes the threshold to be exceeded
|
||||
ml_ocr_latin.pb 12
|
||||
scan_hms_angle.pb 1.5
|
||||
scan_hms_angle.pb 2.5
|
||||
scan_hms_detect.pb 2.5
|
||||
ml_face_openclose.pb;1,32,32,3 0.5
|
||||
ml_object_detect.pb;1,288,288,3 2
|
||||
|
|
|
@ -280,8 +280,11 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
|
||||
// attr weightFormat is only used by conv-like ops' second input
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
|
||||
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&
|
||||
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
|
||||
data_info->format_ = mindspore::KHWC;
|
||||
}
|
||||
if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
|
@ -311,8 +314,8 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
MS_ASSERT(prim != nullptr);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
|
||||
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
|
||||
if (index == 2 && prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
}
|
||||
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
|
||||
ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info);
|
||||
|
|
|
@ -81,8 +81,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/parallel/operator_info_register.cc
|
||||
../optimizer/parallel/spliter.cc
|
||||
../optimizer/parallel/split_strategy.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
../optimizer/graph/conv1d_weight_expanding_pass.cc
|
||||
../optimizer/graph/clip_convert_activation_pass.cc
|
||||
../optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
|
|
|
@ -42,10 +42,7 @@
|
|||
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
|
||||
#include "tools/optimizer/fusion/squeeze_fusion.h"
|
||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
|
||||
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
|
||||
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
|
||||
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
|
@ -161,14 +158,6 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F
|
|||
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
|
||||
graph_pm->AddPass(std::make_shared<opt::IfPass>());
|
||||
}
|
||||
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
|
||||
weight_format_hardcode_pass->SetFmkType(config->fmk);
|
||||
weight_format_hardcode_pass->SetQuantType(config->quantType);
|
||||
graph_pm->AddPass(weight_format_hardcode_pass);
|
||||
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
|
||||
weight_format_transform_pass->SetFmkType(config->fmk);
|
||||
weight_format_transform_pass->SetQuantType(config->quantType);
|
||||
graph_pm->AddPass(weight_format_transform_pass);
|
||||
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
|
||||
slice_prepose_pass->SetFmkType(config->fmk);
|
||||
graph_pm->AddPass(slice_prepose_pass);
|
||||
|
@ -184,9 +173,6 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
|
|||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE) {
|
||||
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
|
||||
}
|
||||
optimizer->AddPassManager(convert_pm);
|
||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||
MS_LOG(ERROR) << "run graph convert pass failed.";
|
||||
|
@ -205,10 +191,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
|
|||
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
|
||||
update_conv2d_param_pass->SetFmkType(config->fmk);
|
||||
const_fold_pm->AddPass(update_conv2d_param_pass);
|
||||
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
|
||||
weight_format_hardcode_pass->SetFmkType(config->fmk);
|
||||
weight_format_hardcode_pass->SetQuantType(config->quantType);
|
||||
const_fold_pm->AddPass(weight_format_hardcode_pass);
|
||||
auto infershape_pass = std::make_shared<opt::InferShapePass>();
|
||||
infershape_pass->SetFmkType(config->fmk);
|
||||
const_fold_pm->AddPass(infershape_pass);
|
||||
|
|
|
@ -46,7 +46,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = model_parser_->Parse(flag.modelFile, flag.weightFile);
|
||||
func_graph = model_parser_->Parse(flag);
|
||||
}
|
||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
|
||||
|
|
|
@ -16,12 +16,17 @@
|
|||
|
||||
#include "tools/converter/import/mindspore_importer.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/import/primitive_adjust.h"
|
||||
#include "tools/converter/import/mindir_adjust.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
STATUS MindsporeImporter::AdjustForMindir(const FuncGraphPtr &func_graph, const converter::Flags &flag) {
|
||||
auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>();
|
||||
primitive_adjust_pass->SetFmkType(flag.fmk);
|
||||
|
@ -42,7 +47,98 @@ STATUS MindsporeImporter::AdjustForMindir(const FuncGraphPtr &func_graph, const
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MindsporeImporter::WeightFormatTransform(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
int status = HardCodeMindir(conv_cnode, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MindsporeImporter::HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
STATUS status = RET_OK;
|
||||
schema::Format weight_src_format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
switch (quant_type_) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
if (format == schema::Format::Format_KHWC) {
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
} else {
|
||||
weight_src_format = schema::Format::Format_KCHW;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
status = HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
weight_node = conv_node->input(kConvWeightIndex);
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value != nullptr) {
|
||||
status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(weight_node)) {
|
||||
status = HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||
quant_type_ = flag.quantType;
|
||||
auto func_graph = LoadMindIR(flag.modelFile);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||
|
@ -54,6 +150,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
MS_LOG(ERROR) << "AdjustForMindir failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto status = WeightFormatTransform(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -29,6 +29,9 @@ class MindsporeImporter {
|
|||
|
||||
private:
|
||||
STATUS AdjustForMindir(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
|
||||
STATUS HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph);
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -34,7 +34,7 @@ class ModelParser {
|
|||
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) { return this->res_graph_; }
|
||||
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; }
|
||||
|
||||
protected:
|
||||
FuncGraphPtr res_graph_ = nullptr;
|
||||
|
|
|
@ -30,9 +30,13 @@
|
|||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
bool IsSkipedLayer(const caffe::LayerParameter &layer) {
|
||||
if (layer.type() == "Input" || layer.type() == "Dropout" || layer.type() == "Split") {
|
||||
return true;
|
||||
|
@ -62,7 +66,10 @@ CaffeModelParser::CaffeModelParser() = default;
|
|||
|
||||
CaffeModelParser::~CaffeModelParser() = default;
|
||||
|
||||
FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
|
||||
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
|
||||
auto model_file = flag.modelFile;
|
||||
auto weight_file = flag.weightFile;
|
||||
quant_type_ = flag.quantType;
|
||||
STATUS status = InitOriginModel(model_file, weight_file);
|
||||
if (status != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -94,9 +101,107 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
|
|||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
}
|
||||
|
||||
STATUS CaffeModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "weight node must param value";
|
||||
return RET_OK;
|
||||
}
|
||||
lite::STATUS status;
|
||||
status = HardCodeCaffe(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS CaffeModelParser::HardCodeCaffe(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value == nullptr) {
|
||||
MS_LOG(DEBUG) << "weight node must param value";
|
||||
return RET_OK;
|
||||
}
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
STATUS status = RET_OK;
|
||||
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
|
||||
switch (quant_type_) {
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
weight_src_format = schema::Format::Format_KCHW;
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value != nullptr) {
|
||||
status = opt::TransFilterFormat(weight_value, schema::Format::Format_KCHW, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type_;
|
||||
return ERROR;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS CaffeModelParser::ConvertLayers() {
|
||||
STATUS status = RET_OK;
|
||||
std::map<std::string, caffe::LayerParameter> weight_layers;
|
||||
|
|
|
@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
|
|||
|
||||
~CaffeModelParser() override;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
||||
|
@ -56,10 +56,15 @@ class CaffeModelParser : public ModelParser {
|
|||
|
||||
std::string GetOriginLayerName(const std::string &layer_name);
|
||||
|
||||
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
|
||||
|
||||
STATUS HardCodeCaffe(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
|
||||
|
||||
caffe::NetParameter caffe_model_;
|
||||
caffe::NetParameter caffe_weight_;
|
||||
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
|
||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
|
|
|
@ -147,9 +147,6 @@ bool Conv1DInOutAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
schema::Format schema_format = schema::Format::Format_KCHW;
|
||||
if (prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)));
|
||||
}
|
||||
// expand weight tensor to 4 dimensions.
|
||||
auto weight_tensor = opt::GetTensorInfo(weight_node);
|
||||
if (weight_tensor == nullptr) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
@ -33,8 +34,13 @@
|
|||
#include "tools/converter/parser/onnx/onnx_inputs_adjust_pass.h"
|
||||
#include "tools/converter/parser/onnx/onnx_pad_adjust_pass.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "ops/transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
||||
{onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
|
||||
{onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
|
||||
|
@ -46,7 +52,9 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||
|
||||
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
|
||||
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
|
||||
string model_file = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto status = InitOriginModel(model_file);
|
||||
|
@ -79,9 +87,148 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
|
|||
MS_LOG(ERROR) << "OnnxModelPostAdjust failed.";
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(all_func_graphs);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::WeightFormatTransform(const std::set<FuncGraphPtr> &all_func_graphs) {
|
||||
for (auto graph : all_func_graphs) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeONNX(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS OnnxModelParser::HardCodeONNX(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
STATUS status = RET_OK;
|
||||
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
switch (quant_type_) {
|
||||
case QuantType_AwareTraining: {
|
||||
// sum up from current onnx quant models
|
||||
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
} else {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_CHWK;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_KCHW;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// conv (K x C/group x kH x kW) group = 1
|
||||
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
|
||||
// deconv (C x K/group x kH x kW) group = 1
|
||||
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
|
||||
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (format == schema::Format::Format_NHWC) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NHWC));
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
} else {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_KCHW;
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
int OnnxModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
|
||||
const FuncGraphPtr &graph, schema::Format weight_src_format,
|
||||
schema::Format weight_dst_format) {
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value != nullptr) {
|
||||
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
||||
auto status = ValidateFileStr(model_file, ".onnx");
|
||||
if (status != RET_OK) {
|
||||
|
@ -223,7 +370,9 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|||
status = RET_ERROR;
|
||||
continue;
|
||||
}
|
||||
primitive_c->AddAttr(mindspore::opt::kWeightFormat, MakeValue<int64_t>(Format_NCHW));
|
||||
if (primitive_c->GetAttr(ops::kFormat) == nullptr) {
|
||||
primitive_c->AddAttr(mindspore::ops::kFormat, MakeValue<int64_t>(Format_NCHW));
|
||||
}
|
||||
status = ConvertOpQuantParams(onnx_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
|
||||
|
|
|
@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
~OnnxModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
|
||||
int OnnxModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
@ -92,12 +92,17 @@ class OnnxModelParser : public ModelParser {
|
|||
STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
|
||||
const std::string &subgrah_name, const std::string &if_node_name,
|
||||
const std::string &root_node_name);
|
||||
STATUS WeightFormatTransform(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
STATUS HardCodeONNX(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
|
||||
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
|
||||
schema::Format weight_src_format, schema::Format weight_dst_format);
|
||||
onnx::ModelProto onnx_model_;
|
||||
onnx::GraphProto onnx_root_graph_;
|
||||
std::vector<FuncGraphPtr> all_subgraphs_;
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
||||
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,10 +15,17 @@
|
|||
*/
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
|
||||
#include "tools/converter/parser/unused_node_remove_pass.h"
|
||||
#include "tools/converter/parser/conv1d_inout_adjust.h"
|
||||
#include "tools/converter/parser/inputs_adjust.h"
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||
|
@ -78,4 +85,163 @@ int PostAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
|
||||
MS_ASSERT(perm != nullptr);
|
||||
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
|
||||
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
|
||||
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
|
||||
MS_LOG(ERROR) << "src_format or dst_format is error.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < src_format_str.size(); ++i) {
|
||||
auto pos = src_format_str.find(dst_format_str[i]);
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(ERROR) << "src_format and dst_format don't match.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
perm->push_back(static_cast<int>(pos));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
|
||||
MS_ASSERT(perm != nullptr);
|
||||
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
|
||||
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
|
||||
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
|
||||
MS_LOG(ERROR) << "src_format or dst_format is error.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < src_format_str.size(); ++i) {
|
||||
auto pos = dst_format_str.find(src_format_str[i]);
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(ERROR) << "src_format and dst_format don't match.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
perm->push_back(static_cast<int>(pos));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
std::vector<int> perm) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
std::vector<CNodePtr> adjust_nodes;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimAdam)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(),
|
||||
[&](const AnfNodePtr &anf_node) { return weight_node == anf_node; })) {
|
||||
if (opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
|
||||
continue;
|
||||
}
|
||||
adjust_nodes.push_back(cnode);
|
||||
}
|
||||
}
|
||||
if (adjust_nodes.empty()) {
|
||||
MS_LOG(DEBUG) << "do not need to adjust nodes.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_sharing_perm");
|
||||
auto prim = std::make_shared<ops::Transpose>();
|
||||
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
|
||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||
if (!weight_node->has_default()) {
|
||||
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract = CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
transpose_node->set_abstract(abstract);
|
||||
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_sharing_post");
|
||||
for (auto &adjust_node : adjust_nodes) {
|
||||
auto inputs = adjust_node->inputs();
|
||||
std::replace_if(
|
||||
inputs.begin(), inputs.end(), [&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; },
|
||||
transpose_node);
|
||||
adjust_node->set_inputs(inputs);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
if (src_format == dst_format) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
auto status = GetTransposePermSharing(src_format, dst_format, &perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get perm failed.";
|
||||
return status;
|
||||
}
|
||||
status = TransposeInsertForWeightSharing(graph, format, weight_node, perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose insert failed.";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
|
||||
std::vector<int> perm) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto manager = Manage(graph);
|
||||
if (opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) ||
|
||||
opt::CheckPrimitiveType(weight_node, prim::kPrimLoad)) {
|
||||
manager->Replace(weight_node, weight_node->input(1));
|
||||
return RET_OK;
|
||||
}
|
||||
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_const_perm");
|
||||
auto prim = std::make_shared<ops::Transpose>();
|
||||
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
|
||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_const_post");
|
||||
conv_node->set_input(2, transpose_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
if (src_format == dst_format) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
auto status = GetTransposePerm(src_format, dst_format, &perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get perm failed.";
|
||||
return status;
|
||||
}
|
||||
status = TransposeInsertForWeightConst(graph, conv_node, weight_node, perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose insert failed.";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -18,15 +18,26 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PARSER_UTILS_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
||||
int PostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
|
||||
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
|
||||
int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
|
||||
std::vector<int> perm);
|
||||
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format);
|
||||
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
std::vector<int> perm);
|
||||
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -40,6 +41,7 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) {
|
|||
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListSetItem) ||
|
||||
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve);
|
||||
}
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
|
||||
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
|
||||
int index = 0) {
|
||||
|
@ -476,7 +478,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
|
||||
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
|
||||
auto modelFile = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
if (status != RET_OK) {
|
||||
|
@ -562,9 +566,146 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
|||
MS_LOG(ERROR) << "AdjustForOnnxModel failed.";
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
res_graph_->set_manager(nullptr);
|
||||
static auto root_func_manager = Manage(res_graph_);
|
||||
return res_graph_;
|
||||
}
|
||||
|
||||
STATUS TFModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeTF(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFModelParser::HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
STATUS status = RET_OK;
|
||||
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
switch (quant_type_) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_HWCK;
|
||||
} else {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_HWKC;
|
||||
}
|
||||
|
||||
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_HWCK;
|
||||
}
|
||||
if (format == Format_NCHW) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NCHW));
|
||||
} else if (format == Format_KHWC) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (format == Format_NCHW) {
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NCHW));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TFModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
|
||||
const FuncGraphPtr &graph, schema::Format weight_src_format,
|
||||
schema::Format weight_dst_format) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value != nullptr) {
|
||||
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightSharing(graph, format, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
|
||||
|
|
|
@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
|
|||
TFModelParser() = default;
|
||||
~TFModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile) override;
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
|
||||
int TFModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
@ -93,6 +93,13 @@ class TFModelParser : public ModelParser {
|
|||
|
||||
STATUS ConnectNullInput();
|
||||
|
||||
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
|
||||
|
||||
STATUS HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
|
||||
|
||||
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
|
||||
schema::Format weight_src_format, schema::Format weight_dst_format);
|
||||
|
||||
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
|
||||
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
|
||||
|
@ -104,6 +111,7 @@ class TFModelParser : public ModelParser {
|
|||
std::vector<std::string> while_cond_branch_name_;
|
||||
std::vector<std::string> if_then_branch_name_;
|
||||
std::unordered_map<std::string, int> node_output_num_;
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,9 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::string &model_path) {
|
||||
size_t size = 0;
|
||||
tflite_model_buf_ = ReadFile(model_path.c_str(), &size);
|
||||
|
@ -47,7 +50,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
|||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
|
||||
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
|
||||
auto model_file = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file);
|
||||
if (tflite_model_ == nullptr) {
|
||||
|
@ -96,9 +101,124 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
|
|||
MS_LOG(ERROR) << "AdjustForOnnxModel failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
}
|
||||
STATUS TfliteModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = opt::GetTensorInfo(weight_node);
|
||||
lite::STATUS status;
|
||||
status = HardCodeTflite(conv_cnode, tensor_info, graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::HardCodeTflite(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
STATUS status = RET_OK;
|
||||
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
switch (quant_type_) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
if (format == KHWC) {
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
weight_src_format = schema::Format::Format_KHWC;
|
||||
} else {
|
||||
weight_src_format = schema::Format::Format_CHWK;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
weight_src_format = schema::Format::Format_CHWK;
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int TfliteModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
|
||||
const FuncGraphPtr &graph, schema::Format weight_src_format,
|
||||
schema::Format weight_dst_format) {
|
||||
if (utils::isa<CNodePtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto weight_value = opt::GetTensorInfo(weight_node);
|
||||
if (weight_value != nullptr) {
|
||||
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(weight_node)) {
|
||||
auto status =
|
||||
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
|
||||
std::string tensor_name = op_name + "/input-" + std::to_string(index);
|
||||
|
|
|
@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
|
|||
|
||||
~TfliteModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
|
||||
int TfliteModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
@ -52,6 +52,11 @@ class TfliteModelParser : public ModelParser {
|
|||
STATUS ConvertGraphOutputs();
|
||||
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
|
||||
int round_type = 1);
|
||||
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
|
||||
schema::Format weight_src_format, schema::Format weight_dst_format);
|
||||
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
|
||||
STATUS HardCodeTflite(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,7 +38,6 @@ namespace opt {
|
|||
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
|
||||
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
|
||||
const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion);
|
||||
constexpr auto kWeightFormat = "weight_format";
|
||||
std::vector<int> CastToInt(const ValuePtr &value);
|
||||
|
||||
std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value);
|
||||
|
|
|
@ -95,10 +95,12 @@ STATUS GenNewConvBias(const ParameterPtr &down_bias_node, const ParameterPtr &do
|
|||
new_bias_data[i] += down_bias_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
new_bias_node->set_name(down_bias_node->fullname_with_scope());
|
||||
new_bias_node->set_default_param(tensor_info);
|
||||
new_bias_node->set_abstract(down_bias_node->abstract());
|
||||
new_bias_node->set_name(down_weight_node->fullname_with_scope());
|
||||
auto status = lite::InitParameterFromTensorInfo(new_bias_node, tensor_info);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
// up weight shape[cout0,h,w,cin0] down weight shape[cout1,1,1,cout0],new weight shape [cout1,h,w,cin0]
|
||||
|
@ -140,8 +142,11 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr
|
|||
}
|
||||
|
||||
new_weight_node->set_name(down_weight_node->fullname_with_scope());
|
||||
new_weight_node->set_default_param(tensor_info);
|
||||
new_weight_node->set_abstract(down_weight_node->abstract());
|
||||
auto status = lite::InitParameterFromTensorInfo(new_weight_node, tensor_info);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -154,8 +159,10 @@ void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &u
|
|||
MS_LOG(ERROR) << "GenNewConvWeight failed.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto manager = func_graph->manager();
|
||||
manager->Replace(down_weight_parameter, new_weight_paramter);
|
||||
down_conv_cnode->set_input(kConvWeightIndex, new_weight_paramter);
|
||||
|
||||
// whether up conv node has bias
|
||||
if (up_conv_cnode->inputs().size() == kConvWithBiasLen) {
|
||||
ParameterPtr down_bias_parameter;
|
||||
|
@ -169,7 +176,7 @@ void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &u
|
|||
return;
|
||||
}
|
||||
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
|
||||
manager->Replace(down_bias_parameter, new_bias_parameter);
|
||||
down_conv_cnode->set_input(kConvBiasIndex, new_bias_parameter);
|
||||
} else {
|
||||
down_conv_cnode->add_input(new_bias_parameter);
|
||||
}
|
||||
|
|
|
@ -56,16 +56,10 @@ void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const
|
|||
if (dst_weight == nullptr || conv_weight == nullptr || scale_weight == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (fmk == lite::converter::FmkType_TF) {
|
||||
for (int i = 0; i < weight_shape_size; i++) {
|
||||
dst_weight[i] = conv_weight[i] * scale_weight[i % kernel_num];
|
||||
}
|
||||
} else {
|
||||
auto kernel_size = weight_shape_size / kernel_num;
|
||||
for (int i = 0; i < kernel_num; i++) {
|
||||
for (int j = 0; j < kernel_size; j++) {
|
||||
dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
|
||||
}
|
||||
auto kernel_size = weight_shape_size / kernel_num;
|
||||
for (int i = 0; i < kernel_num; i++) {
|
||||
for (int j = 0; j < kernel_size; j++) {
|
||||
dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,29 +71,13 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
|
|||
}
|
||||
MS_ASSERT(group > 0);
|
||||
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
|
||||
if (fmk == lite::converter::FmkType_TF) {
|
||||
auto cin_group = weight_tensor->shape()[3] / group;
|
||||
int area_size = weight_tensor->shape()[0] * weight_tensor->shape()[1];
|
||||
auto cin_group = weight_tensor->shape()[0] / group;
|
||||
int area_size = weight_tensor->shape()[1] * weight_tensor->shape()[2];
|
||||
for (int k = 0; k < cin_group; ++k) {
|
||||
for (int j = 0; j < area_size; j++) {
|
||||
for (int i = 0; i < kernel_num; ++i) {
|
||||
for (int k = 0; k < cin_group; ++k) {
|
||||
dst_weight[k + i * cin_group + j * kernel_num * cin_group] =
|
||||
weight_data[k + i * cin_group + j * kernel_num * cin_group] * scale_weight[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_ASSERT(group > 0);
|
||||
auto cin_group = weight_tensor->shape()[0] / group;
|
||||
int area_size = weight_tensor->shape()[2] * weight_tensor->shape()[3];
|
||||
int cout_size = kernel_num * area_size;
|
||||
for (int k = 0; k < cin_group; ++k) {
|
||||
for (int i = 0; i < kernel_num; ++i) {
|
||||
auto row_addr = weight_data + k * cout_size + i * area_size;
|
||||
auto new_row_addr = dst_weight + k * cout_size + i * area_size;
|
||||
for (int j = 0; j < area_size; j++) {
|
||||
new_row_addr[j] = row_addr[j] * scale_weight[i];
|
||||
}
|
||||
dst_weight[i + j * kernel_num + k * area_size * kernel_num] =
|
||||
weight_data[i + j * kernel_num + k * area_size * kernel_num] * scale_weight[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -488,7 +488,7 @@ CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &fu
|
|||
bias, stacked_hidden, input_length};
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
auto prim = GetValueNode<PrimitivePtr>(new_node->input(0));
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::NHWC));
|
||||
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format::NHWC));
|
||||
new_node->set_fullname_with_scope(base_name);
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -77,8 +77,8 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
|
|||
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
schema::Format schema_format = schema::Format::Format_KCHW;
|
||||
if (prim->GetAttr(opt::kWeightFormat) != nullptr) {
|
||||
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)));
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
|
||||
}
|
||||
// expand weight tensor to 4 dimensions.
|
||||
auto status = ExpandFilterShape(weight_node, schema_format);
|
||||
|
|
|
@ -95,7 +95,7 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
|
|||
|
||||
status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status == RET_OK) {
|
||||
conv2d_fusion->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
conv2d_fusion->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope();
|
||||
|
|
|
@ -200,8 +200,8 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
|
|||
if (tensor_info->data_type() != kObjectTypeTensorType) {
|
||||
tensor->set_shape(shape);
|
||||
tensor->set_data_type(tensor_info->data_type());
|
||||
if (primitive->GetAttr(opt::kWeightFormat) != nullptr && i == WEIGHT_INDEX) {
|
||||
tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat))));
|
||||
if (primitive->GetAttr(ops::kFormat) != nullptr && i == WEIGHT_INDEX) {
|
||||
tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat))));
|
||||
} else {
|
||||
tensor->set_format(schema::Format::Format_NHWC);
|
||||
}
|
||||
|
|
|
@ -50,8 +50,8 @@ void SetConvWeightFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *
|
|||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_ASSERT(prim != nullptr);
|
||||
if (prim->GetAttr(kWeightFormat) != nullptr && inputs.size() > 1) {
|
||||
inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat))));
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr && inputs.size() > 1) {
|
||||
inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat))));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,8 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
|||
MS_LOG(DEBUG) << "cnode is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (conv->GetAttr(ops::kFormat) == nullptr || conv->get_format() != mindspore::NHWC) {
|
||||
if (conv->GetAttr(ops::kFormat) == nullptr ||
|
||||
(conv->get_format() != mindspore::NHWC && conv->get_format() != mindspore::KHWC)) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto weight_node = cnode->input(kAnfPopulaterInputNumTwo);
|
||||
|
@ -54,10 +55,10 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
|
|||
auto default_param = weight_param->default_param();
|
||||
auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param);
|
||||
auto weight_shape = weight_tensor->shape();
|
||||
std::vector<int64_t> kernel_size = {weight_shape[0], weight_shape[1]};
|
||||
std::vector<int64_t> kernel_size = {weight_shape[1], weight_shape[2]};
|
||||
conv->set_kernel_size(kernel_size);
|
||||
conv->set_in_channel(weight_shape[2]);
|
||||
conv->set_out_channel(weight_shape[3]);
|
||||
conv->set_in_channel(weight_shape[3]);
|
||||
conv->set_out_channel(weight_shape[0]);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,248 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include <memory>
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
using mindspore::lite::converter::FmkType_MS;
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
using mindspore::schema::QuantType_AwareTraining;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_QUANT_ALL;
|
||||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_QUANT_WEIGHT;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; }
|
||||
void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
||||
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const CNodePtr &conv_node,
|
||||
const tensor::TensorPtr &tensor_info) const {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
switch (quant_type) {
|
||||
case schema::QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE:
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
|
||||
break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node,
|
||||
const tensor::TensorPtr &tensor_info) const {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
|
||||
switch (this->quant_type) {
|
||||
case QuantType_AwareTraining: {
|
||||
// sum up from current onnx quant models
|
||||
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
|
||||
} else {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
|
||||
}
|
||||
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} break;
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// conv (K x C/group x kH x kW) group = 1
|
||||
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
|
||||
// deconv (C x K/group x kH x kW) group = 1
|
||||
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
|
||||
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
|
||||
CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (format == schema::Format::Format_NHWC) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
|
||||
} else {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node,
|
||||
const tensor::TensorPtr &tensor_info) const {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto weight_node = conv_node->input(kConvWeightIndex);
|
||||
switch (this->quant_type) {
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_QUANT_WEIGHT:
|
||||
case QuantType_QUANT_ALL:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// sum up from current ms quant models
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const CNodePtr &conv_node,
|
||||
const tensor::TensorPtr &tensor_info) const {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
switch (this->quant_type) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
if (!is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
|
||||
} else {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
|
||||
}
|
||||
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatHardCodePass::HardCodeTF(const CNodePtr &conv_node,
|
||||
const tensor::TensorPtr &tensor_info) const {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(tensor_info != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
|
||||
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
|
||||
{
|
||||
if (!is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWCK));
|
||||
} else {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWKC));
|
||||
}
|
||||
}
|
||||
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWCK));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
(!CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || (fmk_type != FmkType_MS)) &&
|
||||
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto tensor_info = GetTensorInfo(weight_node);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "weight node must param value";
|
||||
return false;
|
||||
}
|
||||
lite::STATUS status;
|
||||
switch (fmk_type) {
|
||||
case FmkType_CAFFE:
|
||||
status = HardCodeCAFFE(conv_cnode, tensor_info);
|
||||
break;
|
||||
case FmkType_TFLITE:
|
||||
status = HardCodeTFLITE(conv_cnode, tensor_info);
|
||||
break;
|
||||
case FmkType_TF:
|
||||
status = HardCodeTF(conv_cnode, tensor_info);
|
||||
break;
|
||||
case FmkType_ONNX:
|
||||
status = HardCodeONNX(conv_cnode, tensor_info);
|
||||
break;
|
||||
case FmkType_MS:
|
||||
status = HardCodeMS(conv_cnode, tensor_info);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
|
||||
#include <string>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::QuantType;
|
||||
namespace mindspore::opt {
|
||||
class WeightFormatHardCodePass : public Pass {
|
||||
public:
|
||||
WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {}
|
||||
~WeightFormatHardCodePass() override = default;
|
||||
void SetQuantType(QuantType type);
|
||||
void SetFmkType(FmkType fmkType);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
lite::STATUS HardCodeCAFFE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
|
||||
lite::STATUS HardCodeONNX(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
|
||||
lite::STATUS HardCodeMS(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
|
||||
lite::STATUS HardCodeTFLITE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
|
||||
lite::STATUS HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info) const;
|
||||
|
||||
private:
|
||||
QuantType quant_type = schema::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type = lite::converter::FmkType_TF;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
|
|
@ -1,215 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
using mindspore::lite::converter::FmkType_MS;
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
using mindspore::schema::QuantType_AwareTraining;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kFirstInputIndex = 1;
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
lite::STATUS GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
|
||||
MS_ASSERT(perm != nullptr);
|
||||
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
|
||||
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
|
||||
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
|
||||
MS_LOG(ERROR) << "src_format or dst_format is error.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < src_format_str.size(); ++i) {
|
||||
auto pos = dst_format_str.find(src_format_str[i]);
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(ERROR) << "src_format and dst_format don't match.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
perm->push_back(static_cast<int>(pos));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; }
|
||||
void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; }
|
||||
void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; }
|
||||
lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const FuncGraphPtr &graph,
|
||||
const ParameterPtr &weight_node,
|
||||
std::vector<int> perm) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
std::vector<CNodePtr> adjust_nodes;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) ||
|
||||
CheckPrimitiveType(node, prim::kPrimAdam)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (std::any_of(inputs.begin(), inputs.end(),
|
||||
[&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; })) {
|
||||
if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
|
||||
CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) ||
|
||||
CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->AddAttr(kWeightFormat, MakeValue<int64_t>(mindspore::KHWC));
|
||||
continue;
|
||||
}
|
||||
adjust_nodes.push_back(cnode);
|
||||
}
|
||||
}
|
||||
if (adjust_nodes.empty()) {
|
||||
MS_LOG(DEBUG) << "do not need to adjust nodes.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
|
||||
auto prim = std::make_shared<ops::Transpose>();
|
||||
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>(1, 1));
|
||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||
if (!weight_node->has_default()) {
|
||||
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
transpose_node->set_abstract(abstract);
|
||||
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
|
||||
for (auto &adjust_node : adjust_nodes) {
|
||||
auto inputs = adjust_node->inputs();
|
||||
std::replace_if(
|
||||
inputs.begin(), inputs.end(), [&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; },
|
||||
transpose_node);
|
||||
adjust_node->set_inputs(inputs);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatTransformPass::HandleWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
if (src_format == dst_format) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<int> perm;
|
||||
auto status = GetTransposePerm(src_format, dst_format, &perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get perm failed.";
|
||||
return status;
|
||||
}
|
||||
status = TransposeInsertForWeightSharing(graph, weight_node, perm);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "transpose insert failed.";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) &&
|
||||
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto weight_value = GetTensorInfo(weight_node);
|
||||
if (weight_value == nullptr) {
|
||||
MS_LOG(ERROR) << "weight node must param value";
|
||||
return false;
|
||||
}
|
||||
MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 ||
|
||||
weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
|
||||
lite::STATUS status;
|
||||
auto value_ptr = prim->GetAttr(opt::kWeightFormat);
|
||||
auto weight_src_format = static_cast<schema::Format>(GetValue<int64_t>(value_ptr));
|
||||
schema::Format weight_dst_format = schema::Format::Format_KHWC;
|
||||
if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {
|
||||
weight_dst_format = dst_format;
|
||||
}
|
||||
status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
|
||||
if (status == RET_OK) {
|
||||
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(weight_dst_format));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
|
||||
<< EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope()
|
||||
<< "quant type:" << quant_type;
|
||||
return ERROR;
|
||||
}
|
||||
status = HandleWeightSharing(graph, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "handle weight-sharing failed.";
|
||||
return false;
|
||||
}
|
||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||
auto shape = weight_value->shape();
|
||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_node->set_abstract(abstract);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto status = ConvWeightFormatTrans(func_graph);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
|
||||
return status;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
using mindspore::schema::QuantType;
|
||||
namespace mindspore::opt {
|
||||
class WeightFormatTransformPass : public Pass {
|
||||
public:
|
||||
WeightFormatTransformPass() : Pass("weight_format_transform_pass") {}
|
||||
~WeightFormatTransformPass() override = default;
|
||||
void SetQuantType(QuantType type);
|
||||
void SetFmkType(FmkType fmkType);
|
||||
void SetDstFormat(schema::Format format);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph);
|
||||
lite::STATUS TransposeInsertForWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
|
||||
std::vector<int> perm);
|
||||
lite::STATUS HandleWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
|
||||
schema::Format src_format, schema::Format dst_format);
|
||||
|
||||
private:
|
||||
QuantType quant_type = schema::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type = lite::converter::FmkType_TF;
|
||||
schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
|
Loading…
Reference in New Issue