forked from mindspore-Ecosystem/mindspore
fix_topkv2_parser
This commit is contained in:
parent
f41ca6b5c6
commit
676e44a130
|
@ -190,7 +190,6 @@ union PrimitiveType {
|
|||
ActivationGrad,
|
||||
PriorBox,
|
||||
SpaceToBatchND,
|
||||
TopKV2,
|
||||
Return,
|
||||
MakeTuple,
|
||||
ToFormat,
|
||||
|
|
|
@ -872,12 +872,6 @@ table SpaceToBatchND {
|
|||
paddings : [int];
|
||||
}
|
||||
|
||||
table TopKV2 {
|
||||
k : [int];
|
||||
sorted : bool = true;
|
||||
}
|
||||
|
||||
|
||||
table MakeTuple {
|
||||
}
|
||||
|
||||
|
|
|
@ -31,14 +31,13 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
|
|||
ASSERT_NE(meta_graph, nullptr);
|
||||
ASSERT_GT(meta_graph->nodes.size(), 0);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type";
|
||||
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopK) << "wrong Op Type";
|
||||
}
|
||||
|
||||
TEST_F(TestTfliteParserTopKV2, AttrValue) {
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
|
||||
std::vector<int> k = {3};
|
||||
ASSERT_EQ(val->k, k);
|
||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopK(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsTopK();
|
||||
ASSERT_EQ(val->k, 3);
|
||||
ASSERT_EQ(val->sorted, true);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,15 +41,17 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TopKV2T> attr(new schema::TopKV2T());
|
||||
std::unique_ptr<schema::TopKT> attr(new schema::TopKT());
|
||||
|
||||
attr->sorted = true;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) {
|
||||
std::vector<int32_t> k;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, k)) {
|
||||
MS_LOG(ERROR) << "get topKV2 -> k failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->k = k.front();
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_TopKV2;
|
||||
op->primitive->value.type = schema::PrimitiveType_TopK;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
|
|
Loading…
Reference in New Issue