forked from mindspore-Ecosystem/mindspore
!11043 encoder no fusion run success
From: @cjh9368 Reviewed-by: @zhanghaibo5 Signed-off-by:
This commit is contained in:
commit
aad0f0561f
|
@ -987,6 +987,22 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) {
|
||||||
return NNACL_OK;
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) {
|
||||||
|
int index = 0;
|
||||||
|
#ifdef ENABLE_NEON
|
||||||
|
for (; index <= element_size - 4; index += C4NUM) {
|
||||||
|
int32x4_t vin0 = vld1q_s32(input0 + index);
|
||||||
|
int32x4_t vin1 = vld1q_s32(input1 + index);
|
||||||
|
int32x4_t vout = vminq_s32(vin0, vin1);
|
||||||
|
vst1q_s32(output + index, vout);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; index < element_size; index++) {
|
||||||
|
output[index] = input0[index] > input1[index] ? input1[index] : input0[index];
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int BroadcastMaximum(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
|
int BroadcastMaximum(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
|
||||||
ArithmeticParameter *param) {
|
ArithmeticParameter *param) {
|
||||||
TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param);
|
TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param);
|
||||||
|
|
|
@ -95,6 +95,7 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int
|
||||||
int ElementMaximum(const float *in0, const float *in1, float *out, int size);
|
int ElementMaximum(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementMinimum(const float *in0, const float *in1, float *out, int size);
|
int ElementMinimum(const float *in0, const float *in1, float *out, int size);
|
||||||
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
|
int ElementMaximumInt(const int *in0, const int *in1, int *out, int size);
|
||||||
|
int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size);
|
||||||
int BroadcastMaximum(const float *in0, const float *in1, float *tile_input0, float *tile_input1, float *out, int size,
|
int BroadcastMaximum(const float *in0, const float *in1, float *tile_input0, float *tile_input1, float *out, int size,
|
||||||
ArithmeticParameter *param);
|
ArithmeticParameter *param);
|
||||||
|
|
||||||
|
|
|
@ -165,6 +165,7 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector
|
||||||
output->set_data_type(input0->tensors_data_type());
|
output->set_data_type(input0->tensors_data_type());
|
||||||
output_shape_.insert(output_shape_.begin(), input0->ElementsNum());
|
output_shape_.insert(output_shape_.begin(), input0->ElementsNum());
|
||||||
output->set_shape(output_shape_);
|
output->set_shape(output_shape_);
|
||||||
|
output->set_format(input0->format());
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,7 @@ void ArithmeticCPUKernel::InitRunFunction() {
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_Minimum:
|
case PrimitiveType_Minimum:
|
||||||
arithmetic_run_ = ElementMinimum;
|
arithmetic_run_ = ElementMinimum;
|
||||||
|
arithmetic_run_int_ = ElementMinimumInt;
|
||||||
break;
|
break;
|
||||||
case PrimitiveType_FloorDiv:
|
case PrimitiveType_FloorDiv:
|
||||||
arithmetic_run_ = ElementFloorDiv;
|
arithmetic_run_ = ElementFloorDiv;
|
||||||
|
|
|
@ -30,7 +30,17 @@ namespace mindspore::kernel {
|
||||||
int TensorListReserveCPUKernel::Init() { return RET_OK; }
|
int TensorListReserveCPUKernel::Init() { return RET_OK; }
|
||||||
|
|
||||||
int TensorListReserveCPUKernel::Run() {
|
int TensorListReserveCPUKernel::Run() {
|
||||||
|
auto input0 = in_tensors_.at(0);
|
||||||
|
auto input1 = in_tensors_.at(1);
|
||||||
|
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];
|
||||||
auto output = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
|
auto output = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
|
||||||
|
if (output->tensors().size() < (uint32_t)num_elements) {
|
||||||
|
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c());
|
||||||
|
std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>());
|
||||||
|
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum()));
|
||||||
|
output->set_shape(std::vector<int>(1, num_elements));
|
||||||
|
output->MallocTensorListData(kTypeUnknown, tmp_shape);
|
||||||
|
}
|
||||||
output->set_tensors_data_type(element_dtype_);
|
output->set_tensors_data_type(element_dtype_);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,7 +83,14 @@ int TensorListSetItemCPUKernel::Run() {
|
||||||
auto src = input0_->GetTensor(i);
|
auto src = input0_->GetTensor(i);
|
||||||
auto dst = output0_->GetTensor(i);
|
auto dst = output0_->GetTensor(i);
|
||||||
MS_ASSERT(src != nullptr);
|
MS_ASSERT(src != nullptr);
|
||||||
MS_ASSERT(dst != nullptr);
|
// merge move data will delete tensors
|
||||||
|
if (dst == nullptr) {
|
||||||
|
dst = lite::Tensor::CopyTensor(*src, src->data_c() != nullptr);
|
||||||
|
auto &tensors = output0_->tensors();
|
||||||
|
tensors.emplace_back(dst);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (src->data_type() != kTypeUnknown) {
|
if (src->data_type() != kTypeUnknown) {
|
||||||
if (src->Size() != dst->Size()) {
|
if (src->Size() != dst->Size()) {
|
||||||
MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size();
|
MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size();
|
||||||
|
|
|
@ -288,7 +288,6 @@ int Tensor::set_root_tensor(Tensor *tensor) {
|
||||||
this->shape_ = this->root_tensor_->shape_;
|
this->shape_ = this->root_tensor_->shape_;
|
||||||
this->format_ = this->root_tensor_->format_;
|
this->format_ = this->root_tensor_->format_;
|
||||||
this->data_type_ = this->root_tensor_->data_type_;
|
this->data_type_ = this->root_tensor_->data_type_;
|
||||||
this->allocator_ = this->root_tensor_->allocator_;
|
|
||||||
this->category_ = this->root_tensor_->category_;
|
this->category_ = this->root_tensor_->category_;
|
||||||
this->quant_params_ = this->root_tensor_->quant_params_;
|
this->quant_params_ = this->root_tensor_->quant_params_;
|
||||||
this->quant_clusters_ = this->root_tensor_->quant_clusters_;
|
this->quant_clusters_ = this->root_tensor_->quant_clusters_;
|
||||||
|
|
|
@ -264,8 +264,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
// init old node indecies
|
// init old node indecies
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
Optimizer selectOptimizer;
|
Optimizer selectOptimizer;
|
||||||
selectOptimizer.AddPass(new (std::nothrow) SelectPass());
|
selectOptimizer.AddPass(new (std::nothrow) SelectPass(graphDefT));
|
||||||
selectOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
|
||||||
status = selectOptimizer.Run(graphDefT);
|
status = selectOptimizer.Run(graphDefT);
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||||
MS_LOG(ERROR) << "Run switch graphPasses Failed";
|
MS_LOG(ERROR) << "Run switch graphPasses Failed";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <algorithm>
|
||||||
#include "tools/converter/legacy_optimizer/graph/select_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/select_pass.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -40,6 +41,52 @@ STATUS SelectPass::Run(mindspore::schema::MetaGraphT *graph) {
|
||||||
MS_LOG(ERROR) << "node: " << node->name << "'s select pass failed: " << ret;
|
MS_LOG(ERROR) << "node: " << node->name << "'s select pass failed: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
select_indices_.emplace_back(i);
|
||||||
|
}
|
||||||
|
int ret = RemoveSelectNodes();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "remove select nodes failed";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS SelectPass::RemoveSelectNodes() {
|
||||||
|
std::sort(select_indices_.begin(), select_indices_.end(), std::greater<int>());
|
||||||
|
for (auto select_indice : select_indices_) {
|
||||||
|
auto &node = graph_->nodes.at(select_indice);
|
||||||
|
if (node->primitive->value.type != PrimitiveType_Select) {
|
||||||
|
MS_LOG(ERROR) << "node " << node->name << " is not a select node";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
int subgraph_idx = -1;
|
||||||
|
for (size_t i = 0; i < graph_->subGraph.size(); i++) {
|
||||||
|
if (IsContain(graph_->subGraph.at(i)->nodeIndices, select_indice)) {
|
||||||
|
subgraph_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (subgraph_idx == -1) {
|
||||||
|
MS_LOG(ERROR) << "select node " << node->name << " is not belong to any subgraph";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
graph_->nodes.erase(graph_->nodes.begin() + select_indice);
|
||||||
|
std::vector<uint32_t> new_node_indices;
|
||||||
|
std::copy_if(graph_->subGraph.at(subgraph_idx)->nodeIndices.begin(),
|
||||||
|
graph_->subGraph.at(subgraph_idx)->nodeIndices.end(),
|
||||||
|
std::inserter(new_node_indices, new_node_indices.begin()),
|
||||||
|
[&select_indice](int indice) { return (uint32_t)indice != select_indice; });
|
||||||
|
graph_->subGraph.at(subgraph_idx)->nodeIndices = new_node_indices;
|
||||||
|
for (auto &subgraph : graph_->subGraph) {
|
||||||
|
std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
|
||||||
|
[&select_indice](uint32_t idx) {
|
||||||
|
if (idx > select_indice) {
|
||||||
|
return --idx;
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <functional>
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/converter/optimizer.h"
|
#include "tools/converter/optimizer.h"
|
||||||
|
|
||||||
|
@ -28,9 +29,14 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
class SelectPass : public GraphPass {
|
class SelectPass : public GraphPass {
|
||||||
public:
|
public:
|
||||||
SelectPass() = default;
|
explicit SelectPass(schema::MetaGraphT *graph) : graph_(graph) {}
|
||||||
~SelectPass() override = default;
|
~SelectPass() override = default;
|
||||||
STATUS Run(schema::MetaGraphT *graph) override;
|
STATUS Run(schema::MetaGraphT *graph) override;
|
||||||
|
STATUS RemoveSelectNodes();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<uint32_t> select_indices_;
|
||||||
|
schema::MetaGraphT *graph_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SingleSelectPass {
|
class SingleSelectPass {
|
||||||
|
|
|
@ -251,9 +251,34 @@ STATUS SingleSwitchPass::InsertMerge() {
|
||||||
second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(),
|
second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(),
|
||||||
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2);
|
switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2);
|
||||||
|
|
||||||
|
// skip tensor which is not any nodes' inputs to avoid body partial connect to merge input cnode
|
||||||
|
std::vector<uint32_t> skip_input_tensors;
|
||||||
|
for (auto input : const_input) {
|
||||||
|
auto real_input = graph_->subGraph.at(second_subgraph_index_)->inputIndices.at(input);
|
||||||
|
bool skip = true;
|
||||||
|
for (auto &node : second_graph_nodes_) {
|
||||||
|
if (IsContain(node->inputIndex, real_input)) {
|
||||||
|
skip = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (skip) {
|
||||||
|
auto &skip_tensor = graph_->allTensors.at(real_input);
|
||||||
|
int partial_idx = GetSubgraphInputTensorIndex(graph_->subGraph.at(second_subgraph_index_), skip_tensor);
|
||||||
|
skip_input_tensors.emplace_back(partial_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// concat body output to merge input
|
// concat body output to merge input
|
||||||
second_partial_node_->outputIndex.assign(merge_node->inputIndex.begin() + merge_node->inputIndex.size() / 2,
|
second_partial_node_->outputIndex.clear();
|
||||||
merge_node->inputIndex.end());
|
for (uint32_t merge_right_input = 0; merge_right_input < merge_node->inputIndex.size() / 2; merge_right_input++) {
|
||||||
|
if (!IsContain(skip_input_tensors, merge_right_input)) {
|
||||||
|
second_partial_node_->outputIndex.emplace_back(
|
||||||
|
merge_node->inputIndex.at(merge_node->inputIndex.size() / 2 + merge_right_input));
|
||||||
|
} else {
|
||||||
|
second_partial_node_->outputIndex.emplace_back(UINT32_MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
graph_->nodes.push_back(std::move(merge_node));
|
graph_->nodes.push_back(std::move(merge_node));
|
||||||
|
|
||||||
|
@ -544,6 +569,13 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
|
||||||
[](std::pair<int, int> iter) { return iter.second; });
|
[](std::pair<int, int> iter) { return iter.second; });
|
||||||
subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end());
|
subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end());
|
||||||
|
|
||||||
|
// filter for -1 output index
|
||||||
|
std::vector<uint32_t> new_partial_outputs;
|
||||||
|
std::copy_if(partial_outputs.begin(), partial_outputs.end(),
|
||||||
|
std::inserter(new_partial_outputs, new_partial_outputs.begin()),
|
||||||
|
[](uint32_t output) { return output != UINT32_MAX; });
|
||||||
|
partial_node->outputIndex = new_partial_outputs;
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue