fix topk and trans_format_insert_pass

This commit is contained in:
gongdaguo 2021-01-06 10:55:49 +08:00
parent eba1e58140
commit 3f42e6c659
2 changed files with 9 additions and 12 deletions

View File

@ -31,16 +31,10 @@ int DescendCmp(const void *a, const void *b) {
}
int AscendCmp(const void *a, const void *b) {
float sub = ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
if (sub > 0) {
return 1;
} else if (sub < 0) {
return -1;
}
if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) {
return -1;
} else {
return 1;
} else {
return -1;
}
}
@ -58,10 +52,9 @@ void Topk(float *input_data, float *output_data, int32_t *output_index, TopkPara
top_map[j].element = *(cur_input_data + j);
top_map[j].index = j;
}
if (parameter->sorted_) {
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
} else {
qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmp);
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
if (!parameter->sorted_) {
qsort(top_map, k, sizeof(top_map[0]), AscendCmp);
}
for (int m = 0; m < k; m++) {
cur_output_data[m] = top_map[m].element;

View File

@ -206,6 +206,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
continue;
}
#endif
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
continue;
}
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";