This commit is contained in:
xuanyue 2020-08-21 19:24:52 +08:00
parent 492e41a4af
commit bbedc02700
5 changed files with 31 additions and 67 deletions

View File

@ -28,8 +28,6 @@ class TopKCPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~TopKCPUKernel() override {
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(op_parameter_);
free(parameter->topk_node_list_);
}
int Init() override;

View File

@ -55,60 +55,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
}
}
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
bool has_tuple_get_item = false;
std::vector<AnfNodePtr> inputs;
inputs.clear();
inputs.emplace_back(cnode->input(0));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr input_node = cnode->input(i);
if (!input_node->isa<CNode>()) {
inputs.emplace_back(cnode->input(i));
continue;
}
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node);
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
has_tuple_get_item = true;
inputs.emplace_back(tuple_get_item_node->input(1));
AnfNodePtr indexNode = tuple_get_item_node->input(2);
if (!utils::isa<ValueNode>(indexNode)) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
return false;
}
auto value_node = utils::cast<ValueNodePtr>(indexNode);
} else {
inputs.emplace_back(cnode->input(i));
}
}
if (has_tuple_get_item) {
cnode->set_inputs(inputs);
}
return true;
}
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) {
MS_ASSERT(meta_graphT != nullptr);
MS_ASSERT(cnode != nullptr);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input_anode = cnode->input(i);
if (!input_anode->isa<CNode>()) {
MS_LOG(ERROR) << "Node of Return's input is not CNode";
return false;
}
auto input_cnode = utils::cast<CNodePtr>(input_anode);
std::string input_name = input_anode->fullname_with_scope();
auto iter = node_id_map_.find(input_name);
if (iter == node_id_map_.end()) {
MS_LOG(ERROR) << "Could not find output node";
return false;
}
auto graph_output = iter->second;
meta_graphT->outputIndex.emplace_back(graph_output);
}
return true;
}
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveTValue> primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
@ -182,6 +128,28 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
}
}
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node) {
MS_ASSERT(nullptr != meta_graph);
MS_ASSERT(nullptr != return_node);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (input_node->isa<CNode>()) {
auto ret = ConvertInputCNode(input_node, return_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "obtain outputs failed";
return;
}
} else {
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
return;
}
}
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) {
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]);
}
}
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
@ -202,24 +170,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
}
RemoveIfMakeTuple(cnode);
auto node = std::make_unique<schema::CNodeT>();
if (primT->value.type == schema::PrimitiveType_Return) {
AddOutPutIfReturn(meta_graphT, cnode);
node->name = "return_node";
SetGraphoutputIndex(cnode, meta_graphT, node.get());
continue;
}
auto node = std::make_unique<schema::CNodeT>();
node->name = cnode->fullname_with_scope();
node->nodeType = schema::NodeType_CNode;
node->name = cnode->fullname_with_scope();
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpInputNode failed";
return nullptr;
}
SetOpOutputNode(cnode, meta_graphT, node.get());
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertQuantParam failed";

View File

@ -36,8 +36,6 @@ class AnfExporter {
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
void RemoveIfMakeTuple(const CNodePtr &cnode);
bool RemoveIfTupleGetItem(const CNodePtr &cnode);
bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode);
protected:
int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode);
@ -46,6 +44,8 @@ class AnfExporter {
int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
void SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node);
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveTValue> primitive,

View File

@ -63,7 +63,7 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_SIGMOID;
attr->type = schema::ActivationType_HSWISH;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {

View File

@ -273,7 +273,7 @@ int TimeProfile::PrintResult(const std::vector<std::string> &title,
columnLenMax.at(i) = printBuf.size();
}
printBuf.resize(columnLenMax.at(i), ' ');
printf("%s", printBuf.c_str());
printf("%s\t", printBuf.c_str());
}
printf("\n");
for (size_t i = 0; i < rows.size(); i++) {