diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c index 839f40b2a7f..cba65c2c3c1 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c @@ -987,6 +987,22 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) { return NNACL_OK; } +int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + int32x4_t vin0 = vld1q_s32(input0 + index); + int32x4_t vin1 = vld1q_s32(input1 + index); + int32x4_t vout = vminq_s32(vin0, vin1); + vst1q_s32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return NNACL_OK; +} + int BroadcastMaximum(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, ArithmeticParameter *param) { TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h index d8bc67e8974..12e8eb59f78 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h @@ -95,6 +95,7 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int int ElementMaximum(const float *in0, const float *in1, float *out, int size); int ElementMinimum(const float *in0, const float *in1, float *out, int size); int ElementMaximumInt(const int *in0, const int *in1, int *out, int size); +int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size); int BroadcastMaximum(const float *in0, const float *in1, float *tile_input0, float *tile_input1, float *out, int size, ArithmeticParameter *param); diff --git a/mindspore/lite/src/ops/tensorlist_stack.cc b/mindspore/lite/src/ops/tensorlist_stack.cc index 9e06b912fdb..1b206bae670 100644 --- a/mindspore/lite/src/ops/tensorlist_stack.cc +++ b/mindspore/lite/src/ops/tensorlist_stack.cc @@ -165,6 +165,7 @@ int TensorListStack::InferShape(std::vector inputs_, std::vector output->set_data_type(input0->tensors_data_type()); output_shape_.insert(output_shape_.begin(), input0->ElementsNum()); output->set_shape(output_shape_); + output->set_format(input0->format()); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 8d44c2d7d61..96e511f967a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -169,6 +169,7 @@ void ArithmeticCPUKernel::InitRunFunction() { break; case PrimitiveType_Minimum: arithmetic_run_ = ElementMinimum; + arithmetic_run_int_ = ElementMinimumInt; break; case PrimitiveType_FloorDiv: arithmetic_run_ = ElementFloorDiv; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc index d22fc32234f..d32263dad81 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.cc @@ -30,7 +30,17 @@ namespace mindspore::kernel { int TensorListReserveCPUKernel::Init() { return RET_OK; } int TensorListReserveCPUKernel::Run() { + auto input0 = in_tensors_.at(0); + auto input1 = in_tensors_.at(1); + int num_elements = reinterpret_cast(input1->data_c())[0]; auto output = reinterpret_cast(out_tensors_[0]); + if (output->tensors().size() < (uint32_t)num_elements) { + auto ele_shape_ptr = reinterpret_cast(input0->data_c()); + std::vector > tmp_shape(num_elements, std::vector()); + output->set_element_shape(std::vector(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum())); + output->set_shape(std::vector(1, num_elements)); + output->MallocTensorListData(kTypeUnknown, tmp_shape); + } output->set_tensors_data_type(element_dtype_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc index 5a74c14c733..53903af875b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc @@ -83,7 +83,14 @@ int TensorListSetItemCPUKernel::Run() { auto src = input0_->GetTensor(i); auto dst = output0_->GetTensor(i); MS_ASSERT(src != nullptr); - MS_ASSERT(dst != nullptr); + // merge move data will delete tensors + if (dst == nullptr) { + dst = lite::Tensor::CopyTensor(*src, src->data_c() != nullptr); + auto &tensors = output0_->tensors(); + tensors.emplace_back(dst); + continue; + } + if (src->data_type() != kTypeUnknown) { if (src->Size() != dst->Size()) { MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index da022e599ed..a91893dfa83 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -288,7 +288,6 @@ int Tensor::set_root_tensor(Tensor *tensor) { this->shape_ = this->root_tensor_->shape_; this->format_ = this->root_tensor_->format_; this->data_type_ = this->root_tensor_->data_type_; - this->allocator_ = this->root_tensor_->allocator_; this->category_ = this->root_tensor_->category_; this->quant_params_ = this->root_tensor_->quant_params_; this->quant_clusters_ = this->root_tensor_->quant_clusters_; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index d22c97e96e8..35a175bc1b7 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -264,8 +264,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // init old node indecies auto old_nodes = GetGraphNodes(); Optimizer selectOptimizer; - selectOptimizer.AddPass(new (std::nothrow) SelectPass()); - selectOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + selectOptimizer.AddPass(new (std::nothrow) SelectPass(graphDefT)); status = selectOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "Run switch graphPasses Failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.cc index 57cbaaca31f..8e96c86f6a5 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include "tools/converter/legacy_optimizer/graph/select_pass.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" @@ -40,6 +41,52 @@ STATUS SelectPass::Run(mindspore::schema::MetaGraphT *graph) { MS_LOG(ERROR) << "node: " << node->name << "'s select pass failed: " << ret; return ret; } + select_indices_.emplace_back(i); + } + int ret = RemoveSelectNodes(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "remove select nodes failed"; + return ret; + } + return RET_OK; +} + +STATUS SelectPass::RemoveSelectNodes() { + std::sort(select_indices_.begin(), select_indices_.end(), std::greater()); + for (auto select_indice : select_indices_) { + auto &node = graph_->nodes.at(select_indice); + if (node->primitive->value.type != PrimitiveType_Select) { + MS_LOG(ERROR) << "node " << node->name << " is not a select node"; + return RET_ERROR; + } + int subgraph_idx = -1; + for (size_t i = 0; i < graph_->subGraph.size(); i++) { + if (IsContain(graph_->subGraph.at(i)->nodeIndices, select_indice)) { + subgraph_idx = i; + break; + } + } + + if (subgraph_idx == -1) { + MS_LOG(ERROR) << "select node " << node->name << " is not belong to any subgraph"; + return RET_ERROR; + } + graph_->nodes.erase(graph_->nodes.begin() + select_indice); + std::vector new_node_indices; + std::copy_if(graph_->subGraph.at(subgraph_idx)->nodeIndices.begin(), + graph_->subGraph.at(subgraph_idx)->nodeIndices.end(), + std::inserter(new_node_indices, new_node_indices.begin()), + [&select_indice](int indice) { return (uint32_t)indice != select_indice; }); + graph_->subGraph.at(subgraph_idx)->nodeIndices = new_node_indices; + for (auto &subgraph : graph_->subGraph) { + std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), + [&select_indice](uint32_t idx) { + if (idx > select_indice) { + return --idx; + } + return idx; + }); + } } return RET_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.h index bfdc1b215e4..049b4a41a48 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/select_pass.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "tools/common/graph_util.h" #include "tools/converter/optimizer.h" @@ -28,9 +29,14 @@ namespace mindspore { namespace lite { class SelectPass : public GraphPass { public: - SelectPass() = default; + explicit SelectPass(schema::MetaGraphT *graph) : graph_(graph) {} ~SelectPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; + STATUS RemoveSelectNodes(); + + private: + std::vector select_indices_; + schema::MetaGraphT *graph_ = nullptr; }; class SingleSelectPass { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc index b335c3f77a5..9cf08bbda90 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc @@ -251,9 +251,34 @@ STATUS SingleSwitchPass::InsertMerge() { second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(), switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2); + // skip tensor which is not any nodes' inputs to avoid body partial connect to merge input cnode + std::vector skip_input_tensors; + for (auto input : const_input) { + auto real_input = graph_->subGraph.at(second_subgraph_index_)->inputIndices.at(input); + bool skip = true; + for (auto &node : second_graph_nodes_) { + if (IsContain(node->inputIndex, real_input)) { + skip = false; + break; + } + } + if (skip) { + auto &skip_tensor = graph_->allTensors.at(real_input); + int partial_idx = GetSubgraphInputTensorIndex(graph_->subGraph.at(second_subgraph_index_), skip_tensor); + skip_input_tensors.emplace_back(partial_idx); + } + } + // concat body output to merge input - second_partial_node_->outputIndex.assign(merge_node->inputIndex.begin() + merge_node->inputIndex.size() / 2, - merge_node->inputIndex.end()); + second_partial_node_->outputIndex.clear(); + for (uint32_t merge_right_input = 0; merge_right_input < merge_node->inputIndex.size() / 2; merge_right_input++) { + if (!IsContain(skip_input_tensors, merge_right_input)) { + second_partial_node_->outputIndex.emplace_back( + merge_node->inputIndex.at(merge_node->inputIndex.size() / 2 + merge_right_input)); + } else { + second_partial_node_->outputIndex.emplace_back(UINT32_MAX); + } + } graph_->nodes.push_back(std::move(merge_node)); @@ -544,6 +569,13 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche [](std::pair iter) { return iter.second; }); subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); + // filter for -1 output index + std::vector new_partial_outputs; + std::copy_if(partial_outputs.begin(), partial_outputs.end(), + std::inserter(new_partial_outputs, new_partial_outputs.begin()), + [](uint32_t output) { return output != UINT32_MAX; }); + partial_node->outputIndex = new_partial_outputs; + return RET_OK; }