!15228 fix abstract of parameter

From: @hangangqiang
Reviewed-by: @zhang_xue_tong,@ddwsky
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-04-27 09:00:31 +08:00 committed by Gitee
commit 1a9463c513
34 changed files with 797 additions and 401 deletions

View File

@ -38,9 +38,12 @@ int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
if (!parameter->infer_flag_) { if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID; return NNACL_INFER_INVALID;
} }
const TensorC *input_k_tensor = inputs[1];
if (input_k_tensor->data_ == NULL) {
return NNACL_INFER_INVALID;
}
TopkParameter *param = (TopkParameter *)parameter; TopkParameter *param = (TopkParameter *)parameter;
const TensorC *input_k_tensor = inputs[1];
param->k_ = ((int32_t *)input_k_tensor->data_)[0]; param->k_ = ((int32_t *)input_k_tensor->data_)[0];
int out_shape[MAX_SHAPE_SIZE]; int out_shape[MAX_SHAPE_SIZE];

View File

@ -75,12 +75,14 @@ class LiteModel : public Model {
} else { } else {
node->name_ = c_node->name()->c_str(); node->name_ = c_node->name()->c_str();
} }
auto count = c_node->inputIndex()->size(); if (c_node->inputIndex() != nullptr) {
for (uint32_t j = 0; j < count; ++j) { auto count = c_node->inputIndex()->size();
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
}
} }
if (c_node->outputIndex() != nullptr) { if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size(); auto count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) { for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
} }

View File

@ -247,6 +247,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc

View File

@ -73,6 +73,40 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std
return tensor_info; return tensor_info;
} }
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) {
auto tensor_info = CreateTensorInfo(nullptr, 0, shape, data_type);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return nullptr;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
return abstract;
}
int SetParameterAbstractAndParam(const ParameterPtr &parameter, const void *data, size_t data_size,
const std::vector<int64_t> &shape, TypeId data_type) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "Input parameter is nullptr";
return RET_INPUT_PARAM_INVALID;
}
auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
return RET_OK;
}
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) { int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) {
if (tensor_info == nullptr) { if (tensor_info == nullptr) {
MS_LOG(ERROR) << "tensor info is nullptr."; MS_LOG(ERROR) << "tensor info is nullptr.";

View File

@ -46,6 +46,11 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT>
tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape, tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
TypeId data_type); TypeId data_type);
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type);
int SetParameterAbstractAndParam(const ParameterPtr &parameter, const void *data, size_t data_size,
const std::vector<int64_t> &shape, TypeId data_type);
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size); int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size);
std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info,

View File

@ -54,6 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_bidirection_gru_fusion.cc ../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/fusion/matmul_add_fusion.cc ../optimizer/fusion/matmul_add_fusion.cc
../optimizer/fusion/mul_add_fusion.cc
../optimizer/fusion/gelu_fusion.cc ../optimizer/fusion/gelu_fusion.cc
../optimizer/fusion/tf_gelu_fusion.cc ../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc ../optimizer/fusion/onnx_gelu_fusion.cc

View File

@ -70,7 +70,7 @@ MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
} }
MS_LOG(INFO) << "Run anfTransform success"; MS_LOG(INFO) << "Run anfTransform success";
// protobuf -> flatbuf // protobuf -> flatbuffer
auto meta_graph = Export(graph, false, false, flag->trainModel); auto meta_graph = Export(graph, false, false, flag->trainModel);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr"; MS_LOG(ERROR) << "Export to meta graph return nullptr";

View File

@ -39,7 +39,6 @@
using std::string; using std::string;
namespace mindspore::lite { namespace mindspore::lite {
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
std::vector<schema::CNodeT *> old_nodes{}; std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graph_defT_->nodes.size()); old_nodes.resize(graph_defT_->nodes.size());
@ -71,54 +70,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// generate and infer quant parameters // format transpose global optimize
{
Optimizer infer_quant_param_pass;
infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
status = infer_quant_param_pass.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
return status;
}
}
{
// format transform
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
if (ctx.fmk != converter::FmkType_TF) {
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
format_trans_optimizer.AddPass(infer_shape_pass);
}
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
@ -134,20 +86,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// postconvert pass // node replace
{ if (!ctx.trainModel) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer replace_optimizer; Optimizer replace_optimizer;
if (!ctx.trainModel) { replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk));
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
replace_optimizer.AddPass(batch_norm_scale_pass);
}
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = replace_optimizer.Run(graph_defT_); status = replace_optimizer.Run(graph_defT_);
@ -157,6 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// node fusion
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
@ -171,19 +117,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// do quantization // quantization
if (ctx.fmk != converter::FmkType_TF) { if (ctx.fmk != converter::FmkType_TF) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
Optimizer tensor_quant_optimizer; Optimizer tensor_quant_optimizer;
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
auto infer_shape_pass = new (std::nothrow) InferShapePass(); tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
if (infer_shape_pass == nullptr) { tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
tensor_quant_optimizer.AddPass(infer_shape_pass);
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensor_quant_optimizer.Run(graph_defT_); status = tensor_quant_optimizer.Run(graph_defT_);
@ -193,38 +134,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// insert quantNode and deQuantNode // quantization
if (ctx.fmk != converter::FmkType_TF) { if (ctx.fmk != converter::FmkType_TF) {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes();
Optimizer quant_node_optimizer; Optimizer quant_node_optimizer;
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
auto infer_shape_pass = new (std::nothrow) InferShapePass(); auto old_nodes = GetGraphNodes();
if (infer_shape_pass == nullptr) { quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
MS_LOG(ERROR) << "new InferShapePass failed"; quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType));
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
quant_node_optimizer.AddPass(infer_shape_pass);
status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
if (dtype_trans_pass == nullptr) {
MS_LOG(ERROR) << "new dtype_trans_pass failed";
return RET_MEMORY_FAILED;
}
dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
quant_node_optimizer.AddPass(dtype_trans_pass);
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = quant_node_optimizer.Run(graph_defT_); status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
@ -232,7 +152,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// switch pass // controlflow pass
{ {
// init old node indices // init old node indices
auto old_nodes = GetGraphNodes(); auto old_nodes = GetGraphNodes();
@ -240,6 +160,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
switch_optimizer.AddPass(new (std::nothrow) SwitchPass()); switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = switch_optimizer.Run(graph_defT_); status = switch_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run switch_optimizer Failed"; MS_LOG(ERROR) << "Run switch_optimizer Failed";
@ -247,34 +168,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// subgraph tensor pass
{
Optimizer subgraph_tensor_optimizer;
subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraph_tensor_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
return status;
}
}
// tensor name
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer name_optimizer;
name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
status = name_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
return status;
}
}
{ {
Optimizer nested_loop_optimizer; Optimizer nested_loop_optimizer;
auto old_nodes = GetGraphNodes();
nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nested_loop_optimizer.Run(graph_defT_); status = nested_loop_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) { if (status != RET_OK && status != RET_NO_CHANGE) {
@ -284,30 +182,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
{ {
Optimizer quant_param_optimizer; Optimizer forming_model_optimizer;
quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
status = quant_param_optimizer.Run(graph_defT_); forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
if (status != RET_OK && status != RET_NO_CHANGE) { forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed"; status = forming_model_optimizer.Run(graph_defT_);
return status;
}
}
{
Optimizer infer_shape_optimizer;
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
infer_shape_optimizer.AddPass(infer_shape_pass);
status = infer_shape_optimizer.Run(graph_defT_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
return status; return status;
} }
} }
return RET_OK; return RET_OK;
} // namespace mindspore::lite }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -36,14 +36,12 @@ struct BNWeightTensors {
}; };
class BatchNormConvertScalePass : public GraphPass { class BatchNormConvertScalePass : public GraphPass {
public: public:
BatchNormConvertScalePass() = default; explicit BatchNormConvertScalePass(converter::FmkType fmk) : fmkType(fmk) {}
~BatchNormConvertScalePass() = default; ~BatchNormConvertScalePass() = default;
STATUS Run(MetaGraphT *graph) override; STATUS Run(MetaGraphT *graph) override;
void SetFmk(converter::FmkType fmk) { this->fmkType = fmk; }
protected: protected:
STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode); STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);

View File

@ -276,10 +276,5 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
castOpCopyer); castOpCopyer);
} }
void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -30,16 +30,13 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
class DTypeTransPass : public GraphPass { class DTypeTransPass : public GraphPass {
public: public:
DTypeTransPass() : id_(0) {} DTypeTransPass(TypeId model_input_data_type, TypeId model_output_data_type)
: id_(0), input_data_dtype(model_input_data_type), output_data_dtype(model_output_data_type) {}
~DTypeTransPass() override = default; ~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override; STATUS Run(schema::MetaGraphT *graph) override;
void set_input_data_dtype(TypeId data_type);
void set_output_data_dtype(TypeId dataType);
private: private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);

View File

@ -39,14 +39,10 @@ struct InferTensor {
class InferShapePass : public GraphPass { class InferShapePass : public GraphPass {
public: public:
InferShapePass() = default; explicit InferShapePass(converter::FmkType fmk_type) : fmk_type_(fmk_type) {}
~InferShapePass() override = default;
~InferShapePass() = default;
STATUS Run(MetaGraphT *graph) override; STATUS Run(MetaGraphT *graph) override;
void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
private: private:
void InitSearchTensor(MetaGraphT *graph); void InitSearchTensor(MetaGraphT *graph);
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index); void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);

View File

@ -34,8 +34,28 @@ class ModelParser {
virtual ~ModelParser() = default; virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) {
const QuantType &quant_type) = 0; auto ret = ParseToFuncGraph(model_file, weight_file, quant_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse to func graph failed : " << ret;
return nullptr;
}
ret = PostAdjust();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Adjust func graph failed : " << ret;
return nullptr;
}
return this->res_graph_;
}
protected:
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) = 0;
virtual int PostAdjust() = 0;
protected:
FuncGraphPtr res_graph_ = nullptr;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -15,6 +15,7 @@
*/ */
#include <vector> #include <vector>
#include "tools/common/tensor_util.h"
#include "tools/converter/ops/while.h" #include "tools/converter/ops/while.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h" #include "abstract/primitive_infer_map.h"
@ -55,7 +56,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
AbstractBasePtrList output; AbstractBasePtrList output;
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) { for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
output.push_back(std::make_shared<abstract::AbstractTensor>(input_args[i]->BuildType(), shape)); auto abstract_tensor = lite::CreateTensorAbstract(shape, input_args[i]->BuildType()->type_id());
MS_EXCEPTION_IF_NULL(abstract_tensor);
output.push_back(abstract_tensor);
} }
return std::make_shared<abstract::AbstractTuple>(output); return std::make_shared<abstract::AbstractTuple>(output);
} }

View File

@ -41,34 +41,34 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file, int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) { const QuantType &quant_type) {
STATUS status = InitOriginModel(model_file, weight_file); STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
func_graph_ptr_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
status = ConvertGraphInputs(); status = ConvertGraphInputs();
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
status = ConvertLayers(); status = ConvertLayers();
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
status = ConvertGraphOutputs(); status = ConvertGraphOutputs();
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); res_graph_->set_attr("graph_name", MakeValue("main_graph"));
func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
return func_graph_ptr_; return RET_OK;
} }
STATUS CaffeModelParser::ConvertLayers() { STATUS CaffeModelParser::ConvertLayers() {
@ -134,7 +134,7 @@ STATUS CaffeModelParser::ConvertLayers() {
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))}; std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end()); op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end());
op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end()); op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end());
auto new_cnode = func_graph_ptr_->NewCNode(op_inputs); auto new_cnode = res_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(layer.name()); new_cnode->set_fullname_with_scope(layer.name());
// convert outputs // convert outputs
@ -194,14 +194,17 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
for (int i = 0; i < caffe_model_.layer_size(); i++) { for (int i = 0; i < caffe_model_.layer_size(); i++) {
auto layer = caffe_model_.layer(i); auto layer = caffe_model_.layer(i);
if (layer.type() == "Input") { if (layer.type() == "Input") {
auto parameter = func_graph_ptr_->add_parameter(); auto parameter = res_graph_->add_parameter();
std::vector<int64_t> shape; std::vector<int64_t> shape;
for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) {
shape.push_back(layer.input_param().shape(0).dim(j)); shape.push_back(layer.input_param().shape(0).dim(j));
} }
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); if (abstract == nullptr) {
parameter->set_abstract(abstract_tensor); MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + std::to_string(i)); parameter->set_name("graph_input-" + std::to_string(i));
nodes_.insert(std::pair(layer.top(0), parameter)); nodes_.insert(std::pair(layer.top(0), parameter));
} }
@ -220,10 +223,13 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
shape.push_back(caffe_model_.input_dim(j)); shape.push_back(caffe_model_.input_dim(j));
} }
} }
auto parameter = func_graph_ptr_->add_parameter(); auto parameter = res_graph_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); if (abstract == nullptr) {
parameter->set_abstract(abstract_tensor); MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + caffe_model_.input(i)); parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter)); nodes_.insert(std::pair(caffe_model_.input(i), parameter));
} }
@ -234,10 +240,18 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
for (int j = 0; j < shape.dim_size(); j++) { for (int j = 0; j < shape.dim_size(); j++) {
shape_vector.push_back(shape.dim(j)); shape_vector.push_back(shape.dim(j));
} }
auto parameter = func_graph_ptr_->add_parameter(); auto parameter = res_graph_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); if (tensor_info == nullptr) {
parameter->set_abstract(abstract_tensor); MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + caffe_model_.input(i)); parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter)); nodes_.insert(std::pair(caffe_model_.input(i), parameter));
} }
@ -265,7 +279,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
auto cnode = nodes_.find(output_node)->second; auto cnode = nodes_.find(output_node)->second;
make_tuple_inputs.emplace_back(cnode); make_tuple_inputs.emplace_back(cnode);
} }
auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_inputs); auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
@ -277,9 +291,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
auto value_node = NewValueNode(return_prim_ptr); auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node); op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode); op_inputs.emplace_back(make_tuple_cnode);
auto cnode = func_graph_ptr_->NewCNode(op_inputs); auto cnode = res_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
func_graph_ptr_->set_return(cnode); res_graph_->set_return(cnode);
} else { } else {
auto returnPrim = std::make_shared<ops::Return>(); auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) { if (returnPrim == nullptr) {
@ -298,9 +312,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
return RET_NOT_FIND_OP; return RET_NOT_FIND_OP;
} }
opInputs.emplace_back(cnode); opInputs.emplace_back(cnode);
auto returnCnode = func_graph_ptr_->NewCNode(opInputs); auto returnCnode = res_graph_->NewCNode(opInputs);
returnCnode->set_fullname_with_scope("Return"); returnCnode->set_fullname_with_scope("Return");
func_graph_ptr_->set_return(returnCnode); res_graph_->set_return(returnCnode);
} }
return RET_OK; return RET_OK;
} }
@ -333,7 +347,7 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v
ConvertShape(layer.blobs(i), &shape); ConvertShape(layer.blobs(i), &shape);
// cal Weight num // cal Weight num
auto parameter = func_graph_ptr_->add_parameter(); auto parameter = res_graph_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
@ -402,17 +416,25 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std::
} }
STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) { STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) {
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
std::vector<int64_t> shape_vector;
if (layer.top_size() == 1) { if (layer.top_size() == 1) {
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
cnode->set_abstract(abstract);
nodes_[layer.top(0)] = cnode; nodes_[layer.top(0)] = cnode;
return RET_OK; return RET_OK;
} }
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
for (int i = 0; i < layer.top_size(); i++) { for (int i = 0; i < layer.top_size(); i++) {
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
@ -421,7 +443,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(i)); auto get_item_value = NewValueNode(MakeValue<int>(i));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value}; std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs); CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
get_item_cnode->set_fullname_with_scope(layer.top(i)); get_item_cnode->set_fullname_with_scope(layer.top(i));
nodes_[layer.top(i)] = get_item_cnode; nodes_[layer.top(i)] = get_item_cnode;
} }
@ -446,4 +468,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
} }
return layer.name(); return layer.name();
} }
int CaffeModelParser::PostAdjust() { return RET_OK; }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -32,8 +32,10 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override; ~CaffeModelParser() override;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override; const QuantType &quant_type) override;
int PostAdjust() override;
private: private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
@ -59,7 +61,6 @@ class CaffeModelParser : public ModelParser {
caffe::NetParameter caffe_weight_; caffe::NetParameter caffe_weight_;
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_; std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
std::unordered_map<std::string, AnfNodePtr> nodes_; std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -45,31 +45,31 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) { const QuantType &quant_type) {
NotSupportOp::GetInstance()->set_fmk_type("ONNX"); NotSupportOp::GetInstance()->set_fmk_type("ONNX");
anf_root_graph_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file); auto status = InitOriginModel(model_file);
if (RET_OK != status) { if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "init origin model failed."; MS_LOG(ERROR) << "init origin model failed.";
return nullptr; return status;
} }
status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node");
if (RET_OK != status) { if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert onnx graph failed."; MS_LOG(ERROR) << "convert onnx graph failed.";
return nullptr; return status;
} }
static auto root_func_manager = Manage(anf_root_graph_); static auto root_func_manager = Manage(res_graph_);
for (auto &subgraph : all_subgraphs_) { for (auto &subgraph : all_subgraphs_) {
subgraph->set_manager(root_func_manager); subgraph->set_manager(root_func_manager);
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
} }
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); res_graph_->set_attr("graph_name", MakeValue("main_graph"));
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
return anf_root_graph_; return RET_OK;
} }
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
@ -88,9 +88,9 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
onnx_root_graph_ = onnx_model_.graph(); onnx_root_graph_ = onnx_model_.graph();
if (OnnxNodeParser::opset_version() > 15) { if (OnnxNodeParser::opset_version() > 15) {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
} else { } else {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
} }
return RET_OK; return RET_OK;
} }
@ -170,13 +170,16 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c
<< static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()); << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
return RET_ERROR; return RET_ERROR;
} }
auto type_ptr = TypeIdToType(data_type);
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
auto onnx_shape = input_value.type().tensor_type().shape().dim(); auto onnx_shape = input_value.type().tensor_type().shape().dim();
std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
[](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); }); [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
std::replace(shape_vector.begin(), shape_vector.end(), 0, -1); std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name(input_value.name()); parameter->set_name(input_value.name());
anf_nodes_map->emplace(input_value.name(), parameter); anf_nodes_map->emplace(input_value.name(), parameter);
@ -490,17 +493,23 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (onnx_node.output_size() == 1) { if (onnx_node.output_size() == 1) {
auto type_ptr = TypeIdToType(kNumberTypeFloat32); auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
std::vector<int64_t> shape_vector; if (abstract_tensor == nullptr) {
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
cnode->set_abstract(abstract_tensor);
anf_nodes_map->emplace(onnx_node.output(0), cnode); anf_nodes_map->emplace(onnx_node.output(0), cnode);
} else { } else {
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
int op_idx = 0; int op_idx = 0;
for (const auto &output_name : onnx_node.output()) { for (const auto &output_name : onnx_node.output()) {
std::vector<int64_t> shape_vector; auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
auto type_ptr = TypeIdToType(kNumberTypeFloat32); if (abstract_tensor == nullptr) {
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
@ -687,7 +696,11 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
return nullptr; return nullptr;
} }
auto const_node = anf_graph->add_parameter(); auto const_node = anf_graph->add_parameter();
auto const_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>()); auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
if (const_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
const_node->set_abstract(const_abstract); const_node->set_abstract(const_abstract);
int *tensor_data = new (std::nothrow) int[1]; int *tensor_data = new (std::nothrow) int[1];
if (tensor_data == nullptr) { if (tensor_data == nullptr) {
@ -834,9 +847,16 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v
for (int i = 0; i < act_output_num; i++) { for (int i = 0; i < act_output_num; i++) {
// tensor_array need as root while input // tensor_array need as root while input
auto while_tensor_array_input = anf_root_graph->add_parameter(); auto while_tensor_array_input = anf_root_graph->add_parameter();
std::vector<int64_t> shape_vector; auto tensor_info = CreateTensorInfo(nullptr, 0, {}, kObjectTypeTensorType);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vector); if (tensor_info == nullptr) {
auto tensor_info = std::make_shared<tensor::Tensor>(kObjectTypeTensorType, shape_vector); MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract_tensor = tensor_info->ToAbstract();
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
while_tensor_array_input->set_abstract(abstract_tensor); while_tensor_array_input->set_abstract(abstract_tensor);
while_tensor_array_input->set_default_param(tensor_info); while_tensor_array_input->set_default_param(tensor_info);
while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray"); while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray");
@ -975,7 +995,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
auto input_paramter = cond_graph->add_parameter(); auto input_paramter = cond_graph->add_parameter();
input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter"); input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs(); auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs();
auto input_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>()); auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
if (input_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
input_paramter->set_abstract(input_abstract); input_paramter->set_abstract(input_abstract);
if (i == 0) { if (i == 0) {
auto zero_parameter = CreateConstParamter(cond_graph, 0); auto zero_parameter = CreateConstParamter(cond_graph, 0);
@ -987,7 +1011,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
MS_LOG(ERROR) << "new cnode error"; MS_LOG(ERROR) << "new cnode error";
return RET_ERROR; return RET_ERROR;
} }
auto less_abstract = std::make_shared<abstract::AbstractTensor>(kBool, std::vector<int64_t>()); auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
if (less_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
less_cnode->set_abstract(less_abstract); less_cnode->set_abstract(less_abstract);
less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode"); less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
} }
@ -1020,12 +1048,11 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
MS_LOG(ERROR) << "quant param type don't support."; MS_LOG(ERROR) << "quant param type don't support.";
return RET_NOT_SUPPORT; return RET_NOT_SUPPORT;
} }
std::vector<int64_t> shape_vector; auto parameter_node = res_graph_->add_parameter();
auto parameter_node = anf_root_graph_->add_parameter(); auto abstract_tensor = CreateTensorAbstract({}, type);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
if (abstract_tensor == nullptr) { if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "new abstract_tensor failed"; MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_MEMORY_FAILED; return RET_ERROR;
} }
parameter_node->set_abstract(abstract_tensor); parameter_node->set_abstract(abstract_tensor);
parameter_node->set_name(name); parameter_node->set_name(name);
@ -1051,9 +1078,12 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr &parameter_node, c
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type()); MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
return RET_ERROR; return RET_ERROR;
} }
auto type_ptr = TypeIdToType(data_type);
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end()); std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter_node->set_abstract(abstract_tensor); parameter_node->set_abstract(abstract_tensor);
parameter_node->set_name(tensor.name()); parameter_node->set_name(tensor.name());
@ -1142,5 +1172,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type
} }
return iter->second; return iter->second;
} }
int OnnxModelParser::PostAdjust() { return 0; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -40,14 +40,17 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default; ~OnnxModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override; const QuantType &quant_type) override;
int PostAdjust() override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
const tensor::TensorPtr &param_value_lite); const tensor::TensorPtr &param_value_lite);
STATUS InitOriginModel(const std::string &model_file);
private: private:
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs, std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
const std::string &root_node_name); const std::string &root_node_name);
@ -94,7 +97,6 @@ class OnnxModelParser : public ModelParser {
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_; std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_; std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
FuncGraphPtr anf_root_graph_ = nullptr;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -418,18 +418,17 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
type = TensorFlowUtils::GetTFDataType(attr_value.type()); type = TensorFlowUtils::GetTFDataType(attr_value.type());
} }
std::vector<int> shape; std::vector<int64_t> shape;
if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
auto &shape_attr = attr_value.shape(); auto &shape_attr = attr_value.shape();
for (int i = 0; i < shape_attr.dim_size(); ++i) { for (int i = 0; i < shape_attr.dim_size(); ++i) {
shape.push_back(shape_attr.dim(i).size()); shape.push_back(shape_attr.dim(i).size());
} }
} }
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
MS_LOG(INFO) << "Found value attr, means it has default value"; MS_LOG(INFO) << "Found value attr, means it has default value";
auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector); auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "convert const tensor failed."; MS_LOG(ERROR) << "convert const tensor failed.";
return status; return status;
@ -438,10 +437,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
} }
auto type_ptr = TypeIdToType(type == kNumberTypeInt64 ? kNumberTypeInt32 : type); type = (type == kNumberTypeInt64) ? kNumberTypeInt32 : type;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract_tensor = CreateTensorAbstract(shape, type);
if (abstract_tensor == nullptr) { if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr"; MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR; return RET_ERROR;
} }
parameter->set_name(node.name()); parameter->set_name(node.name());
@ -474,51 +473,51 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
} }
return RET_OK; return RET_OK;
} }
FuncGraphPtr paserTfFuction() { return nullptr; }
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) { const QuantType &quantType) {
NotSupportOp::GetInstance()->set_fmk_type("TF"); NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb"); auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) { if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return status;
} }
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
anf_root_graph_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) { if (res_graph_ == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr"; MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return status;
} }
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); res_graph_->set_attr("graph_name", MakeValue("main_graph"));
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
for (int i = 0; i < tf_root_graph_->node_size(); i++) { for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i); auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def; tf_root_graph_nodes_[node_def.name()] = &node_def;
} }
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
bool success_flag = true; bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) { for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i); auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
if (status != RET_OK) { if (status != RET_OK) {
success_flag = false; success_flag = false;
@ -526,7 +525,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
} }
if (!success_flag) { if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed."; MS_LOG(ERROR) << "Convert ops failed.";
return nullptr; return RET_ERROR;
} }
if (!nodes_with_null_input_.empty()) { if (!nodes_with_null_input_.empty()) {
@ -534,7 +533,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Connect null inputs failed."; MS_LOG(ERROR) << "Connect null inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
} }
@ -542,17 +541,17 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed."; MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
status = ConvertSubgraph(); status = ConvertSubgraph();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert subgraph failed."; MS_LOG(ERROR) << "Convert subgraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
return anf_root_graph_; return RET_OK;
} }
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map, STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
@ -746,7 +745,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
MS_LOG(ERROR) << "while cond body size error"; MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR; return RET_ERROR;
} }
static auto root_func_manager = Manage(anf_root_graph_); static auto root_func_manager = Manage(res_graph_);
for (auto &kv : first_func_map) { for (auto &kv : first_func_map) {
auto control_flow_node = kv.first; auto control_flow_node = kv.first;
@ -758,7 +757,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
auto second_value_node = NewValueNode(second_sub_graph); auto second_value_node = NewValueNode(second_sub_graph);
auto inputs = control_flow_node->inputs(); auto inputs = control_flow_node->inputs();
inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node}); inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node});
auto new_node = anf_root_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update auto new_node = res_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update
if (new_node == nullptr) { if (new_node == nullptr) {
MS_LOG(ERROR) << "new node failed"; MS_LOG(ERROR) << "new node failed";
return RET_ERROR; return RET_ERROR;
@ -812,43 +811,46 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
if (output_size == 0) { if (output_size == 0) {
return RET_OK; return RET_OK;
} else if (output_size == 1) { } else if (output_size == 1) {
auto type = kFloat32; auto type = kNumberTypeFloat32;
std::vector<int64_t> shape_vector;
if (IsTensorListOp(anf_node)) { if (IsTensorListOp(anf_node)) {
type = TypeIdToType(kObjectTypeTensorType); type = kObjectTypeTensorType;
} }
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector); auto abstract_tensor = CreateTensorAbstract({}, type);
if (abstract == nullptr) { if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed"; MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR; return RET_ERROR;
} }
anf_node->set_abstract(abstract); anf_node->set_abstract(abstract_tensor);
anf_node_map->insert(std::pair(op.name(), anf_node)); anf_node_map->insert(std::pair(op.name(), anf_node));
} else { } else {
AbstractBasePtrList abstractList; AbstractBasePtrList abstract_list;
for (int output_idx = 0; output_idx < output_size; output_idx++) { for (int output_idx = 0; output_idx < output_size; output_idx++) {
std::vector<int64_t> shape_vector; auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); if (abstract_tensor == nullptr) {
auto tupleGetItemPrimPtr = std::make_shared<ops::TupleGetItem>(); MS_LOG(ERROR) << "Create tensor abstarct failed";
if (tupleGetItemPrimPtr == nullptr) { return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); auto get_item_value = NewValueNode(MakeValue<int>(output_idx));
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; std::vector<AnfNodePtr> inputs{tuple_get_item_prim, anf_node, get_item_value};
CNodePtr getItemCNode = anf_graph->NewCNode(inputs); CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) { if (get_item_abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed"; MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR; return RET_ERROR;
} }
getItemCNode->set_abstract(abstract); get_item_cnode->set_abstract(get_item_abstract);
getItemCNode->set_fullname_with_scope(output_item_name); get_item_cnode->set_fullname_with_scope(output_item_name);
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode));
} }
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
} }
return RET_OK; return RET_OK;
} }
@ -1022,7 +1024,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
} }
} }
} }
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); auto status = MakeAnfGraphOutputs(&output_nodes, res_graph_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error"; MS_LOG(ERROR) << "make anf graph outputs node error";
return status; return status;
@ -1070,5 +1072,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
} }
return RET_OK; return RET_OK;
} }
int TFModelParser::PostAdjust() { return 0; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -36,9 +36,11 @@ namespace lite {
class TFModelParser : public ModelParser { class TFModelParser : public ModelParser {
public: public:
TFModelParser() = default; TFModelParser() = default;
~TFModelParser() = default; ~TFModelParser() override = default;
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
int PostAdjust() override;
private: private:
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
@ -87,7 +89,6 @@ class TFModelParser : public ModelParser {
STATUS ConnectNullInput(); STATUS ConnectNullInput();
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;

View File

@ -43,46 +43,46 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
return tflite::UnPackModel(tflite_model_buf_); return tflite::UnPackModel(tflite_model_buf_);
} }
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file, int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) { const QuantType &quant_type) {
// load graph // load graph
tflite_model_ = ReadTfliteModel(model_file.c_str()); tflite_model_ = ReadTfliteModel(model_file.c_str());
if (tflite_model_ == nullptr) { if (tflite_model_ == nullptr) {
MS_LOG(ERROR) << "read tflite model failed"; MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr; return RET_GRAPH_FILE_ERR;
} }
if (tflite_model_->subgraphs.size() != 1) { if (tflite_model_->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed"; MS_LOG(ERROR) << "read tflite model subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr; return RET_GRAPH_FILE_ERR;
} }
func_graph_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE))); res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
auto status = ConvertGraphInputs(); auto status = ConvertGraphInputs();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed."; MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
status = ConvertOps(); status = ConvertOps();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed."; MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
status = ConvertGraphOutputs(); status = ConvertGraphOutputs();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed."; MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr; return status;
} }
func_graph_->set_attr("graph_name", MakeValue("main_graph")); res_graph_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_; return RET_OK;
} }
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) { std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertOps() {
} else { } else {
tensor_name = GetTensorName(i, tflite_op_type, op_name); tensor_name = GetTensorName(i, tflite_op_type, op_name);
} }
auto parameter = func_graph_->add_parameter(); auto parameter = res_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name); status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
@ -168,7 +168,7 @@ STATUS TfliteModelParser::ConvertOps() {
op_inputs.emplace_back(parameter); op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter)); nodes_.insert(std::pair(input_idx, parameter));
} }
auto new_cnode = func_graph_->NewCNode(op_inputs); auto new_cnode = res_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(op_name); new_cnode->set_fullname_with_scope(op_name);
// parse outputs // parse outputs
@ -284,13 +284,16 @@ STATUS TfliteModelParser::ConvertGraphInputs() {
if (tflite_graph_input < 0) { if (tflite_graph_input < 0) {
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
} }
auto parameter = func_graph_->add_parameter(); auto parameter = res_graph_->add_parameter();
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); }); [](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
nodes_.insert(std::pair(tflite_graph_input, parameter)); nodes_.insert(std::pair(tflite_graph_input, parameter));
@ -318,7 +321,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
} }
make_tuple_inputs.emplace_back(cnode); make_tuple_inputs.emplace_back(cnode);
} }
auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
@ -330,9 +333,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
auto value_node = NewValueNode(return_prim_ptr); auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node); op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode); op_inputs.emplace_back(make_tuple_cnode);
auto cnode = func_graph_->NewCNode(op_inputs); auto cnode = res_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
func_graph_->set_return(cnode); res_graph_->set_return(cnode);
} else { } else {
auto returnPrim = std::make_shared<ops::Return>(); auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) { if (returnPrim == nullptr) {
@ -350,9 +353,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
return RET_NOT_FIND_OP; return RET_NOT_FIND_OP;
} }
op_inputs.emplace_back(cnode); op_inputs.emplace_back(cnode);
auto returnCnode = func_graph_->NewCNode(op_inputs); auto returnCnode = res_graph_->NewCNode(op_inputs);
returnCnode->set_fullname_with_scope("Return"); returnCnode->set_fullname_with_scope("Return");
func_graph_->set_return(returnCnode); res_graph_->set_return(returnCnode);
} }
return RET_OK; return RET_OK;
} }
@ -436,8 +439,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); }); [](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
dst_cnode->set_abstract(abstract_tensor);
nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
} else { } else {
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
@ -450,8 +457,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); }); [](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
@ -460,7 +471,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(op_idx)); auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
nodes_.insert(std::pair(output_idx, get_item_cnode)); nodes_.insert(std::pair(output_idx, get_item_cnode));
op_idx++; op_idx++;
@ -469,4 +480,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
} }
return RET_OK; return RET_OK;
} }
int TfliteModelParser::PostAdjust() { return 0; }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -32,13 +32,14 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default; ~TfliteModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override; const QuantType &quant_type) override;
int PostAdjust() override;
private: private:
std::unordered_map<int, AnfNodePtr> nodes_; std::unordered_map<int, AnfNodePtr> nodes_;
std::unique_ptr<tflite::ModelT> tflite_model_; std::unique_ptr<tflite::ModelT> tflite_model_;
FuncGraphPtr func_graph_;
char *tflite_model_buf_ = nullptr; char *tflite_model_buf_ = nullptr;
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr &parameter, STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr &parameter,

View File

@ -399,6 +399,24 @@ int CheckIfCNodeIsNull(const CNodePtr &node) {
return lite::RET_OK; return lite::RET_OK;
} }
int CheckIfParameterIsNull(const ParameterPtr &node) {
if (node == nullptr) {
MS_LOG(ERROR) << "The Parameter is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
return lite::RET_OK;
}
int CheckIfValueNodeIsNull(const ValueNodePtr &node) {
if (node == nullptr) {
MS_LOG(ERROR) << "The ValueNode is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
return lite::RET_OK;
}
int CheckIfVarIsNull(const VarPtr &var) { int CheckIfVarIsNull(const VarPtr &var) {
if (var == nullptr) { if (var == nullptr) {
MS_LOG(ERROR) << "The Var is null."; MS_LOG(ERROR) << "The Var is null.";

View File

@ -57,6 +57,10 @@ int CheckIfAnfNodeIsNull(const AnfNodePtr &node);
int CheckIfCNodeIsNull(const CNodePtr &node); int CheckIfCNodeIsNull(const CNodePtr &node);
int CheckIfParameterIsNull(const ParameterPtr &node);
int CheckIfValueNodeIsNull(const ValueNodePtr &node);
int CheckIfVarIsNull(const VarPtr &var); int CheckIfVarIsNull(const VarPtr &var);
int CheckInputSize(const CNodePtr &node, int size); int CheckInputSize(const CNodePtr &node, int size);

View File

@ -0,0 +1,294 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/mul_add_fusion.h"
#include <memory>
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/add_fusion.h"
#include "ops/fusion/scale_fusion.h"
#include "ops/op_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kMulInputsLength = 3;
constexpr size_t kAddInputsLength = 3;
} // namespace
const BaseRef MulAddFusion::DefinePattern() const {
auto mul_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
auto add_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
return VectorRef({add_var, mul_var});
}
bool MulAddFusion::ScaleInputShapeValid() const {
MS_ASSERT(scale_tensor_ != nullptr);
MS_ASSERT(bias_tensor_ != nullptr);
auto scale_shape = scale_tensor_->shape_c();
auto offset_shape = bias_tensor_->shape_c();
if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) {
return false;
}
size_t rank_diff = mul_input_shape_.size() - scale_shape.size();
for (size_t i = 0; i < scale_shape.size(); ++i) {
if (mul_input_shape_[i + rank_diff] != scale_shape[i]) {
return false;
}
}
if (scale_shape != offset_shape) {
return false;
}
return true;
}
bool MulAddFusion::CheckMulNode(const FuncGraphPtr &func_graph) const {
MS_ASSERT(func_graph != nullptr);
if (mul_anode_ == nullptr) {
return false;
}
if (IsMultiOutputTensors(func_graph, mul_anode_)) {
MS_LOG(DEBUG) << "Mul op has multi-output";
return false;
}
auto mul_node = mul_anode_->cast<CNodePtr>();
if (!CheckPrimitiveType(mul_node, prim::kPrimMulFusion)) {
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
return false;
}
auto mul_primitive = GetValueNode<std::shared_ptr<ops::MulFusion>>(mul_node->input(0));
MS_ASSERT(mul_primitive != nullptr);
auto mul_act_type = mul_primitive->get_activation_type();
if (mul_act_type != ActivationType::NO_ACTIVATION) {
MS_LOG(DEBUG) << "Only support mul node with no activation";
return false;
}
if (CheckIfCNodeIsNull(mul_node) != lite::RET_OK || CheckInputSize(mul_node, kMulInputsLength) != lite::RET_OK) {
MS_LOG(DEBUG) << "Mul op is null or has error input size";
return false;
}
// find mul's const input and mul input
AnfNodePtr mul_pre_input_node = nullptr;
AnfNodePtr mul_pre_const_node = nullptr;
auto mul_pre_node_1 = mul_node->input(1);
if (CheckIfAnfNodeIsNull(mul_pre_node_1) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
return false;
}
auto mul_pre_node_2 = mul_node->input(2);
if (CheckIfAnfNodeIsNull(mul_pre_node_2) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
return false;
}
if (utils::isa<CNodePtr>(mul_pre_node_1) && !utils::isa<CNodePtr>(mul_pre_node_2)) {
mul_pre_input_node = mul_pre_node_1;
mul_pre_const_node = mul_pre_node_2;
} else if (!utils::isa<CNodePtr>(mul_pre_node_1) && utils::isa<CNodePtr>(mul_pre_node_2)) {
mul_pre_input_node = mul_pre_node_1;
mul_pre_const_node = mul_pre_node_2;
} else {
MS_LOG(DEBUG) << "Mul op should has a cnode input and a const input";
return false;
}
// check mul's const input
tensor::TensorPtr mul_tensor = nullptr;
if (utils::isa<ParameterPtr>(mul_pre_const_node)) {
auto mul_bias_node = mul_pre_const_node->cast<ParameterPtr>();
MS_ASSERT(mul_bias_node != nullptr);
if (!mul_bias_node->has_default()) {
MS_LOG(DEBUG) << "Const input of mul op should has data";
return false;
}
mul_tensor = mul_bias_node->default_param()->cast<tensor::TensorPtr>();
} else if (utils::isa<ValueNodePtr>(mul_pre_const_node)) {
auto mul_bias_node = mul_pre_const_node->cast<ValueNodePtr>();
MS_ASSERT(mul_bias_node != nullptr);
if (mul_bias_node->value() == nullptr) {
MS_LOG(DEBUG) << "Const input of mul op should has data";
return false;
}
mul_tensor = mul_bias_node->value()->cast<tensor::TensorPtr>();
} else {
MS_ASSERT(false);
}
if (mul_tensor == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
mul_input_anode_ = mul_pre_input_node;
mul_const_anode_ = mul_pre_const_node;
scale_tensor_ = mul_tensor;
return true;
}
bool MulAddFusion::CheckAddNode() const {
if (add_anode_ == nullptr) {
return false;
}
auto add_cnode = add_anode_->cast<CNodePtr>();
if (CheckIfCNodeIsNull(add_cnode) != lite::RET_OK || CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) {
MS_LOG(DEBUG) << "Add op is null or has error input size";
return false;
}
if (!CheckPrimitiveType(add_cnode, prim::kPrimAddFusion)) {
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
return false;
}
auto add_primitive = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
MS_ASSERT(add_primitive != nullptr);
auto add_act_type = add_primitive->get_activation_type();
if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 &&
add_act_type != ActivationType::NO_ACTIVATION) {
MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation";
return false;
}
scale_act_type_ = add_act_type;
// find add's const input and mul input
AnfNodePtr add_pre_input_node = nullptr;
AnfNodePtr add_pre_const_node = nullptr;
auto add_pre_node_1 = add_cnode->input(1);
if (CheckIfAnfNodeIsNull(add_pre_node_1) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
return false;
}
auto add_pre_node_2 = add_cnode->input(2);
if (CheckIfAnfNodeIsNull(add_pre_node_2) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
return false;
}
if (utils::isa<CNodePtr>(add_pre_node_1) && !utils::isa<CNodePtr>(add_pre_node_2)) {
add_pre_input_node = add_pre_node_1;
add_pre_const_node = add_pre_node_2;
} else if (!utils::isa<CNodePtr>(add_pre_node_1) && utils::isa<CNodePtr>(add_pre_node_2)) {
add_pre_input_node = add_pre_node_2;
add_pre_const_node = add_pre_node_1;
} else {
MS_LOG(DEBUG) << "Add op should has a cnode input and a const input";
return false;
}
// check add's const input
tensor::TensorPtr add_tensor = nullptr;
if (utils::isa<ParameterPtr>(add_pre_const_node)) {
auto add_bias_node = add_pre_const_node->cast<ParameterPtr>();
MS_ASSERT(add_bias_node != nullptr);
if (!add_bias_node->has_default()) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
add_tensor = add_bias_node->default_param()->cast<tensor::TensorPtr>();
} else if (utils::isa<ValueNodePtr>(add_pre_const_node)) {
auto add_bias_node = add_pre_const_node->cast<ValueNodePtr>();
MS_ASSERT(add_bias_node != nullptr);
if (add_bias_node->value() == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
add_tensor = add_bias_node->value()->cast<tensor::TensorPtr>();
} else {
MS_ASSERT(false);
}
if (add_tensor == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
mul_anode_ = add_pre_input_node;
add_const_anode_ = add_pre_const_node;
bias_tensor_ = add_tensor;
return true;
}
bool MulAddFusion::GetMulInputShape() const {
MS_ASSERT(mul_input_anode_ != nullptr);
ShapeVector mul_input_shape;
AbstractBasePtr mul_input_abstract = nullptr;
if (utils::isa<ParameterPtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<ParameterPtr>();
MS_ASSERT(mul_bias_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else if (utils::isa<ValueNodePtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<ValueNodePtr>();
MS_ASSERT(mul_input_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else if (utils::isa<CNodePtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<CNodePtr>();
MS_ASSERT(mul_input_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else {
MS_ASSERT(false);
}
if (mul_input_abstract == nullptr) {
MS_LOG(DEBUG) << "Mul input node has no abstract";
return false;
}
if (!utils::isa<abstract::AbstractTensorPtr>(mul_input_abstract)) {
MS_LOG(DEBUG) << "Abstract of mul input node should be AbstractTensor";
return false;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(mul_input_abstract);
MS_ASSERT(abstract_tensor != nullptr);
MS_ASSERT(abstract_tensor->BuildShape() != nullptr);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(DEBUG) << "BuildShape of abstract of mul input node should be ShapePtr";
return false;
}
mul_input_shape_ = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return true;
}
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
add_anode_ = node;
if (!CheckAddNode()) {
MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope();
return nullptr;
}
MS_ASSERT(mul_anode_ != nullptr);
MS_ASSERT(bias_tensor_ != nullptr);
MS_ASSERT(add_const_anode_ != nullptr);
if (!CheckMulNode(func_graph)) {
MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_anode_->fullname_with_scope();
return nullptr;
}
MS_ASSERT(mul_input_anode_ != nullptr);
MS_ASSERT(scale_tensor_ != nullptr);
MS_ASSERT(mul_const_anode_ != nullptr);
if (!GetMulInputShape()) {
MS_LOG(DEBUG) << "Get input shape of mul op failed";
return nullptr;
}
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
if (!ScaleInputShapeValid()) {
MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed";
return nullptr;
}
// create scale primitive
auto scale_primitive = new (std::nothrow) mindspore::ops::ScaleFusion();
if (scale_primitive == nullptr) {
MS_LOG(ERROR) << "new scale primitive failed";
return nullptr;
}
scale_primitive->set_activation_type(scale_act_type_);
scale_primitive->set_axis(0 - bias_tensor_->shape_c().size());
// create scale op
auto scale_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(scale_primitive),
{mul_input_anode_, mul_const_anode_, add_const_anode_});
return scale_node;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace opt {
class MulAddFusion : public PatternProcessPass {
public:
explicit MulAddFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
: PatternProcessPass(name, multigraph) {}
~MulAddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
bool CheckMulNode(const FuncGraphPtr &func_graph) const;
bool CheckAddNode() const;
bool GetMulInputShape() const;
bool ScaleInputShapeValid() const;
private:
mutable AnfNodePtr mul_anode_ = nullptr;
mutable AnfNodePtr mul_input_anode_ = nullptr;
mutable AnfNodePtr mul_const_anode_ = nullptr;
mutable ShapeVector mul_input_shape_;
mutable AnfNodePtr add_anode_ = nullptr;
mutable AnfNodePtr add_const_anode_ = nullptr;
mutable tensor::TensorPtr scale_tensor_ = nullptr;
mutable tensor::TensorPtr bias_tensor_ = nullptr;
mutable ActivationType scale_act_type_ = ActivationType::NO_ACTIVATION;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_

View File

@ -256,11 +256,12 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun
auto parameter = func_graph->add_parameter(); auto parameter = func_graph->add_parameter();
parameter->set_name(name); parameter->set_name(name);
std::vector<int64_t> shape_vector(shape.begin(), shape.end()); std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector); auto abstract = lite::CreateTensorAbstract(shape_vector, type);
if (abstract_tensor == nullptr) { if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr; return nullptr;
} }
parameter->set_abstract(abstract_tensor); parameter->set_abstract(abstract);
auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector); auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
if (gate_weight_default == nullptr) { if (gate_weight_default == nullptr) {

View File

@ -502,13 +502,12 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
return nullptr; return nullptr;
} }
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value}); CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
std::vector<int64_t> shape_vector; auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); if (abstract == nullptr) {
if (abstract_tensor == nullptr) { MS_LOG(ERROR) << "Create tensor abstarct failed";
MS_LOG(ERROR) << "create abstract_tensor failed";
return nullptr; return nullptr;
} }
get_item_cnode->set_abstract(abstract_tensor); get_item_cnode->set_abstract(abstract);
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
std::to_string(item_index)); std::to_string(item_index));
return get_item_cnode; return get_item_cnode;
@ -581,13 +580,12 @@ STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int o
MS_ASSERT(cnode != nullptr); MS_ASSERT(cnode != nullptr);
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
for (int i = 0; i < output_num; ++i) { for (int i = 0; i < output_num; ++i) {
std::vector<int64_t> shape_vector; auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); if (abstract == nullptr) {
if (abstract_tensor == nullptr) { MS_LOG(ERROR) << "Create tensor abstarct failed";
MS_LOG(ERROR) << "create abstract_tensor failed";
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract_tensor); abstract_list.emplace_back(abstract);
} }
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
if (abstract_tuple == nullptr) { if (abstract_tuple == nullptr) {

View File

@ -23,6 +23,7 @@
#include "ops/return.h" #include "ops/return.h"
#include "ops/tuple_get_item.h" #include "ops/tuple_get_item.h"
#include "tools/converter/ops/while.h" #include "tools/converter/ops/while.h"
#include "tools/common/tensor_util.h"
namespace { namespace {
mindspore::ValueNodePtr GetWhileAnfPrim() { mindspore::ValueNodePtr GetWhileAnfPrim() {
@ -207,9 +208,13 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
auto node_users = manager->node_users()[node]; auto node_users = manager->node_users()[node];
for (auto &node_user : node_users) { for (auto &node_user : node_users) {
// new getitem // new getitem
AbstractBasePtrList abstractList; AbstractBasePtrList abstract_list;
std::vector<int64_t> shape_vector; auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
@ -225,12 +230,12 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue}; std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
CNodePtr get_item_node = fg_->NewCNode(inputs); CNodePtr get_item_node = fg_->NewCNode(inputs);
std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); auto get_item_node_abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) { if (get_item_node_abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed"; MS_LOG(ERROR) << "Create get_item_node_abstract failed";
return RET_NULL_PTR; return RET_ERROR;
} }
get_item_node->set_abstract(abstract); get_item_node->set_abstract(get_item_node_abstract);
get_item_node->set_fullname_with_scope(output_item_name); get_item_node->set_fullname_with_scope(output_item_name);
// set // set
if (fg_->nodes().contains(node_user.first)) { if (fg_->nodes().contains(node_user.first)) {

View File

@ -22,6 +22,7 @@
#include "src/tensor.h" #include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
namespace mindspore::opt { namespace mindspore::opt {
@ -101,13 +102,16 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
return false; return false;
} }
auto type_id = static_cast<TypeId>(weight_value->data_type()); auto type_id = static_cast<TypeId>(weight_value->data_type());
auto type_ptr = TypeIdToType(type_id);
auto shape = weight_value->shape(); auto shape = weight_value->shape();
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); }); [](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
weight_node->set_abstract(abstract_tensor); if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
} }
} }
return true; return true;

View File

@ -21,6 +21,7 @@
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "src/common/tensor_util.h"
#include "src/ops/populate/populate_register.h" #include "src/ops/populate/populate_register.h"
#include "src/ops/ops_utils.h" #include "src/ops/ops_utils.h"
#include "src/runtime/infer_manager.h" #include "src/runtime/infer_manager.h"
@ -28,19 +29,6 @@
namespace mindspore::opt { namespace mindspore::opt {
namespace { namespace {
constexpr size_t INITIAL_SIZE = 1024; constexpr size_t INITIAL_SIZE = 1024;
tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
std::vector<int> shape(tensor->shape());
std::vector<int64_t> shape_vector;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
return nullptr;
}
return tensor_info;
}
bool IsSpecialType(const CNodePtr &cnode) { bool IsSpecialType(const CNodePtr &cnode) {
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
@ -75,21 +63,14 @@ STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr
} }
} // namespace } // namespace
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { abstract::AbstractBasePtr InferShapePass::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor); MS_ASSERT(nullptr != tensor);
std::vector<int> shape(tensor->shape()); auto shape = tensor->shape();
auto type_id = static_cast<TypeId>(tensor->data_type()); auto type_id = static_cast<TypeId>(tensor->data_type());
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector(shape.begin(), shape.end()); std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, type_id);
if (new_abstract == nullptr) {
MS_LOG(ERROR) << "new AbstractTensor failed";
return nullptr;
}
auto tensor_info = NewTensorInfo(tensor);
if (tensor_info == nullptr) { if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed"; MS_LOG(DEBUG) << "Create tensor info failed";
return nullptr; return nullptr;
} }
@ -112,8 +93,12 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
return nullptr; return nullptr;
} }
} }
new_abstract->set_value(tensor_info); auto abstract = tensor_info->ToAbstract();
return new_abstract; if (abstract == nullptr) {
MS_LOG(DEBUG) << "Create tensor abstarct failed";
return nullptr;
}
return abstract;
} }
STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) { STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
@ -143,8 +128,6 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
std::vector<int32_t> shape; std::vector<int32_t> shape;
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int32_t>(value); }); [](const int64_t &value) { return static_cast<int32_t>(value); });
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector); auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector);
if (parameter->has_default()) { if (parameter->has_default()) {
auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
@ -155,7 +138,11 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
return RET_ERROR; return RET_ERROR;
} }
} }
new_abstract->set_value(new_tensor_info); auto new_abstract = new_tensor_info->ToAbstract();
if (new_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(new_abstract); parameter->set_abstract(new_abstract);
return RET_OK; return RET_OK;
} }
@ -304,7 +291,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
} }
if (output_tensors.size() == 1) { if (output_tensors.size() == 1) {
auto tensor = output_tensors.front(); auto tensor = output_tensors.front();
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) { if (new_abstract == nullptr) {
return RET_ERROR; return RET_ERROR;
} }
@ -313,7 +300,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
for (size_t i = 0; i < output_tensors.size(); i++) { for (size_t i = 0; i < output_tensors.size(); i++) {
auto tensor = output_tensors.front(); auto tensor = output_tensors.front();
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) { if (new_abstract == nullptr) {
return RET_ERROR; return RET_ERROR;
} }

View File

@ -36,7 +36,7 @@ class InferShapePass : public Pass {
private: private:
void FreeTensors(std::vector<lite::Tensor *> *tensors); void FreeTensors(std::vector<lite::Tensor *> *tensors);
abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors);
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
STATUS SetParameterAbstract(const ParameterPtr &parameter); STATUS SetParameterAbstract(const ParameterPtr &parameter);

View File

@ -179,23 +179,23 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
if (!utils::isa<ValueNodePtr>(anf_node)) { if (!utils::isa<ValueNodePtr>(anf_node)) {
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto valueNode = anf_node->cast<ValueNodePtr>(); auto value_node = anf_node->cast<ValueNodePtr>();
if (valueNode->abstract() == nullptr) { if (value_node->abstract() == nullptr) {
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueNode->abstract()); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract());
if (abstractTensor == nullptr) { if (abstract_tensor == nullptr) {
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto value = abstractTensor->GetValueTrack(); auto value = abstract_tensor->GetValueTrack();
if (value != nullptr && value->isa<tensor::Tensor>()) { if (value != nullptr && value->isa<tensor::Tensor>()) {
if (abstractTensor->element() == nullptr) { if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor->element() is nullptr."; MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
return RET_ERROR; return RET_ERROR;
} }
auto typePtr = abstractTensor->element()->GetTypeTrack(); auto type_ptr = abstract_tensor->element()->GetTypeTrack();
if (typePtr->type_id() == kNumberTypeInt64) { if (type_ptr->type_id() == kNumberTypeInt64) {
auto shape_vector = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector); auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector);
auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c()); auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c());
auto src_tensor_info = value->cast<tensor::TensorPtr>(); auto src_tensor_info = value->cast<tensor::TensorPtr>();
@ -204,10 +204,10 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) { for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
dest_data_buf[i] = src_data_buf[i]; dest_data_buf[i] = src_data_buf[i];
} }
abstractTensor->set_value(dest_tensor_info); abstract_tensor->set_value(dest_tensor_info);
abstractTensor->set_type(TypeIdToType(kNumberTypeInt32)); abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
valueNode->set_value(dest_tensor_info); value_node->set_value(dest_tensor_info);
} }
} }
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;

View File

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "ops/transpose.h" #include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_MS; using mindspore::lite::converter::FmkType_MS;
@ -92,9 +93,20 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
auto prim = std::make_shared<ops::Transpose>(); auto prim = std::make_shared<ops::Transpose>();
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
auto type_ptr = TypeIdToType(kTypeUnknown); if (!weight_node->has_default()) {
std::vector<int64_t> shape_vector; MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
auto abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); return lite::RET_ERROR;
}
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
if (weight_tensor == nullptr) {
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
return lite::RET_ERROR;
}
auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
transpose_node->set_abstract(abstract); transpose_node->set_abstract(abstract);
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post"); transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
for (auto &adjust_node : adjust_nodes) { for (auto &adjust_node : adjust_nodes) {
@ -177,11 +189,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
return false; return false;
} }
auto type_id = static_cast<TypeId>(weight_value->data_type()); auto type_id = static_cast<TypeId>(weight_value->data_type());
auto type_ptr = TypeIdToType(type_id);
auto shape = weight_value->shape(); auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end()); std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
weight_node->set_abstract(abstract_tensor); if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
} }
return RET_OK; return RET_OK;
} }