fix abstract of parameter

This commit is contained in:
hangangqiang 2021-04-22 16:50:38 +08:00
parent d6f58cb765
commit 84e4906c9d
34 changed files with 797 additions and 401 deletions

View File

@ -38,9 +38,12 @@ int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
const TensorC *input_k_tensor = inputs[1];
if (input_k_tensor->data_ == NULL) {
return NNACL_INFER_INVALID;
}
TopkParameter *param = (TopkParameter *)parameter;
const TensorC *input_k_tensor = inputs[1];
param->k_ = ((int32_t *)input_k_tensor->data_)[0];
int out_shape[MAX_SHAPE_SIZE];

View File

@ -75,12 +75,14 @@ class LiteModel : public Model {
} else {
node->name_ = c_node->name()->c_str();
}
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
if (c_node->inputIndex() != nullptr) {
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
}
}
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
auto count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
}

View File

@ -247,6 +247,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc

View File

@ -73,6 +73,40 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std
return tensor_info;
}
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) {
auto tensor_info = CreateTensorInfo(nullptr, 0, shape, data_type);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return nullptr;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
return abstract;
}
int SetParameterAbstractAndParam(const ParameterPtr &parameter, const void *data, size_t data_size,
const std::vector<int64_t> &shape, TypeId data_type) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "Input parameter is nullptr";
return RET_INPUT_PARAM_INVALID;
}
auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
return RET_OK;
}
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) {
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "tensor info is nullptr.";

View File

@ -46,6 +46,11 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT>
tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
TypeId data_type);
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type);
int SetParameterAbstractAndParam(const ParameterPtr &parameter, const void *data, size_t data_size,
const std::vector<int64_t> &shape, TypeId data_type);
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size);
std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info,

View File

@ -54,6 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/fusion/matmul_add_fusion.cc
../optimizer/fusion/mul_add_fusion.cc
../optimizer/fusion/gelu_fusion.cc
../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc

View File

@ -70,7 +70,7 @@ MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
}
MS_LOG(INFO) << "Run anfTransform success";
// protobuf -> flatbuf
// protobuf -> flatbuffer
auto meta_graph = Export(graph, false, false, flag->trainModel);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";

View File

@ -39,7 +39,6 @@
using std::string;
namespace mindspore::lite {
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
std::vector<schema::CNodeT *> old_nodes{};
old_nodes.resize(graph_defT_->nodes.size());
@ -71,54 +70,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// generate and infer quant parameters
{
Optimizer infer_quant_param_pass;
infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
status = infer_quant_param_pass.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
return status;
}
}
{
// format transform
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
if (ctx.fmk != converter::FmkType_TF) {
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
format_trans_optimizer.AddPass(infer_shape_pass);
}
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer format_trans_optimizer;
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = format_trans_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
return status;
}
}
// format transpose global optimize
{
// init old node indices
auto old_nodes = GetGraphNodes();
@ -134,20 +86,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// postconvert pass
{
// node replace
if (!ctx.trainModel) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer replace_optimizer;
if (!ctx.trainModel) {
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
if (batch_norm_scale_pass == nullptr) {
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
return RET_ERROR;
}
batch_norm_scale_pass->SetFmk(ctx.fmk);
replace_optimizer.AddPass(batch_norm_scale_pass);
}
replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk));
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
status = replace_optimizer.Run(graph_defT_);
@ -157,6 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// node fusion
{
// init old node indices
auto old_nodes = GetGraphNodes();
@ -171,19 +117,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// do quantization
// quantization
if (ctx.fmk != converter::FmkType_TF) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer tensor_quant_optimizer;
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
tensor_quant_optimizer.AddPass(infer_shape_pass);
tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = tensor_quant_optimizer.Run(graph_defT_);
@ -193,38 +134,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// insert quantNode and deQuantNode
// quantization
if (ctx.fmk != converter::FmkType_TF) {
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer quant_node_optimizer;
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
quant_node_optimizer.AddPass(infer_shape_pass);
status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
return status;
}
auto old_nodes2 = GetGraphNodes();
quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
if (dtype_trans_pass == nullptr) {
MS_LOG(ERROR) << "new dtype_trans_pass failed";
return RET_MEMORY_FAILED;
}
dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
quant_node_optimizer.AddPass(dtype_trans_pass);
auto old_nodes = GetGraphNodes();
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType));
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
status = quant_node_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
@ -232,7 +152,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// switch pass
// controlflow pass
{
// init old node indices
auto old_nodes = GetGraphNodes();
@ -240,6 +160,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = switch_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run switch_optimizer Failed";
@ -247,34 +168,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
// subgraph tensor pass
{
Optimizer subgraph_tensor_optimizer;
subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
status = subgraph_tensor_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
return status;
}
}
// tensor name
{
// init old node indices
auto old_nodes = GetGraphNodes();
Optimizer name_optimizer;
name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
status = name_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
return status;
}
}
{
Optimizer nested_loop_optimizer;
auto old_nodes = GetGraphNodes();
nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nested_loop_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -284,30 +182,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
{
Optimizer quant_param_optimizer;
quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
status = quant_param_optimizer.Run(graph_defT_);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed";
return status;
}
}
{
Optimizer infer_shape_optimizer;
auto infer_shape_pass = new (std::nothrow) InferShapePass();
if (infer_shape_pass == nullptr) {
MS_LOG(ERROR) << "new InferShapePass failed";
return RET_MEMORY_FAILED;
}
infer_shape_pass->set_fmk_type(ctx.fmk);
infer_shape_optimizer.AddPass(infer_shape_pass);
status = infer_shape_optimizer.Run(graph_defT_);
Optimizer forming_model_optimizer;
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
status = forming_model_optimizer.Run(graph_defT_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
return status;
}
}
return RET_OK;
} // namespace mindspore::lite
}
} // namespace mindspore::lite

View File

@ -36,14 +36,12 @@ struct BNWeightTensors {
};
class BatchNormConvertScalePass : public GraphPass {
public:
BatchNormConvertScalePass() = default;
explicit BatchNormConvertScalePass(converter::FmkType fmk) : fmkType(fmk) {}
~BatchNormConvertScalePass() = default;
STATUS Run(MetaGraphT *graph) override;
void SetFmk(converter::FmkType fmk) { this->fmkType = fmk; }
protected:
STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);

View File

@ -276,10 +276,5 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
castOpCopyer);
}
void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
} // namespace lite
} // namespace mindspore

View File

@ -30,16 +30,13 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
class DTypeTransPass : public GraphPass {
public:
DTypeTransPass() : id_(0) {}
DTypeTransPass(TypeId model_input_data_type, TypeId model_output_data_type)
: id_(0), input_data_dtype(model_input_data_type), output_data_dtype(model_output_data_type) {}
~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
void set_input_data_dtype(TypeId data_type);
void set_output_data_dtype(TypeId dataType);
private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);

View File

@ -39,14 +39,10 @@ struct InferTensor {
class InferShapePass : public GraphPass {
public:
InferShapePass() = default;
~InferShapePass() = default;
explicit InferShapePass(converter::FmkType fmk_type) : fmk_type_(fmk_type) {}
~InferShapePass() override = default;
STATUS Run(MetaGraphT *graph) override;
void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
private:
void InitSearchTensor(MetaGraphT *graph);
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);

View File

@ -34,8 +34,28 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) = 0;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) {
auto ret = ParseToFuncGraph(model_file, weight_file, quant_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse to func graph failed : " << ret;
return nullptr;
}
ret = PostAdjust();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Adjust func graph failed : " << ret;
return nullptr;
}
return this->res_graph_;
}
protected:
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) = 0;
virtual int PostAdjust() = 0;
protected:
FuncGraphPtr res_graph_ = nullptr;
};
} // namespace mindspore::lite

View File

@ -15,6 +15,7 @@
*/
#include <vector>
#include "tools/common/tensor_util.h"
#include "tools/converter/ops/while.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
@ -55,7 +56,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
AbstractBasePtrList output;
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
output.push_back(std::make_shared<abstract::AbstractTensor>(input_args[i]->BuildType(), shape));
auto abstract_tensor = lite::CreateTensorAbstract(shape, input_args[i]->BuildType()->type_id());
MS_EXCEPTION_IF_NULL(abstract_tensor);
output.push_back(abstract_tensor);
}
return std::make_shared<abstract::AbstractTuple>(output);
}

View File

@ -41,34 +41,34 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
func_graph_ptr_ = std::make_shared<FuncGraph>();
res_graph_ = std::make_shared<FuncGraph>();
status = ConvertGraphInputs();
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
status = ConvertLayers();
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
status = ConvertGraphOutputs();
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
return func_graph_ptr_;
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
return RET_OK;
}
STATUS CaffeModelParser::ConvertLayers() {
@ -134,7 +134,7 @@ STATUS CaffeModelParser::ConvertLayers() {
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end());
op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end());
auto new_cnode = func_graph_ptr_->NewCNode(op_inputs);
auto new_cnode = res_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(layer.name());
// convert outputs
@ -194,14 +194,17 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
for (int i = 0; i < caffe_model_.layer_size(); i++) {
auto layer = caffe_model_.layer(i);
if (layer.type() == "Input") {
auto parameter = func_graph_ptr_->add_parameter();
auto parameter = res_graph_->add_parameter();
std::vector<int64_t> shape;
for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) {
shape.push_back(layer.input_param().shape(0).dim(j));
}
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor);
auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + std::to_string(i));
nodes_.insert(std::pair(layer.top(0), parameter));
}
@ -220,10 +223,13 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
shape.push_back(caffe_model_.input_dim(j));
}
}
auto parameter = func_graph_ptr_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
parameter->set_abstract(abstract_tensor);
auto parameter = res_graph_->add_parameter();
auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
}
@ -234,10 +240,18 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
for (int j = 0; j < shape.dim_size(); j++) {
shape_vector.push_back(shape.dim(j));
}
auto parameter = func_graph_ptr_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
auto parameter = res_graph_->add_parameter();
auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract);
parameter->set_name("graph_input-" + caffe_model_.input(i));
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
}
@ -265,7 +279,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
auto cnode = nodes_.find(output_node)->second;
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_inputs);
auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs;
@ -277,9 +291,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode);
auto cnode = func_graph_ptr_->NewCNode(op_inputs);
auto cnode = res_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("Return");
func_graph_ptr_->set_return(cnode);
res_graph_->set_return(cnode);
} else {
auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) {
@ -298,9 +312,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
return RET_NOT_FIND_OP;
}
opInputs.emplace_back(cnode);
auto returnCnode = func_graph_ptr_->NewCNode(opInputs);
auto returnCnode = res_graph_->NewCNode(opInputs);
returnCnode->set_fullname_with_scope("Return");
func_graph_ptr_->set_return(returnCnode);
res_graph_->set_return(returnCnode);
}
return RET_OK;
}
@ -333,7 +347,7 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v
ConvertShape(layer.blobs(i), &shape);
// cal Weight num
auto parameter = func_graph_ptr_->add_parameter();
auto parameter = res_graph_->add_parameter();
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
std::vector<int64_t> shape_vector;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
@ -402,17 +416,25 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std::
}
STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) {
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
std::vector<int64_t> shape_vector;
if (layer.top_size() == 1) {
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
cnode->set_abstract(abstract);
nodes_[layer.top(0)] = cnode;
return RET_OK;
}
AbstractBasePtrList abstract_list;
for (int i = 0; i < layer.top_size(); i++) {
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed";
@ -421,7 +443,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(i));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs);
CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
get_item_cnode->set_fullname_with_scope(layer.top(i));
nodes_[layer.top(i)] = get_item_cnode;
}
@ -446,4 +468,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
}
return layer.name();
}
int CaffeModelParser::PostAdjust() { return RET_OK; }
} // namespace mindspore::lite

View File

@ -32,8 +32,10 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int PostAdjust() override;
private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
@ -59,7 +61,6 @@ class CaffeModelParser : public ModelParser {
caffe::NetParameter caffe_weight_;
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_;
};
} // namespace mindspore::lite

View File

@ -45,31 +45,31 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
anf_root_graph_ = std::make_shared<FuncGraph>();
res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "init origin model failed.";
return nullptr;
return status;
}
status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node");
status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node");
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert onnx graph failed.";
return nullptr;
return status;
}
static auto root_func_manager = Manage(anf_root_graph_);
static auto root_func_manager = Manage(res_graph_);
for (auto &subgraph : all_subgraphs_) {
subgraph->set_manager(root_func_manager);
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
}
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
return anf_root_graph_;
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
return RET_OK;
}
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
@ -88,9 +88,9 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
onnx_root_graph_ = onnx_model_.graph();
if (OnnxNodeParser::opset_version() > 15) {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
} else {
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
}
return RET_OK;
}
@ -170,13 +170,16 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c
<< static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
return RET_ERROR;
}
auto type_ptr = TypeIdToType(data_type);
std::vector<int64_t> shape_vector;
auto onnx_shape = input_value.type().tensor_type().shape().dim();
std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
[](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract_tensor);
parameter->set_name(input_value.name());
anf_nodes_map->emplace(input_value.name(), parameter);
@ -490,17 +493,23 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F
return RET_NULL_PTR;
}
if (onnx_node.output_size() == 1) {
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
std::vector<int64_t> shape_vector;
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
cnode->set_abstract(abstract_tensor);
anf_nodes_map->emplace(onnx_node.output(0), cnode);
} else {
AbstractBasePtrList abstract_list;
int op_idx = 0;
for (const auto &output_name : onnx_node.output()) {
std::vector<int64_t> shape_vector;
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed";
@ -687,7 +696,11 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
return nullptr;
}
auto const_node = anf_graph->add_parameter();
auto const_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
if (const_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
const_node->set_abstract(const_abstract);
int *tensor_data = new (std::nothrow) int[1];
if (tensor_data == nullptr) {
@ -834,9 +847,16 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v
for (int i = 0; i < act_output_num; i++) {
// tensor_array need as root while input
auto while_tensor_array_input = anf_root_graph->add_parameter();
std::vector<int64_t> shape_vector;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vector);
auto tensor_info = std::make_shared<tensor::Tensor>(kObjectTypeTensorType, shape_vector);
auto tensor_info = CreateTensorInfo(nullptr, 0, {}, kObjectTypeTensorType);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Create tensor info failed";
return RET_ERROR;
}
auto abstract_tensor = tensor_info->ToAbstract();
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
while_tensor_array_input->set_abstract(abstract_tensor);
while_tensor_array_input->set_default_param(tensor_info);
while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray");
@ -975,7 +995,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
auto input_paramter = cond_graph->add_parameter();
input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs();
auto input_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
if (input_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
input_paramter->set_abstract(input_abstract);
if (i == 0) {
auto zero_parameter = CreateConstParamter(cond_graph, 0);
@ -987,7 +1011,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
MS_LOG(ERROR) << "new cnode error";
return RET_ERROR;
}
auto less_abstract = std::make_shared<abstract::AbstractTensor>(kBool, std::vector<int64_t>());
auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
if (less_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
less_cnode->set_abstract(less_abstract);
less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
}
@ -1020,12 +1048,11 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
MS_LOG(ERROR) << "quant param type don't support.";
return RET_NOT_SUPPORT;
}
std::vector<int64_t> shape_vector;
auto parameter_node = anf_root_graph_->add_parameter();
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
auto parameter_node = res_graph_->add_parameter();
auto abstract_tensor = CreateTensorAbstract({}, type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "new abstract_tensor failed";
return RET_MEMORY_FAILED;
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter_node->set_abstract(abstract_tensor);
parameter_node->set_name(name);
@ -1051,9 +1078,12 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr &parameter_node, c
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
return RET_ERROR;
}
auto type_ptr = TypeIdToType(data_type);
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter_node->set_abstract(abstract_tensor);
parameter_node->set_name(tensor.name());
@ -1142,5 +1172,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type
}
return iter->second;
}
int OnnxModelParser::PostAdjust() { return 0; }
} // namespace lite
} // namespace mindspore

View File

@ -40,14 +40,17 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int PostAdjust() override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
const tensor::TensorPtr &param_value_lite);
STATUS InitOriginModel(const std::string &model_file);
private:
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
const std::string &root_node_name);
@ -94,7 +97,6 @@ class OnnxModelParser : public ModelParser {
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
FuncGraphPtr anf_root_graph_ = nullptr;
};
} // namespace lite
} // namespace mindspore

View File

@ -417,18 +417,17 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
type = TensorFlowUtils::GetTFDataType(attr_value.type());
}
std::vector<int> shape;
std::vector<int64_t> shape;
if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
auto &shape_attr = attr_value.shape();
for (int i = 0; i < shape_attr.dim_size(); ++i) {
shape.push_back(shape_attr.dim(i).size());
}
}
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
MS_LOG(INFO) << "Found value attr, means it has default value";
auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector);
auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert const tensor failed.";
return status;
@ -437,10 +436,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
}
auto type_ptr = TypeIdToType(type == kNumberTypeInt64 ? kNumberTypeInt32 : type);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
type = (type == kNumberTypeInt64) ? kNumberTypeInt32 : type;
auto abstract_tensor = CreateTensorAbstract(shape, type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr";
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_name(node.name());
@ -473,51 +472,51 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
}
return RET_OK;
}
FuncGraphPtr paserTfFuction() { return nullptr; }
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
return status;
}
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
anf_root_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) {
res_graph_ = std::make_shared<FuncGraph>();
if (res_graph_ == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
return status;
}
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def;
}
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
if (status != RET_OK) {
success_flag = false;
@ -525,7 +524,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
}
if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed.";
return nullptr;
return RET_ERROR;
}
if (!nodes_with_null_input_.empty()) {
@ -533,7 +532,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
if (status != RET_OK) {
MS_LOG(ERROR) << "Connect null inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
}
@ -541,17 +540,17 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
status = ConvertSubgraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert subgraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
return anf_root_graph_;
return RET_OK;
}
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
@ -745,7 +744,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR;
}
static auto root_func_manager = Manage(anf_root_graph_);
static auto root_func_manager = Manage(res_graph_);
for (auto &kv : first_func_map) {
auto control_flow_node = kv.first;
@ -757,7 +756,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
auto second_value_node = NewValueNode(second_sub_graph);
auto inputs = control_flow_node->inputs();
inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node});
auto new_node = anf_root_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update
auto new_node = res_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update
if (new_node == nullptr) {
MS_LOG(ERROR) << "new node failed";
return RET_ERROR;
@ -811,43 +810,46 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
if (output_size == 0) {
return RET_OK;
} else if (output_size == 1) {
auto type = kFloat32;
std::vector<int64_t> shape_vector;
auto type = kNumberTypeFloat32;
if (IsTensorListOp(anf_node)) {
type = TypeIdToType(kObjectTypeTensorType);
type = kObjectTypeTensorType;
}
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
auto abstract_tensor = CreateTensorAbstract({}, type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
anf_node->set_abstract(abstract);
anf_node->set_abstract(abstract_tensor);
anf_node_map->insert(std::pair(op.name(), anf_node));
} else {
AbstractBasePtrList abstractList;
AbstractBasePtrList abstract_list;
for (int output_idx = 0; output_idx < output_size; output_idx++) {
std::vector<int64_t> shape_vector;
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
auto tupleGetItemPrimPtr = std::make_shared<ops::TupleGetItem>();
if (tupleGetItemPrimPtr == nullptr) {
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR;
}
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
CNodePtr getItemCNode = anf_graph->NewCNode(inputs);
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(output_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, anf_node, get_item_value};
CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
if (get_item_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
getItemCNode->set_abstract(abstract);
getItemCNode->set_fullname_with_scope(output_item_name);
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
get_item_cnode->set_abstract(get_item_abstract);
get_item_cnode->set_fullname_with_scope(output_item_name);
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode));
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
}
return RET_OK;
}
@ -1003,7 +1005,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
graph_output_names_.push_back(anf_node->fullname_with_scope());
}
}
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
auto status = MakeAnfGraphOutputs(&output_nodes, res_graph_);
if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;
@ -1051,5 +1053,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
}
return RET_OK;
}
int TFModelParser::PostAdjust() { return 0; }
} // namespace lite
} // namespace mindspore

View File

@ -36,9 +36,11 @@ namespace lite {
class TFModelParser : public ModelParser {
public:
TFModelParser() = default;
~TFModelParser() = default;
~TFModelParser() override = default;
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
int PostAdjust() override;
private:
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
@ -84,7 +86,6 @@ class TFModelParser : public ModelParser {
STATUS ConnectNullInput();
FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;

View File

@ -43,46 +43,46 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
return tflite::UnPackModel(tflite_model_buf_);
}
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
// load graph
tflite_model_ = ReadTfliteModel(model_file.c_str());
if (tflite_model_ == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
return RET_GRAPH_FILE_ERR;
}
if (tflite_model_->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
return RET_GRAPH_FILE_ERR;
}
func_graph_ = std::make_shared<FuncGraph>();
func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
res_graph_ = std::make_shared<FuncGraph>();
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
auto status = ConvertGraphInputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
status = ConvertOps();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
status = ConvertGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return status;
}
func_graph_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_;
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
return RET_OK;
}
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertOps() {
} else {
tensor_name = GetTensorName(i, tflite_op_type, op_name);
}
auto parameter = func_graph_->add_parameter();
auto parameter = res_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
@ -168,7 +168,7 @@ STATUS TfliteModelParser::ConvertOps() {
op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter));
}
auto new_cnode = func_graph_->NewCNode(op_inputs);
auto new_cnode = res_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(op_name);
// parse outputs
@ -284,13 +284,16 @@ STATUS TfliteModelParser::ConvertGraphInputs() {
if (tflite_graph_input < 0) {
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
}
auto parameter = func_graph_->add_parameter();
auto parameter = res_graph_->add_parameter();
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
nodes_.insert(std::pair(tflite_graph_input, parameter));
@ -318,7 +321,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
}
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs);
auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs;
@ -330,9 +333,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode);
auto cnode = func_graph_->NewCNode(op_inputs);
auto cnode = res_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("Return");
func_graph_->set_return(cnode);
res_graph_->set_return(cnode);
} else {
auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) {
@ -350,9 +353,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
return RET_NOT_FIND_OP;
}
op_inputs.emplace_back(cnode);
auto returnCnode = func_graph_->NewCNode(op_inputs);
auto returnCnode = res_graph_->NewCNode(op_inputs);
returnCnode->set_fullname_with_scope("Return");
func_graph_->set_return(returnCnode);
res_graph_->set_return(returnCnode);
}
return RET_OK;
}
@ -436,8 +439,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
dst_cnode->set_abstract(abstract_tensor);
nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
} else {
AbstractBasePtrList abstract_list;
@ -450,8 +457,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed";
@ -460,7 +471,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
CNodePtr get_item_cnode = func_graph_->NewCNode(inputs);
CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
nodes_.insert(std::pair(output_idx, get_item_cnode));
op_idx++;
@ -469,4 +480,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
}
return RET_OK;
}
int TfliteModelParser::PostAdjust() { return 0; }
} // namespace mindspore::lite

View File

@ -32,13 +32,14 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default;
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
int PostAdjust() override;
private:
std::unordered_map<int, AnfNodePtr> nodes_;
std::unique_ptr<tflite::ModelT> tflite_model_;
FuncGraphPtr func_graph_;
char *tflite_model_buf_ = nullptr;
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr &parameter,

View File

@ -399,6 +399,24 @@ int CheckIfCNodeIsNull(const CNodePtr &node) {
return lite::RET_OK;
}
int CheckIfParameterIsNull(const ParameterPtr &node) {
if (node == nullptr) {
MS_LOG(ERROR) << "The Parameter is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
return lite::RET_OK;
}
int CheckIfValueNodeIsNull(const ValueNodePtr &node) {
if (node == nullptr) {
MS_LOG(ERROR) << "The ValueNode is null.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_NULL_PTR;
}
return lite::RET_OK;
}
int CheckIfVarIsNull(const VarPtr &var) {
if (var == nullptr) {
MS_LOG(ERROR) << "The Var is null.";

View File

@ -57,6 +57,10 @@ int CheckIfAnfNodeIsNull(const AnfNodePtr &node);
int CheckIfCNodeIsNull(const CNodePtr &node);
int CheckIfParameterIsNull(const ParameterPtr &node);
int CheckIfValueNodeIsNull(const ValueNodePtr &node);
int CheckIfVarIsNull(const VarPtr &var);
int CheckInputSize(const CNodePtr &node, int size);

View File

@ -0,0 +1,294 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/mul_add_fusion.h"
#include <memory>
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/add_fusion.h"
#include "ops/fusion/scale_fusion.h"
#include "ops/op_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kMulInputsLength = 3;
constexpr size_t kAddInputsLength = 3;
} // namespace
const BaseRef MulAddFusion::DefinePattern() const {
auto mul_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
auto add_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
return VectorRef({add_var, mul_var});
}
bool MulAddFusion::ScaleInputShapeValid() const {
MS_ASSERT(scale_tensor_ != nullptr);
MS_ASSERT(bias_tensor_ != nullptr);
auto scale_shape = scale_tensor_->shape_c();
auto offset_shape = bias_tensor_->shape_c();
if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) {
return false;
}
size_t rank_diff = mul_input_shape_.size() - scale_shape.size();
for (size_t i = 0; i < scale_shape.size(); ++i) {
if (mul_input_shape_[i + rank_diff] != scale_shape[i]) {
return false;
}
}
if (scale_shape != offset_shape) {
return false;
}
return true;
}
bool MulAddFusion::CheckMulNode(const FuncGraphPtr &func_graph) const {
MS_ASSERT(func_graph != nullptr);
if (mul_anode_ == nullptr) {
return false;
}
if (IsMultiOutputTensors(func_graph, mul_anode_)) {
MS_LOG(DEBUG) << "Mul op has multi-output";
return false;
}
auto mul_node = mul_anode_->cast<CNodePtr>();
if (!CheckPrimitiveType(mul_node, prim::kPrimMulFusion)) {
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
return false;
}
auto mul_primitive = GetValueNode<std::shared_ptr<ops::MulFusion>>(mul_node->input(0));
MS_ASSERT(mul_primitive != nullptr);
auto mul_act_type = mul_primitive->get_activation_type();
if (mul_act_type != ActivationType::NO_ACTIVATION) {
MS_LOG(DEBUG) << "Only support mul node with no activation";
return false;
}
if (CheckIfCNodeIsNull(mul_node) != lite::RET_OK || CheckInputSize(mul_node, kMulInputsLength) != lite::RET_OK) {
MS_LOG(DEBUG) << "Mul op is null or has error input size";
return false;
}
// find mul's const input and mul input
AnfNodePtr mul_pre_input_node = nullptr;
AnfNodePtr mul_pre_const_node = nullptr;
auto mul_pre_node_1 = mul_node->input(1);
if (CheckIfAnfNodeIsNull(mul_pre_node_1) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
return false;
}
auto mul_pre_node_2 = mul_node->input(2);
if (CheckIfAnfNodeIsNull(mul_pre_node_2) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
return false;
}
if (utils::isa<CNodePtr>(mul_pre_node_1) && !utils::isa<CNodePtr>(mul_pre_node_2)) {
mul_pre_input_node = mul_pre_node_1;
mul_pre_const_node = mul_pre_node_2;
} else if (!utils::isa<CNodePtr>(mul_pre_node_1) && utils::isa<CNodePtr>(mul_pre_node_2)) {
mul_pre_input_node = mul_pre_node_1;
mul_pre_const_node = mul_pre_node_2;
} else {
MS_LOG(DEBUG) << "Mul op should has a cnode input and a const input";
return false;
}
// check mul's const input
tensor::TensorPtr mul_tensor = nullptr;
if (utils::isa<ParameterPtr>(mul_pre_const_node)) {
auto mul_bias_node = mul_pre_const_node->cast<ParameterPtr>();
MS_ASSERT(mul_bias_node != nullptr);
if (!mul_bias_node->has_default()) {
MS_LOG(DEBUG) << "Const input of mul op should has data";
return false;
}
mul_tensor = mul_bias_node->default_param()->cast<tensor::TensorPtr>();
} else if (utils::isa<ValueNodePtr>(mul_pre_const_node)) {
auto mul_bias_node = mul_pre_const_node->cast<ValueNodePtr>();
MS_ASSERT(mul_bias_node != nullptr);
if (mul_bias_node->value() == nullptr) {
MS_LOG(DEBUG) << "Const input of mul op should has data";
return false;
}
mul_tensor = mul_bias_node->value()->cast<tensor::TensorPtr>();
} else {
MS_ASSERT(false);
}
if (mul_tensor == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
mul_input_anode_ = mul_pre_input_node;
mul_const_anode_ = mul_pre_const_node;
scale_tensor_ = mul_tensor;
return true;
}
bool MulAddFusion::CheckAddNode() const {
if (add_anode_ == nullptr) {
return false;
}
auto add_cnode = add_anode_->cast<CNodePtr>();
if (CheckIfCNodeIsNull(add_cnode) != lite::RET_OK || CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) {
MS_LOG(DEBUG) << "Add op is null or has error input size";
return false;
}
if (!CheckPrimitiveType(add_cnode, prim::kPrimAddFusion)) {
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
return false;
}
auto add_primitive = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
MS_ASSERT(add_primitive != nullptr);
auto add_act_type = add_primitive->get_activation_type();
if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 &&
add_act_type != ActivationType::NO_ACTIVATION) {
MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation";
return false;
}
scale_act_type_ = add_act_type;
// find add's const input and mul input
AnfNodePtr add_pre_input_node = nullptr;
AnfNodePtr add_pre_const_node = nullptr;
auto add_pre_node_1 = add_cnode->input(1);
if (CheckIfAnfNodeIsNull(add_pre_node_1) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
return false;
}
auto add_pre_node_2 = add_cnode->input(2);
if (CheckIfAnfNodeIsNull(add_pre_node_2) != lite::RET_OK) {
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
return false;
}
if (utils::isa<CNodePtr>(add_pre_node_1) && !utils::isa<CNodePtr>(add_pre_node_2)) {
add_pre_input_node = add_pre_node_1;
add_pre_const_node = add_pre_node_2;
} else if (!utils::isa<CNodePtr>(add_pre_node_1) && utils::isa<CNodePtr>(add_pre_node_2)) {
add_pre_input_node = add_pre_node_2;
add_pre_const_node = add_pre_node_1;
} else {
MS_LOG(DEBUG) << "Add op should has a cnode input and a const input";
return false;
}
// check add's const input
tensor::TensorPtr add_tensor = nullptr;
if (utils::isa<ParameterPtr>(add_pre_const_node)) {
auto add_bias_node = add_pre_const_node->cast<ParameterPtr>();
MS_ASSERT(add_bias_node != nullptr);
if (!add_bias_node->has_default()) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
add_tensor = add_bias_node->default_param()->cast<tensor::TensorPtr>();
} else if (utils::isa<ValueNodePtr>(add_pre_const_node)) {
auto add_bias_node = add_pre_const_node->cast<ValueNodePtr>();
MS_ASSERT(add_bias_node != nullptr);
if (add_bias_node->value() == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
add_tensor = add_bias_node->value()->cast<tensor::TensorPtr>();
} else {
MS_ASSERT(false);
}
if (add_tensor == nullptr) {
MS_LOG(DEBUG) << "Const input of add op should has data";
return false;
}
mul_anode_ = add_pre_input_node;
add_const_anode_ = add_pre_const_node;
bias_tensor_ = add_tensor;
return true;
}
bool MulAddFusion::GetMulInputShape() const {
MS_ASSERT(mul_input_anode_ != nullptr);
ShapeVector mul_input_shape;
AbstractBasePtr mul_input_abstract = nullptr;
if (utils::isa<ParameterPtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<ParameterPtr>();
MS_ASSERT(mul_bias_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else if (utils::isa<ValueNodePtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<ValueNodePtr>();
MS_ASSERT(mul_input_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else if (utils::isa<CNodePtr>(mul_input_anode_)) {
auto mul_input_node = mul_input_anode_->cast<CNodePtr>();
MS_ASSERT(mul_input_node != nullptr);
mul_input_abstract = mul_input_node->abstract();
} else {
MS_ASSERT(false);
}
if (mul_input_abstract == nullptr) {
MS_LOG(DEBUG) << "Mul input node has no abstract";
return false;
}
if (!utils::isa<abstract::AbstractTensorPtr>(mul_input_abstract)) {
MS_LOG(DEBUG) << "Abstract of mul input node should be AbstractTensor";
return false;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(mul_input_abstract);
MS_ASSERT(abstract_tensor != nullptr);
MS_ASSERT(abstract_tensor->BuildShape() != nullptr);
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(DEBUG) << "BuildShape of abstract of mul input node should be ShapePtr";
return false;
}
mul_input_shape_ = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
return true;
}
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
add_anode_ = node;
if (!CheckAddNode()) {
MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope();
return nullptr;
}
MS_ASSERT(mul_anode_ != nullptr);
MS_ASSERT(bias_tensor_ != nullptr);
MS_ASSERT(add_const_anode_ != nullptr);
if (!CheckMulNode(func_graph)) {
MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_anode_->fullname_with_scope();
return nullptr;
}
MS_ASSERT(mul_input_anode_ != nullptr);
MS_ASSERT(scale_tensor_ != nullptr);
MS_ASSERT(mul_const_anode_ != nullptr);
if (!GetMulInputShape()) {
MS_LOG(DEBUG) << "Get input shape of mul op failed";
return nullptr;
}
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
if (!ScaleInputShapeValid()) {
MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed";
return nullptr;
}
// create scale primitive
auto scale_primitive = new (std::nothrow) mindspore::ops::ScaleFusion();
if (scale_primitive == nullptr) {
MS_LOG(ERROR) << "new scale primitive failed";
return nullptr;
}
scale_primitive->set_activation_type(scale_act_type_);
scale_primitive->set_axis(0 - bias_tensor_->shape_c().size());
// create scale op
auto scale_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(scale_primitive),
{mul_input_anode_, mul_const_anode_, add_const_anode_});
return scale_node;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace opt {
class MulAddFusion : public PatternProcessPass {
public:
explicit MulAddFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
: PatternProcessPass(name, multigraph) {}
~MulAddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
bool CheckMulNode(const FuncGraphPtr &func_graph) const;
bool CheckAddNode() const;
bool GetMulInputShape() const;
bool ScaleInputShapeValid() const;
private:
mutable AnfNodePtr mul_anode_ = nullptr;
mutable AnfNodePtr mul_input_anode_ = nullptr;
mutable AnfNodePtr mul_const_anode_ = nullptr;
mutable ShapeVector mul_input_shape_;
mutable AnfNodePtr add_anode_ = nullptr;
mutable AnfNodePtr add_const_anode_ = nullptr;
mutable tensor::TensorPtr scale_tensor_ = nullptr;
mutable tensor::TensorPtr bias_tensor_ = nullptr;
mutable ActivationType scale_act_type_ = ActivationType::NO_ACTIVATION;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_

View File

@ -256,11 +256,12 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun
auto parameter = func_graph->add_parameter();
parameter->set_name(name);
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
if (abstract_tensor == nullptr) {
auto abstract = lite::CreateTensorAbstract(shape_vector, type);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
parameter->set_abstract(abstract_tensor);
parameter->set_abstract(abstract);
auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
if (gate_weight_default == nullptr) {

View File

@ -502,13 +502,12 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
return nullptr;
}
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
std::vector<int64_t> shape_vector;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "create abstract_tensor failed";
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return nullptr;
}
get_item_cnode->set_abstract(abstract_tensor);
get_item_cnode->set_abstract(abstract);
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
std::to_string(item_index));
return get_item_cnode;
@ -581,13 +580,12 @@ STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int o
MS_ASSERT(cnode != nullptr);
AbstractBasePtrList abstract_list;
for (int i = 0; i < output_num; ++i) {
std::vector<int64_t> shape_vector;
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "create abstract_tensor failed";
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
abstract_list.emplace_back(abstract);
}
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
if (abstract_tuple == nullptr) {

View File

@ -23,6 +23,7 @@
#include "ops/return.h"
#include "ops/tuple_get_item.h"
#include "tools/converter/ops/while.h"
#include "tools/common/tensor_util.h"
namespace {
mindspore::ValueNodePtr GetWhileAnfPrim() {
@ -207,9 +208,13 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
auto node_users = manager->node_users()[node];
for (auto &node_user : node_users) {
// new getitem
AbstractBasePtrList abstractList;
std::vector<int64_t> shape_vector;
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
AbstractBasePtrList abstract_list;
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
@ -225,12 +230,12 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
CNodePtr get_item_node = fg_->NewCNode(inputs);
std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
if (abstract == nullptr) {
MS_LOG(ERROR) << "create AbstractTensor failed";
return RET_NULL_PTR;
auto get_item_node_abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (get_item_node_abstract == nullptr) {
MS_LOG(ERROR) << "Create get_item_node_abstract failed";
return RET_ERROR;
}
get_item_node->set_abstract(abstract);
get_item_node->set_abstract(get_item_node_abstract);
get_item_node->set_fullname_with_scope(output_item_name);
// set
if (fg_->nodes().contains(node_user.first)) {

View File

@ -22,6 +22,7 @@
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "securec/include/securec.h"
namespace mindspore::opt {
@ -101,13 +102,16 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
return false;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto type_ptr = TypeIdToType(type_id);
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
weight_node->set_abstract(abstract_tensor);
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
}
return true;

View File

@ -21,6 +21,7 @@
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "src/common/common.h"
#include "src/common/tensor_util.h"
#include "src/ops/populate/populate_register.h"
#include "src/ops/ops_utils.h"
#include "src/runtime/infer_manager.h"
@ -28,19 +29,6 @@
namespace mindspore::opt {
namespace {
constexpr size_t INITIAL_SIZE = 1024;
tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
std::vector<int> shape(tensor->shape());
std::vector<int64_t> shape_vector;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
return nullptr;
}
return tensor_info;
}
bool IsSpecialType(const CNodePtr &cnode) {
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
@ -75,21 +63,14 @@ STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr
}
} // namespace
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
abstract::AbstractBasePtr InferShapePass::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor);
std::vector<int> shape(tensor->shape());
auto shape = tensor->shape();
auto type_id = static_cast<TypeId>(tensor->data_type());
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
if (new_abstract == nullptr) {
MS_LOG(ERROR) << "new AbstractTensor failed";
return nullptr;
}
auto tensor_info = NewTensorInfo(tensor);
auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, type_id);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "new tensor::Tensor failed";
MS_LOG(DEBUG) << "Create tensor info failed";
return nullptr;
}
@ -112,8 +93,12 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
return nullptr;
}
}
new_abstract->set_value(tensor_info);
return new_abstract;
auto abstract = tensor_info->ToAbstract();
if (abstract == nullptr) {
MS_LOG(DEBUG) << "Create tensor abstarct failed";
return nullptr;
}
return abstract;
}
STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
@ -143,8 +128,6 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
std::vector<int32_t> shape;
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int32_t>(value); });
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector);
if (parameter->has_default()) {
auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
@ -155,7 +138,11 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr &parameter) {
return RET_ERROR;
}
}
new_abstract->set_value(new_tensor_info);
auto new_abstract = new_tensor_info->ToAbstract();
if (new_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
parameter->set_abstract(new_abstract);
return RET_OK;
}
@ -304,7 +291,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
}
if (output_tensors.size() == 1) {
auto tensor = output_tensors.front();
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor);
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) {
return RET_ERROR;
}
@ -313,7 +300,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
AbstractBasePtrList abstract_list;
for (size_t i = 0; i < output_tensors.size(); i++) {
auto tensor = output_tensors.front();
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor);
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
if (new_abstract == nullptr) {
return RET_ERROR;
}

View File

@ -36,7 +36,7 @@ class InferShapePass : public Pass {
private:
void FreeTensors(std::vector<lite::Tensor *> *tensors);
abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor);
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors);
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
STATUS SetParameterAbstract(const ParameterPtr &parameter);

View File

@ -179,23 +179,23 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
if (!utils::isa<ValueNodePtr>(anf_node)) {
return lite::RET_NO_CHANGE;
}
auto valueNode = anf_node->cast<ValueNodePtr>();
if (valueNode->abstract() == nullptr) {
auto value_node = anf_node->cast<ValueNodePtr>();
if (value_node->abstract() == nullptr) {
return lite::RET_NO_CHANGE;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueNode->abstract());
if (abstractTensor == nullptr) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract());
if (abstract_tensor == nullptr) {
return lite::RET_NO_CHANGE;
}
auto value = abstractTensor->GetValueTrack();
auto value = abstract_tensor->GetValueTrack();
if (value != nullptr && value->isa<tensor::Tensor>()) {
if (abstractTensor->element() == nullptr) {
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
return RET_ERROR;
}
auto typePtr = abstractTensor->element()->GetTypeTrack();
if (typePtr->type_id() == kNumberTypeInt64) {
auto shape_vector = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
if (type_ptr->type_id() == kNumberTypeInt64) {
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector);
auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c());
auto src_tensor_info = value->cast<tensor::TensorPtr>();
@ -204,10 +204,10 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
dest_data_buf[i] = src_data_buf[i];
}
abstractTensor->set_value(dest_tensor_info);
abstractTensor->set_type(TypeIdToType(kNumberTypeInt32));
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
valueNode->set_value(dest_tensor_info);
abstract_tensor->set_value(dest_tensor_info);
abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
value_node->set_value(dest_tensor_info);
}
}
return lite::RET_NO_CHANGE;

View File

@ -19,6 +19,7 @@
#include <vector>
#include "ops/transpose.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_MS;
@ -92,9 +93,20 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
auto prim = std::make_shared<ops::Transpose>();
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
auto abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
if (!weight_node->has_default()) {
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
return lite::RET_ERROR;
}
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
if (weight_tensor == nullptr) {
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
return lite::RET_ERROR;
}
auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
transpose_node->set_abstract(abstract);
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
for (auto &adjust_node : adjust_nodes) {
@ -177,11 +189,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
return false;
}
auto type_id = static_cast<TypeId>(weight_value->data_type());
auto type_ptr = TypeIdToType(type_id);
auto shape = weight_value->shape();
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
weight_node->set_abstract(abstract_tensor);
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstarct failed";
return RET_ERROR;
}
weight_node->set_abstract(abstract);
}
return RET_OK;
}