fix_topkv2_parser

This commit is contained in:
sunsuodong 2020-08-17 16:17:43 +08:00
parent f41ca6b5c6
commit 676e44a130
4 changed files with 9 additions and 15 deletions

View File

@ -190,7 +190,6 @@ union PrimitiveType {
ActivationGrad,
PriorBox,
SpaceToBatchND,
TopKV2,
Return,
MakeTuple,
ToFormat,

View File

@ -872,12 +872,6 @@ table SpaceToBatchND {
paddings : [int];
}
table TopKV2 {
k : [int];
sorted : bool = true;
}
table MakeTuple {
}

View File

@ -31,14 +31,13 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopK) << "wrong Op Type";
}
TEST_F(TestTfliteParserTopKV2, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
std::vector<int> k = {3};
ASSERT_EQ(val->k, k);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopK(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTopK();
ASSERT_EQ(val->k, 3);
ASSERT_EQ(val->sorted, true);
}
} // namespace mindspore

View File

@ -41,15 +41,17 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
std::unique_ptr<schema::TopKV2T> attr(new schema::TopKV2T());
std::unique_ptr<schema::TopKT> attr(new schema::TopKT());
attr->sorted = true;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) {
std::vector<int32_t> k;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, k)) {
MS_LOG(ERROR) << "get topKV2 -> k failed";
return RET_ERROR;
}
attr->k = k.front();
op->primitive->value.type = schema::PrimitiveType_TopKV2;
op->primitive->value.type = schema::PrimitiveType_TopK;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,