forked from OSSInnovation/mindspore
!6084 MSLITE errorcode return and add reduce parser for caffe,lsh_projection for tflite,fix bugs
Merge pull request !6084 from 徐安越/master
This commit is contained in:
commit
93e6bb9aca
1
build.sh
1
build.sh
|
@ -687,7 +687,6 @@ build_lite()
|
|||
|
||||
if [[ "X$COMPILE_LITE" = "Xon" ]]; then
|
||||
build_lite
|
||||
exit
|
||||
else
|
||||
build_mindspore
|
||||
fi
|
||||
|
|
|
@ -202,6 +202,7 @@ union PrimitiveType {
|
|||
NegGrad,
|
||||
LogGrad,
|
||||
BatchToSpaceND,
|
||||
LshProjection,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -120,6 +120,12 @@ enum PaddingMode : byte {
|
|||
MODE_RESERVED = 3
|
||||
}
|
||||
|
||||
enum LshProjectionType : byte {
|
||||
UNKNOWN = 0,
|
||||
SPARSE = 1,
|
||||
DENSE = 2
|
||||
}
|
||||
|
||||
table Pad {
|
||||
paddings: [int];
|
||||
paddingMode: PaddingMode;
|
||||
|
@ -661,7 +667,8 @@ enum ReduceMode : byte {
|
|||
ReduceMin = 2,
|
||||
ReduceProd = 3,
|
||||
ReduceSum = 4,
|
||||
ReduceSumSquare = 5
|
||||
ReduceSumSquare = 5,
|
||||
ReduceASum = 6
|
||||
}
|
||||
|
||||
table Reduce {
|
||||
|
@ -785,7 +792,7 @@ table FloorMod {
|
|||
table L2Norm {
|
||||
axis: [int];
|
||||
epsilon: float;
|
||||
activationType: ActivationType;
|
||||
activationType: ActivationType = 0;
|
||||
}
|
||||
|
||||
table LogicalAnd {
|
||||
|
@ -937,3 +944,7 @@ table BlackBox {
|
|||
size : int;
|
||||
address : [ubyte];
|
||||
}
|
||||
|
||||
table LshProjection {
|
||||
type : LshProjectionType;
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
if (i >= dst_node->inputIndex.size()) {
|
||||
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
||||
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
||||
break;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto activate_index = dst_node->inputIndex[i];
|
||||
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
||||
|
@ -170,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
|||
}
|
||||
}
|
||||
|
||||
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node) {
|
||||
MS_ASSERT(nullptr != meta_graph);
|
||||
MS_ASSERT(nullptr != return_node);
|
||||
|
@ -178,31 +178,34 @@ void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_p
|
|||
auto input_node = cnode->input(i);
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(ERROR) << "output node is nullptr";
|
||||
return;
|
||||
return RET_NULL_PTR;
|
||||
} else if (input_node->isa<CNode>()) {
|
||||
auto ret = ConvertInputCNode(input_node, return_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain outputs failed";
|
||||
return;
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
|
||||
return;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) {
|
||||
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
int ret = RET_OK;
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return nullptr;
|
||||
ret = RET_MEMORY_FAILED;
|
||||
break;
|
||||
}
|
||||
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
|
||||
primitive_c->Type() == schema::PrimitiveType_MakeTuple ||
|
||||
|
@ -216,32 +219,41 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|||
auto node = std::make_unique<schema::CNodeT>();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "object failed to be constructed";
|
||||
return nullptr;
|
||||
ret = RET_MEMORY_FAILED;
|
||||
break;
|
||||
}
|
||||
if (primT->value.type == schema::PrimitiveType_Return) {
|
||||
node->name = "return_node";
|
||||
SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
||||
ret = SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOpOutputN failed";
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->name = cnode->fullname_with_scope();
|
||||
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
||||
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
||||
ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOpInputNode failed";
|
||||
return nullptr;
|
||||
break;
|
||||
}
|
||||
SetOpOutputNode(cnode, meta_graphT, node.get());
|
||||
ret = ConvertQuantParam(meta_graphT, primitive_c, node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
||||
return nullptr;
|
||||
break;
|
||||
}
|
||||
if (!keep_graph) {
|
||||
primitive_c->ClearPrimitiveT();
|
||||
}
|
||||
meta_graphT->nodes.emplace_back(std::move(node));
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return nullptr;
|
||||
}
|
||||
// set graph input tensors
|
||||
SetGraphInputIndex(meta_graphT);
|
||||
return meta_graphT.release();
|
||||
|
@ -297,11 +309,11 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
|
|||
auto abstractBase = paramNode->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
|
||||
return RET_ERROR;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name();
|
||||
return RET_ERROR;
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
||||
|
@ -309,7 +321,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
|
|||
paramTensor->dataType = typePtr->type_id();
|
||||
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
|
||||
return RET_ERROR;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
|
||||
|
@ -431,13 +443,13 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
|
|||
auto ret = ConvertInputCNode(input_node, fb_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertInputCNode failed";
|
||||
return RET_ERROR;
|
||||
return ret;
|
||||
}
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
auto ret = ConvertInputParameter(input_node, meta_graphT, fb_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertInputParameter failed";
|
||||
return RET_ERROR;
|
||||
return ret;
|
||||
}
|
||||
if (!input_node->cast<ParameterPtr>()->has_default()) {
|
||||
is_graph_input = true;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class AnfExporter {
|
||||
|
@ -45,7 +46,7 @@ class AnfExporter {
|
|||
int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
void SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node);
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
|
||||
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||
|
|
|
@ -54,7 +54,7 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() {
|
|||
char *tensor_data = new (std::nothrow) char[size];
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
std::memcpy(tensor_data, tensor->data.data(), size);
|
||||
param_value->set_tensor_addr(tensor_data);
|
||||
|
@ -128,7 +128,7 @@ int AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNo
|
|||
auto tuple_get_item_prim_ptr = GetTupleGetItemPrim();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||
return RET_ERROR;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(i));
|
||||
|
@ -153,16 +153,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
|
|||
auto node = GetNode(j);
|
||||
if (nullptr == node) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_ERROR;
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
op_inputs.push_back(node);
|
||||
}
|
||||
auto new_cnode = func_graph_->NewCNode(op_inputs);
|
||||
new_cnode->set_fullname_with_scope(cNode->name);
|
||||
auto ret = ConvertAbstract(cNode, new_cnode);
|
||||
if (ret != RET_OK) {
|
||||
auto status = ConvertAbstract(cNode, new_cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertAbstract failed.";
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -176,7 +176,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() {
|
|||
auto make_tuple_prim_ptr = GetMakeTuplePrim();
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
||||
return RET_ERROR;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
||||
make_tuple_inputs.emplace_back(make_tuple_prim);
|
||||
|
@ -184,7 +184,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() {
|
|||
auto cNode = GetNode(tensor_id);
|
||||
if (nullptr == cNode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_ERROR;
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
make_tuple_inputs.emplace_back(cNode);
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() {
|
|||
auto return_prim_ptr = GetReturnPrim();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_ERROR;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto value_node = NewValueNode(return_prim_ptr);
|
||||
op_inputs.emplace_back(value_node);
|
||||
|
@ -207,14 +207,14 @@ int AnfImporterFromMetaGraphT::AddReturnCNode() {
|
|||
auto return_prim_ptr = GetReturnPrim();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_ERROR;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto value_node = NewValueNode(return_prim_ptr);
|
||||
std::vector<AnfNodePtr> op_inputs{value_node};
|
||||
auto cnode = GetNode(meta_graph_->outputIndex.front());
|
||||
if (nullptr == cnode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_ERROR;
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
op_inputs.emplace_back(cnode);
|
||||
auto return_cnode = func_graph_->NewCNode(op_inputs);
|
||||
|
|
|
@ -201,23 +201,23 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool)
|
|||
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
|
||||
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||
return false;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
node->set_name(value_proto.name());
|
||||
const auto &type_proto = value_proto.type();
|
||||
if (!type_proto.has_tensor_type()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
|
||||
return false;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
|
||||
if (!tensor_typeproto.has_elem_type() || !tensor_typeproto.has_shape()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor has no elem_type or shape! ";
|
||||
return false;
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape();
|
||||
std::vector<int> shape;
|
||||
|
@ -227,7 +227,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
|
|||
|
||||
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!";
|
||||
return false;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]);
|
||||
|
@ -248,7 +248,7 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
|
|||
MS_LOG(ERROR) << "memcpy_s error";
|
||||
delete tensor_data_buf;
|
||||
delete tensor_info;
|
||||
return false;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
|
@ -261,10 +261,10 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod
|
|||
delete tensor_info;
|
||||
}
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
||||
|
@ -273,20 +273,22 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outpu
|
|||
const onnx::TensorProto &initializer_proto = importProto.initializer(i);
|
||||
if (!initializer_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i;
|
||||
return false;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
default_para_map_[initializer_proto.name()] = initializer_proto;
|
||||
}
|
||||
|
||||
int status = RET_OK;
|
||||
MS_LOG(INFO) << "all parameters size: " << importProto.input_size();
|
||||
for (int i = 0; i < importProto.input_size(); ++i) {
|
||||
const onnx::ValueInfoProto &input_proto = importProto.input(i);
|
||||
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) {
|
||||
status = BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return status;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
|
@ -662,7 +664,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
|
@ -674,22 +676,25 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
if (node_type == kConstantValueNode) {
|
||||
if (!BuildValueNodeForFuncGraph(node_proto)) {
|
||||
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
return RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
|
||||
return false;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
}
|
||||
|
||||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) {
|
||||
MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
|
@ -697,47 +702,51 @@ bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph
|
|||
if (importProto.has_name()) {
|
||||
debug_info_ptr->set_name(importProto.name());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
||||
MS_LOG(INFO) << "FuncGraph under converting has not name!";
|
||||
}
|
||||
|
||||
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||
return false;
|
||||
auto status = ImportParametersForGraph(outputFuncGraph, importProto);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
return ImportNodesForGraph(outputFuncGraph, importProto, quantType);
|
||||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) {
|
||||
if (!model_proto.has_producer_name()) {
|
||||
MS_LOG(ERROR) << "Parse model producer name from pb file failed!";
|
||||
return false;
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
}
|
||||
producer_name_ = model_proto.producer_name();
|
||||
|
||||
if (!model_proto.has_model_version()) {
|
||||
MS_LOG(ERROR) << "Parse model producer version from pb file failed!";
|
||||
return false;
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
}
|
||||
model_version_ = model_proto.model_version();
|
||||
|
||||
if (!model_proto.has_ir_version()) {
|
||||
MS_LOG(ERROR) << "Parse model version from pb file failed!";
|
||||
return false;
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
}
|
||||
ir_version_ = model_proto.ir_version();
|
||||
return true;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (!ParseModelConfigureInfo(*onnx_model_)) {
|
||||
int status = ParseModelConfigureInfo(*onnx_model_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
return status;
|
||||
}
|
||||
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
||||
if (!BuildFuncGraph(dstGraph, graphBuild, quantType)) {
|
||||
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Build funcgraph failed!";
|
||||
func_graph_ = nullptr;
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
func_graph_ = dstGraph;
|
||||
MS_LOG(INFO) << "Parse pb to build FuncGraph Success!";
|
||||
|
|
|
@ -45,13 +45,13 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
int ConverterConstTensor() override { return RET_ERROR; };
|
||||
int ConverterCNode() override { return RET_ERROR; };
|
||||
int AddReturnCNode() override { return RET_ERROR; };
|
||||
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
int ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
|
||||
int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
|
|
|
@ -61,25 +61,31 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
|
||||
|
||||
if (new_graph == nullptr) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
// quant
|
||||
if (config != nullptr) {
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return nullptr;
|
||||
}
|
||||
} else if (config->quantType == schema::QuantType_WeightQuant) {
|
||||
auto bitNum = static_cast<size_t>(std::stoull(config->bitNum));
|
||||
if (bitNum != quant::UINT8_QUANTIZATION) {
|
||||
MS_LOG(ERROR) << "Current Only Support 8 bit weight quant";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(
|
||||
new_graph, config->quantSize, config->convWeightQuantChannelThreshold, config->bitNum);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -89,6 +95,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
auto status = mQuantizer->DoQuantize(new_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Quant failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
|
@ -97,6 +104,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
status = quant_cast.Run(new_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "ir/anf.h"
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -66,6 +66,8 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
if (flag->fmk == converter::FmkType_MS) {
|
||||
MS_ASSERT(nullptr != modelImporter);
|
||||
modelImporter->Import(flag->quantType);
|
||||
int status = modelImporter->Import(flag->quantType);
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
graph = modelImporter->GetResult();
|
||||
} else {
|
||||
MS_ASSERT(nullptr != modelParser);
|
||||
|
@ -94,8 +96,9 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
transform->SetGraphDef(meta_graph);
|
||||
transform->CreateQuantizer(flag);
|
||||
auto status = transform->Transform(*flag);
|
||||
if (status != 0) {
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transform meta graph failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -106,15 +109,16 @@ int RunConverter(int argc, const char **argv) {
|
|||
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
|
||||
if (flags == nullptr) {
|
||||
MS_LOG(ERROR) << "new flags error ";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto status = flags->Init(argc, argv);
|
||||
if (status == RET_SUCCESS_EXIT) {
|
||||
return 0;
|
||||
return status;
|
||||
}
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
|
||||
return 1;
|
||||
std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl;
|
||||
return status;
|
||||
}
|
||||
// Load graph
|
||||
std::string modelName = flags->modelFile.substr(flags->modelFile.find_last_of(DELIM_SLASH) + 1);
|
||||
|
@ -147,9 +151,11 @@ int RunConverter(int argc, const char **argv) {
|
|||
return 1;
|
||||
}
|
||||
}
|
||||
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
|
||||
if (fb_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert model return nullptr";
|
||||
return 1;
|
||||
std::cout << "CONVERT RESULT: FAILED!" << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
// save graph to file
|
||||
|
@ -158,13 +164,14 @@ int RunConverter(int argc, const char **argv) {
|
|||
status = storage.Save(*fb_graph, flags->outputFile);
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "Save graph failed";
|
||||
return 1;
|
||||
std::cout << "SAVE GRAPH FAILED!" << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
delete fb_graph;
|
||||
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
|
||||
|
||||
return 0;
|
||||
std::cout << "CONVERT RESULT: SUCCESS!" << std::endl;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "tools/anf_importer/anf_importer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -120,6 +120,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
status = mQuantizer->DetermineNodeQuantType();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DetermineNodeQuant failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -142,7 +143,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
auto formatTransPass = new (std::nothrow) FormatTransPass();
|
||||
if (formatTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new formatTransPass failed";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
formatTransPass->SetQuantType(ctx.quantType);
|
||||
formatTransPass->SetFmk(ctx.fmk);
|
||||
|
@ -154,7 +155,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) {
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
|
@ -196,7 +197,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.inferenceType);
|
||||
|
|
|
@ -117,6 +117,13 @@ STATUS InferShapePass::Run(MetaGraphT *graph) {
|
|||
if (ret == RET_INFER_INVALID) {
|
||||
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type) << "flag set to false.";
|
||||
for (auto input_tensor : input_tensors) {
|
||||
delete input_tensor;
|
||||
}
|
||||
for (auto output_tensor : output_tensors) {
|
||||
delete output_tensor;
|
||||
}
|
||||
return RET_INFER_INVALID;
|
||||
} else if (ret != RET_OK) {
|
||||
MS_LOG(WARNING) << "InferShape failed, name: " << node->name
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(node->primitive->value.type);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/anf_importer/import_from_meta_graphT.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using namespace schema;
|
||||
|
@ -35,8 +35,12 @@ class ModelParser {
|
|||
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) {
|
||||
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "parse model to fb failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = this->Fb2Anf(meta_graph);
|
||||
delete (meta_graph);
|
||||
delete(meta_graph);
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
|
@ -48,9 +52,10 @@ class ModelParser {
|
|||
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
AnfImporterFromMetaGraphT importer(meta_graph, func_graph);
|
||||
auto ret = importer.Import();
|
||||
if (RET_OK != ret) {
|
||||
MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << ret;
|
||||
auto status = importer.Import();
|
||||
if (RET_OK != status) {
|
||||
MS_LOG(ERROR) << "Import anf_graph from meta_graphT failed, ret: " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -52,7 +52,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
|
|||
for (auto &opDef : graphDefT->nodes) {
|
||||
for (auto pass : this->nodePasses) {
|
||||
status = pass->Run(new GraphNode(graphDefT, opDef.get()));
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) {
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run NodePass failed";
|
||||
return status;
|
||||
} else {
|
||||
|
@ -65,7 +65,7 @@ STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
|
|||
|
||||
for (auto pass : this->graphPasses) {
|
||||
status = pass->Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_ERR) {
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run GraphPass failed";
|
||||
return status;
|
||||
} else {
|
||||
|
|
|
@ -31,4 +31,5 @@ add_library(caffe_parser_mid OBJECT
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/caffe_tanh_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/caffe_exp_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/caffe_slice_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/caffe_reduce_parser.cc
|
||||
)
|
||||
|
|
|
@ -38,7 +38,16 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
// set default params
|
||||
attr->outMaxValue = false;
|
||||
attr->topK = 1;
|
||||
const caffe::ArgMaxParameter argmaxParam = proto.argmax_param();
|
||||
if (argmaxParam.has_out_max_val()) {
|
||||
attr->outMaxValue = argmaxParam.out_max_val();
|
||||
}
|
||||
if (argmaxParam.has_top_k()) {
|
||||
attr->topK = argmaxParam.top_k();
|
||||
}
|
||||
int32_t axisType;
|
||||
int32_t axis = 0;
|
||||
if (!argmaxParam.has_axis()) {
|
||||
|
@ -46,15 +55,9 @@ STATUS CaffeArgMaxParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
|||
} else {
|
||||
axisType = 1;
|
||||
axis = (int64_t)argmaxParam.axis();
|
||||
if (axis == -1) {
|
||||
MS_LOG(ERROR) << "axis with -1 may lead to calculation errors when input less than 4 dims.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
attr->axis = axis;
|
||||
attr->axisType = axisType;
|
||||
attr->outMaxValue = argmaxParam.out_max_val();
|
||||
attr->topK = argmaxParam.top_k();
|
||||
attr->keepDims = true;
|
||||
|
||||
op->name = proto.name();
|
||||
|
|
|
@ -33,18 +33,23 @@ const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
|
|||
|
||||
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
|
||||
int status = ValidateFileStr(modelFile, ".prototxt");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (weightFile.empty()) {
|
||||
MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (ValidateFileStr(weightFile, ".caffemodel") != RET_OK) {
|
||||
status = ValidateFileStr(weightFile, ".caffemodel");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -52,33 +57,40 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
TensorCache tensorCache;
|
||||
|
||||
caffe::NetParameter proto;
|
||||
if (ReadProtoFromText((const char *)modelFile.c_str(), &proto) != RET_OK) {
|
||||
status = ReadProtoFromText((const char *)modelFile.c_str(), &proto);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
metaGraph->name = proto.name();
|
||||
|
||||
caffe::NetParameter weight;
|
||||
if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) {
|
||||
status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto status = GetModelInput(proto, &tensorCache);
|
||||
status = GetModelInput(proto, &tensorCache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetModelInput failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseLayer failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = SetGraphTensorIndex(proto, &tensorCache, metaGraph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
metaGraph->name = GetModelName(modelFile);
|
||||
|
@ -148,7 +160,12 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T
|
|||
}
|
||||
|
||||
for (auto iter : caffeInspector.GetGraphOutput()) {
|
||||
int index = tensorCache->FindTensor(iter);
|
||||
int index = -1;
|
||||
if (splitLayer.find(iter) != splitLayer.end()) {
|
||||
index = tensorCache->FindTensor(splitLayer.find(iter)->second);
|
||||
} else {
|
||||
index = tensorCache->FindTensor(iter);
|
||||
}
|
||||
if (index >= 0) {
|
||||
subGraphDef->outputIndex.emplace_back(index);
|
||||
} else {
|
||||
|
@ -199,26 +216,28 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
|
|||
op->name = layer.name();
|
||||
|
||||
if (layer.type() == "Split") {
|
||||
splitLayer.emplace(layer.name(), layer.bottom(0));
|
||||
for (int j = 0; j < layer.top_size(); ++j) {
|
||||
splitLayer.emplace(layer.top(j), layer.bottom(0));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = SetOpInputIdx(layer, op.get(), tensorCache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!";
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
|
||||
auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str());
|
||||
if (nodeParser == nullptr) {
|
||||
MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name();
|
||||
return RET_ERROR;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::vector<schema::TensorT *> weightVec;
|
||||
status = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
|
||||
SetWeightTensor(weightVec, op.get(), tensorCache);
|
||||
|
@ -226,7 +245,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
|
|||
status = SetOpOutputIdx(layer, op.get(), tensorCache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!";
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
|
||||
// op->fmkType = FmkType_CAFFE;
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/caffe/caffe_reduce_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto,
|
||||
const caffe::LayerParameter &weight,
|
||||
schema::CNodeT *op,
|
||||
std::vector<schema::TensorT *> *weightVec) {
|
||||
MS_LOG(DEBUG) << "parse CaffeReduceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const caffe::ReductionParameter reduce_param = proto.reduction_param();
|
||||
if (reduce_param.has_operation()) {
|
||||
switch (reduce_param.operation()) {
|
||||
case caffe::ReductionParameter_ReductionOp_MEAN:
|
||||
attr->mode = schema::ReduceMode_ReduceMean;
|
||||
break;
|
||||
case caffe::ReductionParameter_ReductionOp_SUM:
|
||||
attr->mode = schema::ReduceMode_ReduceSum;
|
||||
break;
|
||||
case caffe::ReductionParameter_ReductionOp_SUMSQ:
|
||||
attr->mode = schema::ReduceMode_ReduceSumSquare;
|
||||
break;
|
||||
case caffe::ReductionParameter_ReductionOp_ASUM:
|
||||
attr->mode = schema::ReduceMode_ReduceASum;
|
||||
default:
|
||||
MS_LOG(ERROR) << "reduce parse params fail, unsupported opration: " << reduce_param.operation();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
attr->mode = schema::ReduceMode_ReduceSum;
|
||||
}
|
||||
if (reduce_param.has_axis()) {
|
||||
attr->axes = std::vector(1, reduce_param.axis());
|
||||
} else {
|
||||
attr->axes = std::vector(1, 0);
|
||||
}
|
||||
attr->reduceToEnd = true;
|
||||
attr->keepDims = false;
|
||||
op->name = proto.name();
|
||||
op->primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser.h"
|
||||
#include "tools/converter/parser/caffe/caffe_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class CaffeReduceParser : public CaffeNodeParser {
|
||||
public:
|
||||
CaffeReduceParser() : CaffeNodeParser("reduce") {}
|
||||
|
||||
STATUS Parse(const caffe::LayerParameter &proto,
|
||||
const caffe::LayerParameter &weight,
|
||||
schema::CNodeT *op,
|
||||
std::vector<schema::TensorT *> *weightVec) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_REDUCE_PARSER_H
|
||||
|
|
@ -548,8 +548,15 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "mslite don't support tanh now";
|
||||
return RET_ERROR;
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->type = schema::ActivationType_TANH;
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||
|
|
|
@ -458,14 +458,18 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|||
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
|
||||
int status = ValidateFileStr(modelFile, ".onnx");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
onnx::ModelProto onnx_model;
|
||||
if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) {
|
||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
||||
|
@ -475,19 +479,25 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
// find out input names and const names
|
||||
FindGraphInputAndConst(onnx_graph);
|
||||
// set const tensor
|
||||
if (SetGraphConstTensor(onnx_graph, &tensor_cache)) {
|
||||
status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
||||
// init onnx model graph input tensor
|
||||
if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
|
||||
status = SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphInputTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
// init onnx model graph output tensor
|
||||
if (SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
|
||||
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
// init op node input/output tensor, and dst_op attr
|
||||
|
@ -499,9 +509,10 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache);
|
||||
continue;
|
||||
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
|
||||
auto status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache);
|
||||
status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
continue;
|
||||
|
@ -509,9 +520,10 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
|
||||
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
||||
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
||||
auto status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache);
|
||||
status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
dst_graph->nodes.emplace_back(std::move(dst_op));
|
||||
|
|
|
@ -42,11 +42,29 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|||
attr->format = schema::Format_NCHW;
|
||||
std::vector<int64_t> shape;
|
||||
shape.clear();
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "shape") {
|
||||
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
|
||||
shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i)));
|
||||
if (onnx_node.input_size() != 2) {
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "shape") {
|
||||
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
|
||||
shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
onnx::TensorProto input_shape;
|
||||
const auto &shape_name = onnx_node.input(1);
|
||||
for (const auto &it : onnx_graph.initializer()) {
|
||||
if (it.name() == shape_name) {
|
||||
input_shape = it;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (input_shape.int64_data_size() == 0) {
|
||||
MS_LOG(WARNING) << "shape maybe from another op other than const initializer";
|
||||
} else {
|
||||
for (int i = 0; i < input_shape.int64_data_size(); ++i) {
|
||||
shape.push_back(input_shape.int64_data(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,26 +43,10 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (tflite_op->inputs.empty()) {
|
||||
MS_LOG(ERROR) << "the input is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_tensors[data_index];
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the input tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto ndim = data_tensor->shape.size();
|
||||
std::vector<int32_t> axis;
|
||||
axis.reserve(ndim);
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
axis.emplace_back(i);
|
||||
}
|
||||
attr->axis = axis;
|
||||
attr->epsilon = 0.0f;
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions();
|
||||
attr->axis = {-1};
|
||||
attr->epsilon = 1e-6f;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_L2Norm;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_lsh_projection_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLshProjectionParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions();
|
||||
switch (tflite_attr->type) {
|
||||
case tflite::LSHProjectionType_SPARSE:
|
||||
attr->type = schema::LshProjectionType_SPARSE;
|
||||
break;
|
||||
case tflite::LSHProjectionType_DENSE:
|
||||
attr->type = schema::LshProjectionType_DENSE;
|
||||
break;
|
||||
default:
|
||||
attr->type = schema::LshProjectionType_UNKNOWN;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_LshProjection;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteLshProjectionParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H
|
||||
|
|
@ -56,11 +56,11 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<
|
|||
if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(),
|
||||
tflite_model_buffer[buffer_idx]->data.size())) {
|
||||
MS_LOG(ERROR) << "memcpy tensor data failed";
|
||||
return RET_ERROR;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "src tensor data is empty";
|
||||
return RET_ERROR;
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -77,7 +77,8 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
|
|||
}
|
||||
|
||||
// change quant param min to 0 to fit ms-lite ops
|
||||
if (tensor->dataType == TypeId::kNumberTypeInt8) {
|
||||
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8
|
||||
&& tensor->dataType == TypeId::kNumberTypeInt8) {
|
||||
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
||||
}
|
||||
|
||||
|
@ -114,12 +115,13 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str();
|
||||
return RET_NULL_PTR;
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
if (node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId,
|
||||
&tensorsFormat, &tensorsIdMap) != RET_OK) {
|
||||
int status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId,
|
||||
&tensorsFormat, &tensorsIdMap);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
|
||||
return RET_ERROR;
|
||||
return status;
|
||||
}
|
||||
|
||||
sub_graph->nodes.emplace_back(op.release());
|
||||
|
@ -158,7 +160,11 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
|
|||
auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer);
|
||||
auto isConst = (!tensor_buffer->data.empty());
|
||||
if (isConst) {
|
||||
CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get());
|
||||
int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain const tensor failed";
|
||||
return status;
|
||||
}
|
||||
} else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
// set in/out tensor to int8 to fit ms-lite op
|
||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||
|
@ -204,6 +210,9 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
|
|||
auto iter = tensorsIdMap.find(id);
|
||||
if (iter != tensorsIdMap.end()) {
|
||||
graph_inputs.push_back(iter->second);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "get graph input failed";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end());
|
||||
|
@ -220,6 +229,9 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
|
|||
auto iter = tensorsIdMap.find(id);
|
||||
if (iter != tensorsIdMap.end()) {
|
||||
graph_outputs.push_back(iter->second);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "get graph output failed";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end());
|
||||
|
@ -306,11 +318,13 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file,
|
|||
auto tflite_model = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (tflite_model->subgraphs.size() != 1) {
|
||||
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs[0];
|
||||
|
@ -318,31 +332,40 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file,
|
|||
auto meta_graph = std::make_unique<schema::MetaGraphT>();
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "new meta graph failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return nullptr;
|
||||
}
|
||||
meta_graph->name = "MS_model converted by TF-Lite";
|
||||
quantType = quant_type;
|
||||
// convert op
|
||||
if (ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()) != RET_OK) {
|
||||
int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse op failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// convert tensor
|
||||
if (ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()) != RET_OK) {
|
||||
status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert tensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// set graph input/output
|
||||
if (GetGraphInfo(tflite_subgraph, meta_graph.get()) != RET_OK) {
|
||||
status = GetGraphInfo(tflite_subgraph, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert tensors failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// update for depthwiseConv
|
||||
if (ConvertGroupDepthwiseOp(meta_graph.get()) != RET_OK) {
|
||||
status = ConvertGroupDepthwiseOp(meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert group depthwise conv failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_RETURN_CODE_H
|
||||
#define LITE_RETURN_CODE_H
|
||||
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ReturnCode {
|
||||
public:
|
||||
~ReturnCode() {}
|
||||
static ReturnCode *GetSingleReturnCode() {
|
||||
static ReturnCode returnCode;
|
||||
return &returnCode;
|
||||
}
|
||||
void UpdateReturnCode(STATUS status) {
|
||||
if (statusCode == RET_OK) {
|
||||
statusCode = status;
|
||||
}
|
||||
}
|
||||
STATUS GetReturnCode() {
|
||||
return statusCode;
|
||||
}
|
||||
private:
|
||||
ReturnCode() { statusCode = RET_OK; }
|
||||
int statusCode;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_RETURN_CODE_H
|
||||
|
|
@ -79,7 +79,7 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) {
|
|||
return 0;
|
||||
}
|
||||
}
|
||||
void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
|
||||
int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, const CNodePtr &bias_node) {
|
||||
AnfNodePtr conv_bias_node = nullptr;
|
||||
AnfNodePtr conv_weight_node = nullptr;
|
||||
if (conv_node->inputs().size() == kConvNoBiasLen) {
|
||||
|
@ -93,11 +93,12 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c
|
|||
auto kernel_nums = Get_Kenrnel_nums(conv_node);
|
||||
if (kernel_nums <= 0) {
|
||||
MS_LOG(EXCEPTION) << "kernel num less than 0";
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto add_bias_data = new (std::nothrow) float[kernel_nums];
|
||||
if (add_bias_data == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
return;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
|
||||
CheckIfNodeIsParam(bias_add_weight);
|
||||
|
@ -112,6 +113,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c
|
|||
} else {
|
||||
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
if (conv_bias_node != nullptr) {
|
||||
|
@ -120,6 +122,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c
|
|||
auto conv_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param);
|
||||
if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) {
|
||||
MS_LOG(EXCEPTION) << "conv_bias_node shape error";
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->tensor_addr());
|
||||
for (int i = 0; i < kernel_nums; i++) {
|
||||
|
@ -133,6 +136,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c
|
|||
conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias");
|
||||
conv_node->add_input(conv_new_bias);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
const BaseRef ConvBiasaddFusion::DefinePattern() const {
|
||||
|
@ -159,7 +163,11 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
}
|
||||
auto conv_node = conv_node_anf->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(conv_node);
|
||||
GenConvNewBias(func_graph, conv_node, add_node);
|
||||
int ret = GenConvNewBias(func_graph, conv_node, add_node);
|
||||
if (ret != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
auto type = primitive_c->Type();
|
||||
|
@ -180,6 +188,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
primc->SetHasBias(true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported opType, " << type;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
return nullptr;
|
||||
}
|
||||
return conv_node;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
Loading…
Reference in New Issue