!16820 [MS_LITE] weight format

From: @YeFeng_24
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-06-02 12:21:02 +08:00 committed by Gitee
commit 5d81221b58
35 changed files with 881 additions and 659 deletions

View File

@ -213,8 +213,6 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc

View File

@ -106,13 +106,13 @@ ml_face_glasses 2.5
# ml_segmentation_matting 26 # output value unstable
ml_segmentation_atlanta_10 5
# ml_bodymask: The difference of output node divided by a very small value leads to a large error
ml_bodymask 14 13
ml_bodymask 16 18
ml_Hand_deploy 4 4
# ml_hand_3d_detection: The difference of output node divided by a very small value leads to a large error
ml_hand_3d_detection 12 10
ml_hand_3d_regression 3 4
# ml_ARengine23_bodypose: The difference of output node divided by a very small value leads to a large error
ml_ARengine23_bodypose 56 58
ml_ARengine23_bodypose 56 59
ml_ocr_bank_card_detection_inception_tmp 20
ml_ocr_bank_card_recognition_fcny 0.5
hiai_cv_aestheticsEngineModel_osp 1.6

View File

@ -68,7 +68,7 @@ ml_location_lane_counter0.onnx 0.5
#The encoder an decoder model are used in ml_asr scene, both have value overflow. Not suitable for fp16.
#But added for guarding process.
encoder.onnx;1,32,83 1262
mtk_emotions-d2012-75.onnx 5
mtk_emotions-d2012-75.onnx 6
mtk_detect-mbv1-shortcut-400-400.onnx 0.5
mtk_detect-mbv2-shortcut-400-400.onnx 0.5
mtk_detect_mbv1_640_480.onnx 0.5

View File

@ -9,7 +9,7 @@ ml_video_edit_generate_filter.pb 2
ml_ocr_jk.pb 0.7
# The accumulated error causes the threshold to be exceeded
ml_ocr_latin.pb 12
scan_hms_angle.pb 1.5
scan_hms_angle.pb 2.5
scan_hms_detect.pb 2.5
ml_face_openclose.pb;1,32,32,3 0.5
ml_object_detect.pb;1,288,288,3 2

View File

@ -280,8 +280,11 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
// attr weightFormat is only used by conv-like ops' second input
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
if ((opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) &&
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
data_info->format_ = mindspore::KHWC;
}
if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
MS_LOG(ERROR) << "fetch information from default param failed.";
@ -311,8 +314,8 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
MS_ASSERT(prim != nullptr);
if (value->isa<tensor::Tensor>()) {
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat));
if (index == 2 && prim->GetAttr(ops::kFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
}
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info);

View File

@ -81,8 +81,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/parallel/operator_info_register.cc
../optimizer/parallel/spliter.cc
../optimizer/parallel/split_strategy.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/conv1d_weight_expanding_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc

View File

@ -42,10 +42,7 @@
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
#include "tools/optimizer/fusion/squeeze_fusion.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
@ -161,14 +158,6 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
graph_pm->AddPass(std::make_shared<opt::IfPass>());
}
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_transform_pass);
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
slice_prepose_pass->SetFmkType(config->fmk);
graph_pm->AddPass(slice_prepose_pass);
@ -184,9 +173,6 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
}
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run graph convert pass failed.";
@ -205,10 +191,6 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
update_conv2d_param_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(update_conv2d_param_pass);
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
const_fold_pm->AddPass(weight_format_hardcode_pass);
auto infershape_pass = std::make_shared<opt::InferShapePass>();
infershape_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(infershape_pass);

View File

@ -46,7 +46,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
return nullptr;
}
func_graph = model_parser_->Parse(flag.modelFile, flag.weightFile);
func_graph = model_parser_->Parse(flag);
}
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";

View File

@ -16,12 +16,17 @@
#include "tools/converter/import/mindspore_importer.h"
#include <memory>
#include <vector>
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/import/primitive_adjust.h"
#include "tools/converter/import/mindir_adjust.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
namespace mindspore::lite {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
STATUS MindsporeImporter::AdjustForMindir(const FuncGraphPtr &func_graph, const converter::Flags &flag) {
auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>();
primitive_adjust_pass->SetFmkType(flag.fmk);
@ -42,7 +47,98 @@ STATUS MindsporeImporter::AdjustForMindir(const FuncGraphPtr &func_graph, const
return RET_OK;
}
STATUS MindsporeImporter::WeightFormatTransform(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
int status = HardCodeMindir(conv_cnode, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
}
}
return RET_OK;
}
STATUS MindsporeImporter::HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph) {
MS_ASSERT(conv_cnode != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
auto weight_node = conv_node->input(kConvWeightIndex);
schema::Format weight_dst_format = schema::Format::Format_KHWC;
STATUS status = RET_OK;
schema::Format weight_src_format = schema::Format::Format_NUM_OF_FORMAT;
switch (quant_type_) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (format == schema::Format::Format_KHWC) {
weight_src_format = schema::Format::Format_KHWC;
} else {
weight_src_format = schema::Format::Format_KCHW;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
<< ", node: " << conv_node->fullname_with_scope();
return RET_ERROR;
}
}
if (utils::isa<CNodePtr>(weight_node)) {
status = HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-const failed.";
return RET_ERROR;
}
}
weight_node = conv_node->input(kConvWeightIndex);
auto weight_value = opt::GetTensorInfo(weight_node);
if (weight_value != nullptr) {
status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
<< "quant type:" << quant_type_;
return RET_ERROR;
}
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
if (utils::isa<ParameterPtr>(weight_node)) {
status = HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
return lite::RET_OK;
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
quant_type_ = flag.quantType;
auto func_graph = LoadMindIR(flag.modelFile);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
@ -54,6 +150,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
MS_LOG(ERROR) << "AdjustForMindir failed.";
return nullptr;
}
auto status = WeightFormatTransform(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
return nullptr;
}
return func_graph;
}
} // namespace mindspore::lite

View File

@ -29,6 +29,9 @@ class MindsporeImporter {
private:
STATUS AdjustForMindir(const FuncGraphPtr &func_graph, const converter::Flags &flag);
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
STATUS HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph);
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace mindspore::lite

View File

@ -34,7 +34,7 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) { return this->res_graph_; }
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; }
protected:
FuncGraphPtr res_graph_ = nullptr;

View File

@ -30,9 +30,13 @@
#include "tools/converter/converter_context.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite {
namespace {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
bool IsSkipedLayer(const caffe::LayerParameter &layer) {
if (layer.type() == "Input" || layer.type() == "Dropout" || layer.type() == "Split") {
return true;
@ -62,7 +66,10 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
auto model_file = flag.modelFile;
auto weight_file = flag.weightFile;
quant_type_ = flag.quantType;
STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -94,9 +101,107 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
MS_LOG(ERROR) << "AdjustForAnf failed.";
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
return nullptr;
}
return res_graph_;
}
STATUS CaffeModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "weight node must param value";
return RET_OK;
}
lite::STATUS status;
status = HardCodeCaffe(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
}
}
return RET_OK;
}
STATUS CaffeModelParser::HardCodeCaffe(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
const FuncGraphPtr &graph) {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto weight_node = conv_node->input(kConvWeightIndex);
auto weight_value = opt::GetTensorInfo(weight_node);
if (weight_value == nullptr) {
MS_LOG(DEBUG) << "weight node must param value";
return RET_OK;
}
schema::Format weight_dst_format = schema::Format::Format_KHWC;
STATUS status = RET_OK;
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
switch (quant_type_) {
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
weight_src_format = schema::Format::Format_KCHW;
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
if (utils::isa<CNodePtr>(weight_node)) {
auto status =
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-const failed.";
return RET_ERROR;
}
}
weight_value = opt::GetTensorInfo(weight_node);
if (weight_value != nullptr) {
status = opt::TransFilterFormat(weight_value, schema::Format::Format_KCHW, weight_dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
<< "quant type:" << quant_type_;
return ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
if (utils::isa<ParameterPtr>(weight_node)) {
auto status =
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS CaffeModelParser::ConvertLayers() {
STATUS status = RET_OK;
std::map<std::string, caffe::LayerParameter> weight_layers;

View File

@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
FuncGraphPtr Parse(const converter::Flags &flag) override;
private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
@ -56,10 +56,15 @@ class CaffeModelParser : public ModelParser {
std::string GetOriginLayerName(const std::string &layer_name);
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
STATUS HardCodeCaffe(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
caffe::NetParameter caffe_model_;
caffe::NetParameter caffe_weight_;
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace mindspore::lite

View File

@ -147,9 +147,6 @@ bool Conv1DInOutAdjust::Run(const FuncGraphPtr &func_graph) {
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_ASSERT(prim != nullptr);
schema::Format schema_format = schema::Format::Format_KCHW;
if (prim->GetAttr(opt::kWeightFormat) != nullptr) {
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)));
}
// expand weight tensor to 4 dimensions.
auto weight_tensor = opt::GetTensorInfo(weight_node);
if (weight_tensor == nullptr) {

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include <unordered_map>
#include <utility>
#include "tools/optimizer/common/gllo_utils.h"
@ -33,8 +34,13 @@
#include "tools/converter/parser/onnx/onnx_inputs_adjust_pass.h"
#include "tools/converter/parser/onnx/onnx_pad_adjust_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "ops/transpose.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_INT8, mindspore::kNumberTypeInt8},
{onnx::TensorProto_DataType_UINT8, mindspore::kNumberTypeUInt8},
@ -46,7 +52,9 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
string model_file = flag.modelFile;
quant_type_ = flag.quantType;
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);
@ -79,9 +87,148 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
MS_LOG(ERROR) << "OnnxModelPostAdjust failed.";
return nullptr;
}
status = WeightFormatTransform(all_func_graphs);
if (status != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
return nullptr;
}
return res_graph_;
}
STATUS OnnxModelParser::WeightFormatTransform(const std::set<FuncGraphPtr> &all_func_graphs) {
for (auto graph : all_func_graphs) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeONNX(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
}
}
}
return RET_OK;
}
lite::STATUS OnnxModelParser::HardCodeONNX(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
const FuncGraphPtr &graph) {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
schema::Format weight_dst_format = schema::Format::Format_KHWC;
STATUS status = RET_OK;
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
auto weight_node = conv_node->input(kConvWeightIndex);
switch (quant_type_) {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
weight_src_format = schema::Format::Format_KHWC;
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
} else {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_CHWK;
}
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_KCHW;
} else {
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
} break;
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
if (format == schema::Format::Format_NHWC) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NHWC));
weight_src_format = schema::Format::Format_KHWC;
} else {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_KCHW;
}
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
if (status != RET_OK) {
return RET_ERROR;
}
return lite::RET_OK;
}
int OnnxModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
const FuncGraphPtr &graph, schema::Format weight_src_format,
schema::Format weight_dst_format) {
if (utils::isa<CNodePtr>(weight_node)) {
auto status =
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-const failed.";
return RET_ERROR;
}
}
auto weight_value = opt::GetTensorInfo(weight_node);
if (weight_value != nullptr) {
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
<< "quant type:" << quant_type_;
return RET_ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
if (utils::isa<ParameterPtr>(weight_node)) {
auto status =
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
return RET_OK;
}
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
auto status = ValidateFileStr(model_file, ".onnx");
if (status != RET_OK) {
@ -223,7 +370,9 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
status = RET_ERROR;
continue;
}
primitive_c->AddAttr(mindspore::opt::kWeightFormat, MakeValue<int64_t>(Format_NCHW));
if (primitive_c->GetAttr(ops::kFormat) == nullptr) {
primitive_c->AddAttr(mindspore::ops::kFormat, MakeValue<int64_t>(Format_NCHW));
}
status = ConvertOpQuantParams(onnx_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";

View File

@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
FuncGraphPtr Parse(const converter::Flags &flag) override;
int OnnxModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
@ -92,12 +92,17 @@ class OnnxModelParser : public ModelParser {
STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
const std::string &subgrah_name, const std::string &if_node_name,
const std::string &root_node_name);
STATUS WeightFormatTransform(const std::set<FuncGraphPtr> &all_func_graphs);
STATUS HardCodeONNX(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
schema::Format weight_src_format, schema::Format weight_dst_format);
onnx::ModelProto onnx_model_;
onnx::GraphProto onnx_root_graph_;
std::vector<FuncGraphPtr> all_subgraphs_;
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore

View File

@ -15,10 +15,17 @@
*/
#include "tools/converter/parser/parser_utils.h"
#include <memory>
#include <algorithm>
#include <vector>
#include <string>
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
#include "tools/converter/parser/unused_node_remove_pass.h"
#include "tools/converter/parser/conv1d_inout_adjust.h"
#include "tools/converter/parser/inputs_adjust.h"
#include "ops/transpose.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/common/tensor_util.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite {
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
@ -78,4 +85,163 @@ int PostAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
}
return RET_OK;
}
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
MS_ASSERT(perm != nullptr);
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
MS_LOG(ERROR) << "src_format or dst_format is error.";
return lite::RET_ERROR;
}
for (size_t i = 0; i < src_format_str.size(); ++i) {
auto pos = src_format_str.find(dst_format_str[i]);
if (pos == std::string::npos) {
MS_LOG(ERROR) << "src_format and dst_format don't match.";
return lite::RET_ERROR;
}
perm->push_back(static_cast<int>(pos));
}
return lite::RET_OK;
}
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
MS_ASSERT(perm != nullptr);
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
MS_LOG(ERROR) << "src_format or dst_format is error.";
return lite::RET_ERROR;
}
for (size_t i = 0; i < src_format_str.size(); ++i) {
auto pos = dst_format_str.find(src_format_str[i]);
if (pos == std::string::npos) {
MS_LOG(ERROR) << "src_format and dst_format don't match.";
return lite::RET_ERROR;
}
perm->push_back(static_cast<int>(pos));
}
return lite::RET_OK;
}
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
std::vector<int> perm) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
auto node_list = TopoSort(graph->get_return());
std::vector<CNodePtr> adjust_nodes;
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
opt::CheckPrimitiveType(node, prim::kPrimAdam)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(),
[&](const AnfNodePtr &anf_node) { return weight_node == anf_node; })) {
if (opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
continue;
}
adjust_nodes.push_back(cnode);
}
}
if (adjust_nodes.empty()) {
MS_LOG(DEBUG) << "do not need to adjust nodes.";
return lite::RET_OK;
}
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_sharing_perm");
auto prim = std::make_shared<ops::Transpose>();
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
if (!weight_node->has_default()) {
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
return lite::RET_ERROR;
}
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
if (weight_tensor == nullptr) {
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
return lite::RET_ERROR;
}
auto abstract = CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
transpose_node->set_abstract(abstract);
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_sharing_post");
for (auto &adjust_node : adjust_nodes) {
auto inputs = adjust_node->inputs();
std::replace_if(
inputs.begin(), inputs.end(), [&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; },
transpose_node);
adjust_node->set_inputs(inputs);
}
return lite::RET_OK;
}
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
schema::Format src_format, schema::Format dst_format) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
if (src_format == dst_format) {
return lite::RET_OK;
}
std::vector<int> perm;
auto status = GetTransposePermSharing(src_format, dst_format, &perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get perm failed.";
return status;
}
status = TransposeInsertForWeightSharing(graph, format, weight_node, perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "transpose insert failed.";
}
return status;
}
int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
std::vector<int> perm) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
auto manager = Manage(graph);
if (opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) ||
opt::CheckPrimitiveType(weight_node, prim::kPrimLoad)) {
manager->Replace(weight_node, weight_node->input(1));
return RET_OK;
}
auto perm_node = opt::BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_const_perm");
auto prim = std::make_shared<ops::Transpose>();
prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(1, 1));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_const_post");
conv_node->set_input(2, transpose_node);
return lite::RET_OK;
}
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
schema::Format src_format, schema::Format dst_format) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
if (src_format == dst_format) {
return lite::RET_OK;
}
std::vector<int> perm;
auto status = GetTransposePerm(src_format, dst_format, &perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get perm failed.";
return status;
}
status = TransposeInsertForWeightConst(graph, conv_node, weight_node, perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "transpose insert failed.";
}
return status;
}
} // namespace mindspore::lite

View File

@ -18,15 +18,26 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_PARSER_UTILS_H
#include <set>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
int PostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
int TransposeInsertForWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
std::vector<int> perm);
int HandleWeightConst(const FuncGraphPtr &graph, const CNodePtr &conv_node, const CNodePtr &weight_node,
schema::Format src_format, schema::Format dst_format);
int TransposeInsertForWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
std::vector<int> perm);
int HandleWeightSharing(const FuncGraphPtr &graph, int64_t format, const ParameterPtr &weight_node,
schema::Format src_format, schema::Format dst_format);
} // namespace lite
} // namespace mindspore

View File

@ -31,6 +31,7 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/common/tensor_util.h"
namespace mindspore {
namespace lite {
@ -40,6 +41,7 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) {
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListSetItem) ||
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve);
}
constexpr size_t kConvWeightIndex = 2;
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
int index = 0) {
@ -476,7 +478,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
return RET_OK;
}
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
auto modelFile = flag.modelFile;
quant_type_ = flag.quantType;
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
@ -562,9 +566,146 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
MS_LOG(ERROR) << "AdjustForOnnxModel failed.";
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
return nullptr;
}
res_graph_->set_manager(nullptr);
static auto root_func_manager = Manage(res_graph_);
return res_graph_;
}
STATUS TFModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeTF(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
}
}
return RET_OK;
}
STATUS TFModelParser::HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
const FuncGraphPtr &graph) {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
schema::Format weight_dst_format = schema::Format::Format_KHWC;
STATUS status = RET_OK;
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
auto weight_node = conv_node->input(kConvWeightIndex);
auto weight_value = opt::GetTensorInfo(weight_node);
switch (quant_type_) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_HWCK;
} else {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_HWKC;
}
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_HWCK;
}
if (format == Format_NCHW) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NCHW));
} else if (format == Format_KHWC) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
weight_src_format = schema::Format::Format_KHWC;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
if (status != RET_OK) {
return RET_ERROR;
}
if (format == Format_NCHW) {
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format_NCHW));
}
return RET_OK;
}
int TFModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
const FuncGraphPtr &graph, schema::Format weight_src_format,
schema::Format weight_dst_format) {
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return RET_ERROR;
}
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
if (utils::isa<CNodePtr>(weight_node)) {
auto status =
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-const failed.";
return RET_ERROR;
}
}
auto weight_value = opt::GetTensorInfo(weight_node);
if (weight_value != nullptr) {
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
<< "quant type:" << quant_type_;
return RET_ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
if (utils::isa<ParameterPtr>(weight_node)) {
auto status =
HandleWeightSharing(graph, format, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
return RET_OK;
}
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,

View File

@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
TFModelParser() = default;
~TFModelParser() override = default;
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile) override;
FuncGraphPtr Parse(const converter::Flags &flag) override;
int TFModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
@ -93,6 +93,13 @@ class TFModelParser : public ModelParser {
STATUS ConnectNullInput();
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
STATUS HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
schema::Format weight_src_format, schema::Format weight_dst_format);
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
@ -104,6 +111,7 @@ class TFModelParser : public ModelParser {
std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_;
std::unordered_map<std::string, int> node_output_num_;
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore

View File

@ -32,6 +32,9 @@
#include "tools/converter/parser/parser_utils.h"
namespace mindspore::lite {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::string &model_path) {
size_t size = 0;
tflite_model_buf_ = ReadFile(model_path.c_str(), &size);
@ -47,7 +50,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
return tflite::UnPackModel(tflite_model_buf_);
}
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file) {
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
auto model_file = flag.modelFile;
quant_type_ = flag.quantType;
// load graph
tflite_model_ = ReadTfliteModel(model_file);
if (tflite_model_ == nullptr) {
@ -96,9 +101,124 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::
MS_LOG(ERROR) << "AdjustForOnnxModel failed.";
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
return nullptr;
}
return res_graph_;
}
STATUS TfliteModelParser::WeightFormatTransform(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = opt::GetTensorInfo(weight_node);
lite::STATUS status;
status = HardCodeTflite(conv_cnode, tensor_info, graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return RET_ERROR;
}
}
return RET_OK;
}
STATUS TfliteModelParser::HardCodeTflite(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info,
const FuncGraphPtr &graph) {
MS_ASSERT(conv_cnode != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
schema::Format weight_dst_format = schema::Format::Format_KHWC;
STATUS status = RET_OK;
schema::Format weight_src_format = Format_NUM_OF_FORMAT;
auto weight_node = conv_node->input(kConvWeightIndex);
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
switch (quant_type_) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (format == KHWC) {
weight_src_format = schema::Format::Format_KHWC;
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
weight_src_format = schema::Format::Format_KHWC;
} else {
weight_src_format = schema::Format::Format_CHWK;
}
} else if (opt::CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
weight_src_format = schema::Format::Format_CHWK;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type_)
<< ", node: " << conv_node->fullname_with_scope();
return RET_ERROR;
}
}
status = DoWeightFormatTransform(conv_node, weight_node, graph, weight_src_format, weight_dst_format);
if (status != RET_OK) {
return RET_ERROR;
}
return lite::RET_OK;
}
int TfliteModelParser::DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node,
const FuncGraphPtr &graph, schema::Format weight_src_format,
schema::Format weight_dst_format) {
if (utils::isa<CNodePtr>(weight_node)) {
auto status =
HandleWeightConst(graph, conv_node, weight_node->cast<CNodePtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-const failed.";
return RET_ERROR;
}
}
auto weight_value = opt::GetTensorInfo(weight_node);
if (weight_value != nullptr) {
auto status = opt::TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << conv_node->fullname_with_scope()
<< "quant type:" << quant_type_;
return RET_ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
if (utils::isa<ParameterPtr>(weight_node)) {
auto status =
HandleWeightSharing(graph, KHWC, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return RET_ERROR;
}
}
return RET_OK;
}
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
std::string tensor_name = op_name + "/input-" + std::to_string(index);

View File

@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) override;
FuncGraphPtr Parse(const converter::Flags &flag) override;
int TfliteModelPostAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
@ -52,6 +52,11 @@ class TfliteModelParser : public ModelParser {
STATUS ConvertGraphOutputs();
static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params,
int round_type = 1);
int DoWeightFormatTransform(const CNodePtr &conv_node, const AnfNodePtr &weight_node, const FuncGraphPtr &graph,
schema::Format weight_src_format, schema::Format weight_dst_format);
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
STATUS HardCodeTflite(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info, const FuncGraphPtr &graph);
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
};
} // namespace lite
} // namespace mindspore

View File

@ -38,7 +38,6 @@ namespace opt {
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared<Primitive>(ops::kNameConv2DBackpropInputFusion);
constexpr auto kWeightFormat = "weight_format";
std::vector<int> CastToInt(const ValuePtr &value);
std::vector<std::vector<int>> CastToVec2DInt(const ValuePtr &value);

View File

@ -95,10 +95,12 @@ STATUS GenNewConvBias(const ParameterPtr &down_bias_node, const ParameterPtr &do
new_bias_data[i] += down_bias_data[i];
}
}
new_bias_node->set_name(down_bias_node->fullname_with_scope());
new_bias_node->set_default_param(tensor_info);
new_bias_node->set_abstract(down_bias_node->abstract());
new_bias_node->set_name(down_weight_node->fullname_with_scope());
auto status = lite::InitParameterFromTensorInfo(new_bias_node, tensor_info);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return RET_ERROR;
}
return RET_OK;
}
// up weight shape[cout0,h,w,cin0] down weight shape[cout1,1,1,cout0],new weight shape [cout1,h,w,cin0]
@ -140,8 +142,11 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr
}
new_weight_node->set_name(down_weight_node->fullname_with_scope());
new_weight_node->set_default_param(tensor_info);
new_weight_node->set_abstract(down_weight_node->abstract());
auto status = lite::InitParameterFromTensorInfo(new_weight_node, tensor_info);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
return RET_ERROR;
}
return RET_OK;
}
@ -154,8 +159,10 @@ void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &u
MS_LOG(ERROR) << "GenNewConvWeight failed.";
return;
}
auto manager = func_graph->manager();
manager->Replace(down_weight_parameter, new_weight_paramter);
down_conv_cnode->set_input(kConvWeightIndex, new_weight_paramter);
// whether up conv node has bias
if (up_conv_cnode->inputs().size() == kConvWithBiasLen) {
ParameterPtr down_bias_parameter;
@ -169,7 +176,7 @@ void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &u
return;
}
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
manager->Replace(down_bias_parameter, new_bias_parameter);
down_conv_cnode->set_input(kConvBiasIndex, new_bias_parameter);
} else {
down_conv_cnode->add_input(new_bias_parameter);
}

View File

@ -56,16 +56,10 @@ void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const
if (dst_weight == nullptr || conv_weight == nullptr || scale_weight == nullptr) {
return;
}
if (fmk == lite::converter::FmkType_TF) {
for (int i = 0; i < weight_shape_size; i++) {
dst_weight[i] = conv_weight[i] * scale_weight[i % kernel_num];
}
} else {
auto kernel_size = weight_shape_size / kernel_num;
for (int i = 0; i < kernel_num; i++) {
for (int j = 0; j < kernel_size; j++) {
dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
}
auto kernel_size = weight_shape_size / kernel_num;
for (int i = 0; i < kernel_num; i++) {
for (int j = 0; j < kernel_size; j++) {
dst_weight[i * kernel_size + j] = conv_weight[i * kernel_size + j] * scale_weight[i];
}
}
}
@ -77,29 +71,13 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig
}
MS_ASSERT(group > 0);
auto weight_data = reinterpret_cast<float *>(weight_tensor->data_c());
if (fmk == lite::converter::FmkType_TF) {
auto cin_group = weight_tensor->shape()[3] / group;
int area_size = weight_tensor->shape()[0] * weight_tensor->shape()[1];
auto cin_group = weight_tensor->shape()[0] / group;
int area_size = weight_tensor->shape()[1] * weight_tensor->shape()[2];
for (int k = 0; k < cin_group; ++k) {
for (int j = 0; j < area_size; j++) {
for (int i = 0; i < kernel_num; ++i) {
for (int k = 0; k < cin_group; ++k) {
dst_weight[k + i * cin_group + j * kernel_num * cin_group] =
weight_data[k + i * cin_group + j * kernel_num * cin_group] * scale_weight[i];
}
}
}
} else {
MS_ASSERT(group > 0);
auto cin_group = weight_tensor->shape()[0] / group;
int area_size = weight_tensor->shape()[2] * weight_tensor->shape()[3];
int cout_size = kernel_num * area_size;
for (int k = 0; k < cin_group; ++k) {
for (int i = 0; i < kernel_num; ++i) {
auto row_addr = weight_data + k * cout_size + i * area_size;
auto new_row_addr = dst_weight + k * cout_size + i * area_size;
for (int j = 0; j < area_size; j++) {
new_row_addr[j] = row_addr[j] * scale_weight[i];
}
dst_weight[i + j * kernel_num + k * area_size * kernel_num] =
weight_data[i + j * kernel_num + k * area_size * kernel_num] * scale_weight[i];
}
}
}

View File

@ -488,7 +488,7 @@ CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &fu
bias, stacked_hidden, input_length};
auto new_node = func_graph->NewCNode(new_node_inputs);
auto prim = GetValueNode<PrimitivePtr>(new_node->input(0));
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::NHWC));
prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format::NHWC));
new_node->set_fullname_with_scope(base_name);
return new_node;
}

View File

@ -77,8 +77,8 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
MS_ASSERT(prim != nullptr);
schema::Format schema_format = schema::Format::Format_KCHW;
if (prim->GetAttr(opt::kWeightFormat) != nullptr) {
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)));
if (prim->GetAttr(ops::kFormat) != nullptr) {
schema_format = static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat)));
}
// expand weight tensor to 4 dimensions.
auto status = ExpandFilterShape(weight_node, schema_format);

View File

@ -95,7 +95,7 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status == RET_OK) {
conv2d_fusion->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(weight_dst_format));
conv2d_fusion->AddAttr(ops::kFormat, MakeValue<int64_t>(weight_dst_format));
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope();

View File

@ -200,8 +200,8 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
if (tensor_info->data_type() != kObjectTypeTensorType) {
tensor->set_shape(shape);
tensor->set_data_type(tensor_info->data_type());
if (primitive->GetAttr(opt::kWeightFormat) != nullptr && i == WEIGHT_INDEX) {
tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat))));
if (primitive->GetAttr(ops::kFormat) != nullptr && i == WEIGHT_INDEX) {
tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(ops::kFormat))));
} else {
tensor->set_format(schema::Format::Format_NHWC);
}

View File

@ -50,8 +50,8 @@ void SetConvWeightFormat(const CNodePtr &cnode, const std::vector<lite::Tensor *
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(prim != nullptr);
if (prim->GetAttr(kWeightFormat) != nullptr && inputs.size() > 1) {
inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat))));
if (prim->GetAttr(ops::kFormat) != nullptr && inputs.size() > 1) {
inputs[1]->set_format(static_cast<schema::Format>(GetValue<int64_t>(prim->GetAttr(ops::kFormat))));
}
}

View File

@ -34,7 +34,8 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
MS_LOG(DEBUG) << "cnode is invalid.";
return lite::RET_ERROR;
}
if (conv->GetAttr(ops::kFormat) == nullptr || conv->get_format() != mindspore::NHWC) {
if (conv->GetAttr(ops::kFormat) == nullptr ||
(conv->get_format() != mindspore::NHWC && conv->get_format() != mindspore::KHWC)) {
return lite::RET_OK;
}
auto weight_node = cnode->input(kAnfPopulaterInputNumTwo);
@ -54,10 +55,10 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) {
auto default_param = weight_param->default_param();
auto weight_tensor = std::dynamic_pointer_cast<tensor::Tensor>(default_param);
auto weight_shape = weight_tensor->shape();
std::vector<int64_t> kernel_size = {weight_shape[0], weight_shape[1]};
std::vector<int64_t> kernel_size = {weight_shape[1], weight_shape[2]};
conv->set_kernel_size(kernel_size);
conv->set_in_channel(weight_shape[2]);
conv->set_out_channel(weight_shape[3]);
conv->set_in_channel(weight_shape[3]);
conv->set_out_channel(weight_shape[0]);
return lite::RET_OK;
}

View File

@ -1,248 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include <memory>
#include "ops/fusion/conv2d_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_MS;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_TF;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_ALL;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_QUANT_WEIGHT;
using mindspore::schema::QuantType_WeightQuant;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; }
void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; }
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const CNodePtr &conv_node,
const tensor::TensorPtr &tensor_info) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
switch (quant_type) {
case schema::QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE:
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node,
const tensor::TensorPtr &tensor_info) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
int64_t format = prim->GetAttr(ops::kFormat) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kFormat)) : 0;
switch (this->quant_type) {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
} else {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
}
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
} else {
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
} break;
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) ||
CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) {
if (format == schema::Format::Format_NHWC) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
} else {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
}
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node,
const tensor::TensorPtr &tensor_info) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
auto weight_node = conv_node->input(kConvWeightIndex);
switch (this->quant_type) {
case QuantType_WeightQuant:
case QuantType_PostTraining:
case QuantType_QUANT_WEIGHT:
case QuantType_QUANT_ALL:
case QuantType_QUANT_NONE: {
// sum up from current ms quant models
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KCHW));
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const CNodePtr &conv_node,
const tensor::TensorPtr &tensor_info) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
switch (this->quant_type) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
if (!is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::KHWC));
} else {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
}
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::CHWK));
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
lite::STATUS WeightFormatHardCodePass::HardCodeTF(const CNodePtr &conv_node,
const tensor::TensorPtr &tensor_info) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(tensor_info != nullptr);
auto prim = GetValueNode<PrimitivePtr>(conv_node->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
bool is_depth_wise = prim->GetAttr(ops::kIsDepthWise) != nullptr && GetValue<bool>(prim->GetAttr(ops::kIsDepthWise));
if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) {
{
if (!is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWCK));
} else {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWKC));
}
}
} else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(Format::HWCK));
}
return lite::RET_OK;
}
bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
(!CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || (fmk_type != FmkType_MS)) &&
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto tensor_info = GetTensorInfo(weight_node);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "weight node must param value";
return false;
}
lite::STATUS status;
switch (fmk_type) {
case FmkType_CAFFE:
status = HardCodeCAFFE(conv_cnode, tensor_info);
break;
case FmkType_TFLITE:
status = HardCodeTFLITE(conv_cnode, tensor_info);
break;
case FmkType_TF:
status = HardCodeTF(conv_cnode, tensor_info);
break;
case FmkType_ONNX:
status = HardCodeONNX(conv_cnode, tensor_info);
break;
case FmkType_MS:
status = HardCodeMS(conv_cnode, tensor_info);
break;
default:
MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
return false;
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Format hard code failed: " << status << ", node: " << node->fullname_with_scope();
return false;
}
}
return false;
}
} // namespace mindspore::opt

View File

@ -1,47 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatHardCodePass : public Pass {
public:
WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {}
~WeightFormatHardCodePass() override = default;
void SetQuantType(QuantType type);
void SetFmkType(FmkType fmkType);
bool Run(const FuncGraphPtr &graph) override;
private:
lite::STATUS HardCodeCAFFE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
lite::STATUS HardCodeONNX(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
lite::STATUS HardCodeMS(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
lite::STATUS HardCodeTFLITE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const;
lite::STATUS HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info) const;
private:
QuantType quant_type = schema::QuantType_QUANT_NONE;
FmkType fmk_type = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_

View File

@ -1,215 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include <memory>
#include <algorithm>
#include <vector>
#include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quant_param_holder.h"
using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_MS;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_WeightQuant;
namespace mindspore::opt {
namespace {
constexpr size_t kFirstInputIndex = 1;
constexpr size_t kConvWeightIndex = 2;
lite::STATUS GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
MS_ASSERT(perm != nullptr);
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
MS_LOG(ERROR) << "src_format or dst_format is error.";
return lite::RET_ERROR;
}
for (size_t i = 0; i < src_format_str.size(); ++i) {
auto pos = dst_format_str.find(src_format_str[i]);
if (pos == std::string::npos) {
MS_LOG(ERROR) << "src_format and dst_format don't match.";
return lite::RET_ERROR;
}
perm->push_back(static_cast<int>(pos));
}
return lite::RET_OK;
}
} // namespace
void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; }
void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; }
void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; }
lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const FuncGraphPtr &graph,
const ParameterPtr &weight_node,
std::vector<int> perm) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
auto node_list = TopoSort(graph->get_return());
std::vector<CNodePtr> adjust_nodes;
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
if (CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) ||
CheckPrimitiveType(node, prim::kPrimAdam)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(),
[&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; })) {
if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) ||
CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->AddAttr(kWeightFormat, MakeValue<int64_t>(mindspore::KHWC));
continue;
}
adjust_nodes.push_back(cnode);
}
}
if (adjust_nodes.empty()) {
MS_LOG(DEBUG) << "do not need to adjust nodes.";
return lite::RET_OK;
}
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
auto prim = std::make_shared<ops::Transpose>();
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>(1, 1));
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
if (!weight_node->has_default()) {
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
return lite::RET_ERROR;
}
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
if (weight_tensor == nullptr) {
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
return lite::RET_ERROR;
}
auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
transpose_node->set_abstract(abstract);
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
for (auto &adjust_node : adjust_nodes) {
auto inputs = adjust_node->inputs();
std::replace_if(
inputs.begin(), inputs.end(), [&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; },
transpose_node);
adjust_node->set_inputs(inputs);
}
return lite::RET_OK;
}
lite::STATUS WeightFormatTransformPass::HandleWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
schema::Format src_format, schema::Format dst_format) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(weight_node != nullptr);
if (src_format == dst_format) {
return lite::RET_OK;
}
std::vector<int> perm;
auto status = GetTransposePerm(src_format, dst_format, &perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get perm failed.";
return status;
}
status = TransposeInsertForWeightSharing(graph, weight_node, perm);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "transpose insert failed.";
}
return status;
}
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) &&
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto prim = GetValueNode<PrimitivePtr>(conv_cnode->input(0));
if (prim == nullptr) {
MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive.";
return lite::RET_ERROR;
}
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto weight_value = GetTensorInfo(weight_node);
if (weight_value == nullptr) {
MS_LOG(ERROR) << "weight node must param value";
return false;
}
MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 ||
weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
lite::STATUS status;
auto value_ptr = prim->GetAttr(opt::kWeightFormat);
auto weight_src_format = static_cast<schema::Format>(GetValue<int64_t>(value_ptr));
schema::Format weight_dst_format = schema::Format::Format_KHWC;
if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {
weight_dst_format = dst_format;
}
status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format);
if (status == RET_OK) {
prim->AddAttr(opt::kWeightFormat, MakeValue<int64_t>(weight_dst_format));
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope()
<< "quant type:" << quant_type;
return ERROR;
}
status = HandleWeightSharing(graph, weight_node->cast<ParameterPtr>(), weight_src_format, weight_dst_format);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "handle weight-sharing failed.";
return false;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
return RET_OK;
}
bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto status = ConvWeightFormatTrans(func_graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
return status;
}
return false;
}
} // namespace mindspore::opt

View File

@ -1,50 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#include <string>
#include <vector>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatTransformPass : public Pass {
public:
WeightFormatTransformPass() : Pass("weight_format_transform_pass") {}
~WeightFormatTransformPass() override = default;
void SetQuantType(QuantType type);
void SetFmkType(FmkType fmkType);
void SetDstFormat(schema::Format format);
bool Run(const FuncGraphPtr &graph) override;
private:
lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph);
lite::STATUS TransposeInsertForWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
std::vector<int> perm);
lite::STATUS HandleWeightSharing(const FuncGraphPtr &graph, const ParameterPtr &weight_node,
schema::Format src_format, schema::Format dst_format);
private:
QuantType quant_type = schema::QuantType_QUANT_NONE;
FmkType fmk_type = lite::converter::FmkType_TF;
schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_