forked from mindspore-Ecosystem/mindspore
fix bug
This commit is contained in:
parent
492e41a4af
commit
bbedc02700
|
@ -28,8 +28,6 @@ class TopKCPUKernel : public LiteKernel {
|
||||||
const mindspore::lite::PrimitiveC *primitive)
|
const mindspore::lite::PrimitiveC *primitive)
|
||||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||||
~TopKCPUKernel() override {
|
~TopKCPUKernel() override {
|
||||||
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(op_parameter_);
|
|
||||||
free(parameter->topk_node_list_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
|
|
|
@ -55,60 +55,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
|
||||||
MS_ASSERT(cnode != nullptr);
|
|
||||||
bool has_tuple_get_item = false;
|
|
||||||
std::vector<AnfNodePtr> inputs;
|
|
||||||
inputs.clear();
|
|
||||||
inputs.emplace_back(cnode->input(0));
|
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
||||||
AnfNodePtr input_node = cnode->input(i);
|
|
||||||
if (!input_node->isa<CNode>()) {
|
|
||||||
inputs.emplace_back(cnode->input(i));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node);
|
|
||||||
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
|
|
||||||
has_tuple_get_item = true;
|
|
||||||
inputs.emplace_back(tuple_get_item_node->input(1));
|
|
||||||
AnfNodePtr indexNode = tuple_get_item_node->input(2);
|
|
||||||
if (!utils::isa<ValueNode>(indexNode)) {
|
|
||||||
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto value_node = utils::cast<ValueNodePtr>(indexNode);
|
|
||||||
} else {
|
|
||||||
inputs.emplace_back(cnode->input(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (has_tuple_get_item) {
|
|
||||||
cnode->set_inputs(inputs);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) {
|
|
||||||
MS_ASSERT(meta_graphT != nullptr);
|
|
||||||
MS_ASSERT(cnode != nullptr);
|
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
||||||
auto input_anode = cnode->input(i);
|
|
||||||
if (!input_anode->isa<CNode>()) {
|
|
||||||
MS_LOG(ERROR) << "Node of Return's input is not CNode";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
|
||||||
std::string input_name = input_anode->fullname_with_scope();
|
|
||||||
auto iter = node_id_map_.find(input_name);
|
|
||||||
if (iter == node_id_map_.end()) {
|
|
||||||
MS_LOG(ERROR) << "Could not find output node";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto graph_output = iter->second;
|
|
||||||
meta_graphT->outputIndex.emplace_back(graph_output);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||||
const std::shared_ptr<PrimitiveTValue> primitive,
|
const std::shared_ptr<PrimitiveTValue> primitive,
|
||||||
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
||||||
|
@ -182,6 +128,28 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||||
|
schema::CNodeT *return_node) {
|
||||||
|
MS_ASSERT(nullptr != meta_graph);
|
||||||
|
MS_ASSERT(nullptr != return_node);
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||||
|
auto input_node = cnode->input(i);
|
||||||
|
if (input_node->isa<CNode>()) {
|
||||||
|
auto ret = ConvertInputCNode(input_node, return_node);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "obtain outputs failed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) {
|
||||||
|
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
||||||
auto cnodes = func_graph->GetOrderedCnodes();
|
auto cnodes = func_graph->GetOrderedCnodes();
|
||||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||||
|
@ -202,24 +170,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
RemoveIfMakeTuple(cnode);
|
RemoveIfMakeTuple(cnode);
|
||||||
|
|
||||||
|
auto node = std::make_unique<schema::CNodeT>();
|
||||||
|
|
||||||
if (primT->value.type == schema::PrimitiveType_Return) {
|
if (primT->value.type == schema::PrimitiveType_Return) {
|
||||||
AddOutPutIfReturn(meta_graphT, cnode);
|
node->name = "return_node";
|
||||||
|
SetGraphoutputIndex(cnode, meta_graphT, node.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto node = std::make_unique<schema::CNodeT>();
|
|
||||||
node->name = cnode->fullname_with_scope();
|
|
||||||
node->nodeType = schema::NodeType_CNode;
|
node->nodeType = schema::NodeType_CNode;
|
||||||
|
node->name = cnode->fullname_with_scope();
|
||||||
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
|
||||||
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
auto ret = SetOpInputNode(cnode, meta_graphT, node.get());
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "SetOpInputNode failed";
|
MS_LOG(ERROR) << "SetOpInputNode failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
SetOpOutputNode(cnode, meta_graphT, node.get());
|
SetOpOutputNode(cnode, meta_graphT, node.get());
|
||||||
|
|
||||||
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
|
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
||||||
|
|
|
@ -36,8 +36,6 @@ class AnfExporter {
|
||||||
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||||
schema::CNodeT *fb_node);
|
schema::CNodeT *fb_node);
|
||||||
void RemoveIfMakeTuple(const CNodePtr &cnode);
|
void RemoveIfMakeTuple(const CNodePtr &cnode);
|
||||||
bool RemoveIfTupleGetItem(const CNodePtr &cnode);
|
|
||||||
bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode);
|
int ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode);
|
||||||
|
@ -46,6 +44,8 @@ class AnfExporter {
|
||||||
int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
||||||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
|
||||||
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
|
||||||
|
void SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||||
|
schema::CNodeT *return_node);
|
||||||
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
|
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
|
||||||
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
|
||||||
const std::shared_ptr<PrimitiveTValue> primitive,
|
const std::shared_ptr<PrimitiveTValue> primitive,
|
||||||
|
|
|
@ -63,7 +63,7 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
||||||
attr->type = schema::ActivationType_SIGMOID;
|
attr->type = schema::ActivationType_SIGMOID;
|
||||||
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
} else if (std::strcmp(node_name, "HardSwish") == 0) {
|
||||||
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
|
||||||
attr->type = schema::ActivationType_SIGMOID;
|
attr->type = schema::ActivationType_HSWISH;
|
||||||
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
|
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
|
||||||
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
||||||
if (tflite_attr == nullptr) {
|
if (tflite_attr == nullptr) {
|
||||||
|
|
|
@ -273,7 +273,7 @@ int TimeProfile::PrintResult(const std::vector<std::string> &title,
|
||||||
columnLenMax.at(i) = printBuf.size();
|
columnLenMax.at(i) = printBuf.size();
|
||||||
}
|
}
|
||||||
printBuf.resize(columnLenMax.at(i), ' ');
|
printBuf.resize(columnLenMax.at(i), ' ');
|
||||||
printf("%s", printBuf.c_str());
|
printf("%s\t", printBuf.c_str());
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
for (size_t i = 0; i < rows.size(); i++) {
|
for (size_t i = 0; i < rows.size(); i++) {
|
||||||
|
|
Loading…
Reference in New Issue