forked from mindspore-Ecosystem/mindspore
fix bug
This commit is contained in:
parent
492e41a4af
commit
bbedc02700
|
@ -28,8 +28,6 @@ class TopKCPUKernel : public LiteKernel {
|
|||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~TopKCPUKernel() override {
|
||||
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(op_parameter_);
|
||||
free(parameter->topk_node_list_);
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
|
|
|
@ -55,60 +55,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
|||
}
|
||||
}
|
||||
|
||||
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
bool has_tuple_get_item = false;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.clear();
|
||||
inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
AnfNodePtr input_node = cnode->input(i);
|
||||
if (!input_node->isa<CNode>()) {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
continue;
|
||||
}
|
||||
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node);
|
||||
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
|
||||
has_tuple_get_item = true;
|
||||
inputs.emplace_back(tuple_get_item_node->input(1));
|
||||
AnfNodePtr indexNode = tuple_get_item_node->input(2);
|
||||
if (!utils::isa<ValueNode>(indexNode)) {
|
||||
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
||||
return false;
|
||||
}
|
||||
auto value_node = utils::cast<ValueNodePtr>(indexNode);
|
||||
} else {
|
||||
inputs.emplace_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
if (has_tuple_get_item) {
|
||||
cnode->set_inputs(inputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) {
|
||||
MS_ASSERT(meta_graphT != nullptr);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto input_anode = cnode->input(i);
|
||||
if (!input_anode->isa<CNode>()) {
|
||||
MS_LOG(ERROR) << "Node of Return's input is not CNode";
|
||||
return false;
|
||||
}
|
||||
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
||||
std::string input_name = input_anode->fullname_with_scope();
|
||||
auto iter = node_id_map_.find(input_name);
|
||||
if (iter == node_id_map_.end()) {
|
||||
MS_LOG(ERROR) << "Could not find output node";
|
||||
return false;
|
||||
}
|
||||
auto graph_output = iter->second;
|
||||
meta_graphT->outputIndex.emplace_back(graph_output);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||
const std::shared_ptr<PrimitiveTValue> primitive,
|
||||
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
||||
|
@ -182,6 +128,28 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
|||
}
|
||||
}
|
||||
|
||||
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node) {
|
||||
MS_ASSERT(nullptr != meta_graph);
|
||||
MS_ASSERT(nullptr != return_node);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (input_node->isa<CNode>()) {
|
||||
auto ret = ConvertInputCNode(input_node, return_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain outputs failed";
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
|
||||
return;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) {
|
||||
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]);
|
||||
}
|
||||
}
|
||||
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
|
@ -202,24 +170,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
RemoveIfMakeTuple(cnode);
|
||||
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
|
||||
if (primT->value.type == schema::PrimitiveType_Return) {
|
||||
AddOutPutIfReturn(meta_graphT, cnode);
|
||||
node->name = "return_node";
|
||||
SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->name = cnode->fullname_with_scope();
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
|
||||
node->name = cnode->fullname_with_scope();
|
||||
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
||||
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetOpInputNode failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SetOpOutputNode(cnode, meta_graphT, node.get());
|
||||
|
||||
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
||||
|
|
|
@ -36,8 +36,6 @@ class AnfExporter {
|
|||
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *fb_node);
|
||||
void RemoveIfMakeTuple(const CNodePtr &cnode);
|
||||
bool RemoveIfTupleGetItem(const CNodePtr &cnode);
|
||||
bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode);
|
||||
|
@ -46,6 +44,8 @@ class AnfExporter {
|
|||
int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||
void SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *return_node);
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
|
||||
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||
const std::shared_ptr<PrimitiveTValue> primitive,
|
||||
|
|
|
@ -63,7 +63,7 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
attr->type = schema::ActivationType_SIGMOID;
|
||||
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||
attr->type = schema::ActivationType_SIGMOID;
|
||||
attr->type = schema::ActivationType_HSWISH;
|
||||
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
|
|
|
@ -273,7 +273,7 @@ int TimeProfile::PrintResult(const std::vector<std::string> &title,
|
|||
columnLenMax.at(i) = printBuf.size();
|
||||
}
|
||||
printBuf.resize(columnLenMax.at(i), ' ');
|
||||
printf("%s", printBuf.c_str());
|
||||
printf("%s\t", printBuf.c_str());
|
||||
}
|
||||
printf("\n");
|
||||
for (size_t i = 0; i < rows.size(); i++) {
|
||||
|
|
Loading…
Reference in New Issue