forked from mindspore-Ecosystem/mindspore
fix topk and trans_format_insert_pass
This commit is contained in:
parent
eba1e58140
commit
3f42e6c659
|
@ -31,16 +31,10 @@ int DescendCmp(const void *a, const void *b) {
|
|||
}
|
||||
|
||||
int AscendCmp(const void *a, const void *b) {
|
||||
float sub = ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
|
||||
if (sub > 0) {
|
||||
return 1;
|
||||
} else if (sub < 0) {
|
||||
return -1;
|
||||
}
|
||||
if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) {
|
||||
return -1;
|
||||
} else {
|
||||
return 1;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,10 +52,9 @@ void Topk(float *input_data, float *output_data, int32_t *output_index, TopkPara
|
|||
top_map[j].element = *(cur_input_data + j);
|
||||
top_map[j].index = j;
|
||||
}
|
||||
if (parameter->sorted_) {
|
||||
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
|
||||
} else {
|
||||
qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmp);
|
||||
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
|
||||
if (!parameter->sorted_) {
|
||||
qsort(top_map, k, sizeof(top_map[0]), AscendCmp);
|
||||
}
|
||||
for (int m = 0; m < k; m++) {
|
||||
cur_output_data[m] = top_map[m].element;
|
||||
|
|
|
@ -206,6 +206,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|||
continue;
|
||||
}
|
||||
#endif
|
||||
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
|
||||
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
|
||||
|
|
Loading…
Reference in New Issue