From 447329fb8df62a9d8d205a9be4dd32139928d01a Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Mon, 22 Mar 2021 20:16:18 +0800 Subject: [PATCH] remove param value lite --- mindspore/lite/src/param_value_lite.h | 79 ----- mindspore/lite/test/CMakeLists.txt | 8 +- .../test/common/import_from_meta_graphT.cc | 39 +-- mindspore/lite/test/models_npu.cfg | 2 +- .../test/models_with_multiple_inputs_fp16.cfg | 4 +- .../lite/tools/anf_exporter/anf_exporter.cc | 306 +++++++++++------- .../lite/tools/anf_exporter/anf_exporter.h | 19 +- mindspore/lite/tools/common/tensor_util.cc | 108 ++++++- mindspore/lite/tools/common/tensor_util.h | 14 + .../graph/tensor_quant_pass.cc | 2 +- .../parser/caffe/caffe_model_parser.cc | 35 +- .../parser/onnx/onnx_constant_parser.cc | 20 +- .../onnx/onnx_given_tensor_fill_parser.cc | 47 +-- .../parser/onnx/onnx_model_parser.cc | 102 +++--- .../converter/parser/onnx/onnx_model_parser.h | 5 +- .../converter/parser/tf/tf_model_parser.cc | 244 ++++++++------ .../converter/parser/tf/tf_model_parser.h | 7 +- .../parser/tflite/tflite_model_parser.cc | 67 ++-- .../parser/tflite/tflite_model_parser.h | 3 +- .../converter/quantizer/huffman_encode.cc | 20 +- .../converter/quantizer/huffman_encode.h | 3 +- .../quantizer/post_training_quantizer.cc | 67 ++-- .../converter/quantizer/quantize_util.cc | 73 ++--- .../tools/converter/quantizer/quantize_util.h | 55 ++-- .../tools/converter/quantizer/quantizer.h | 1 - .../converter/quantizer/weight_quantizer.cc | 148 ++++----- .../converter/quantizer/weight_quantizer.h | 21 +- .../lite/tools/optimizer/common/gllo_utils.cc | 242 ++++++-------- .../lite/tools/optimizer/common/gllo_utils.h | 20 +- .../optimizer/fusion/batchmatmul_fusion.cc | 51 ++- .../fusion/constant_folding_fusion.cc | 43 ++- .../optimizer/fusion/conv_biasadd_fusion.cc | 16 +- .../tools/optimizer/fusion/conv_bn_fusion.cc | 29 +- .../optimizer/fusion/conv_conv_fusion.cc | 83 ++--- .../optimizer/fusion/conv_scale_fusion.cc | 9 +- .../optimizer/fusion/conv_transform_fusion.cc | 54 +--- .../optimizer/fusion/conv_transform_fusion.h | 3 +- .../fusion/conv_tuplegetitem_fusion.cc | 1 - .../tools/optimizer/fusion/gelu_fusion.cc | 8 +- .../tools/optimizer/fusion/norm_fusion.cc | 26 +- .../optimizer/fusion/sigmoid_mul_fusion.cc | 1 - .../fusion/tf_bidirection_gru_cf_fusion.h | 1 - .../fusion/tf_bidirection_gru_fusion.cc | 72 ++--- .../fusion/tf_bidirection_gru_fusion.h | 3 +- .../optimizer/fusion/tf_lstm_cell_fusion.cc | 89 +++-- .../optimizer/fusion/tf_lstm_cell_fusion.h | 3 +- .../fusion/tflite_lstm_cell_fusion.cc | 86 +++-- .../fusion/tflite_lstm_cell_fusion.h | 2 +- .../graph/clip_convert_activation_pass.cc | 12 +- .../graph/clip_convert_activation_pass.h | 1 - .../graph/conv1d_weight_expanding_pass.cc | 24 +- .../graph/conv1d_weight_expanding_pass.h | 3 +- .../graph/group_depthwise_op_convert_pass.cc | 17 +- .../graph/group_depthwise_op_convert_pass.h | 1 - .../lite/tools/optimizer/graph/if_pass.h | 1 - .../tools/optimizer/graph/infershape_pass.cc | 132 +++----- .../tools/optimizer/graph/infershape_pass.h | 1 - .../optimizer/graph/inputs_adjust_pass.h | 1 - .../optimizer/graph/mindir_adjust_pass.cc | 77 +---- .../optimizer/graph/mindir_adjust_pass.h | 2 - .../graph/onnx_inputs_adjust_pass.cc | 38 ++- .../optimizer/graph/onnx_pad_adjust_pass.cc | 37 +-- .../optimizer/graph/slice_prepose_pass.cc | 31 +- .../graph/tflite_inputs_adjust_pass.cc | 8 +- .../graph/tflite_inputs_adjust_pass.h | 1 - .../unused_transpose_node_remove_pass.cc | 6 +- .../graph/update_conv2d_param_pass.cc | 8 +- .../graph/weight_format_hardcode_pass.cc | 65 ++-- .../graph/weight_format_hardcode_pass.h | 11 +- .../graph/weight_format_transform_pass.cc | 37 ++- .../lite/tools/optimizer/graph/while_pass.h | 1 - 71 files changed, 1344 insertions(+), 1512 deletions(-) delete mode 100644 mindspore/lite/src/param_value_lite.h diff --git a/mindspore/lite/src/param_value_lite.h b/mindspore/lite/src/param_value_lite.h deleted file mode 100644 index 3c2bf75707e..00000000000 --- a/mindspore/lite/src/param_value_lite.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ -#define MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ - -#include -#include -#include -#include -#include "src/tensor.h" - -namespace mindspore { -class ParamValueLite : public Value { - public: - ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} - ~ParamValueLite() override { - if (tensor_addr_ != nullptr) { - auto tensor_mem = reinterpret_cast(tensor_addr_); - delete[](tensor_mem); - tensor_addr_ = nullptr; - tensor_size_ = 0; - } - } - MS_DECLARE_PARENT(ParamValueLite, Value) - size_t tensor_size() const { return tensor_size_; } - void set_tensor_size(const size_t size) { tensor_size_ = size; } - void *tensor_addr() const { return tensor_addr_; } - void set_tensor_addr(void *addr) { tensor_addr_ = addr; } - - std::vector tensor_shape() const { return tensor_shape_; } - void set_tensor_shape(const std::vector &tensor_shape) { tensor_shape_ = tensor_shape; } - - TypeId tensor_type() const { return type_id_; } - void set_tensor_type(const TypeId type_id) { type_id_ = type_id; } - - void SetTensorData(void *addr, const size_t size) { - this->tensor_addr_ = addr; - this->tensor_size_ = size; - } - - int tensor_shape_size() const { - int size = 1; - for (auto val : tensor_shape_) { - size *= val; - } - return size; - } - - bool operator==(const Value &other) const override { return this == &other; } - - int format() const { return this->format_; } - - void set_format(int format) { this->format_ = format; } - - private: - void *tensor_addr_ = nullptr; - size_t tensor_size_ = 0; - int format_ = schema::Format::Format_KCHW; - std::vector tensor_shape_{}; - TypeId type_id_ = TypeId::kNumberTypeFloat32; -}; - -using ParamValueLitePtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_PARAM_VALUE_LITE_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 4b109631f73..66634ae1835 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -152,11 +152,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/common/file_utils.cc ${LITE_DIR}/src/common/utils.cc ${LITE_DIR}/src/common/string_util.cc - ${LITE_DIR}/tools/common/graph_util.cc - ${LITE_DIR}/tools/common/tensor_util.cc - ${LITE_DIR}/tools/common/node_util.cc ${LITE_DIR}/tools/common/flag_parser.cc - ${LITE_DIR}/tools/common/storage.cc ${LITE_DIR}/tools/benchmark/benchmark.cc ${LITE_DIR}/test/st/benchmark_test.cc ${LITE_DIR}/src/errorcode.cc @@ -271,6 +267,10 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc ${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc + ${LITE_DIR}/tools/common/graph_util.cc + ${LITE_DIR}/tools/common/tensor_util.cc + ${LITE_DIR}/tools/common/node_util.cc + ${LITE_DIR}/tools/common/storage.cc ) endif() ### train diff --git a/mindspore/lite/test/common/import_from_meta_graphT.cc b/mindspore/lite/test/common/import_from_meta_graphT.cc index 9d4af45e85c..dea4bee1eea 100644 --- a/mindspore/lite/test/common/import_from_meta_graphT.cc +++ b/mindspore/lite/test/common/import_from_meta_graphT.cc @@ -18,12 +18,12 @@ #include #include "schema/inner/model_generated.h" #include "frontend/operator/ops.h" -#include "src/param_value_lite.h" #include "src/common/log_adapter.h" #include "tools/converter/converter_context.h" #include "include/errorcode.h" #include "test/common/import_from_meta_graphT.h" -#include "ir/func_graph.h" +#include "src/common/utils.h" +#include "tools/common/tensor_util.h" namespace mindspore::lite { AnfNodePtr AnfImporterFromMetaGraphT::GetNode(int tensor_id) { @@ -50,41 +50,34 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); auto type_id = static_cast(tensor->dataType); auto type_ptr = TypeIdToType(type_id); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - MS_ASSERT(nullptr != abstract_tensor); - parameter->set_abstract(abstract_tensor); + std::vector shape_vector(shape.begin(), shape.end()); if (!tensor->name.empty()) { parameter->set_name(tensor->name); } else { parameter->set_name("const-" + std::to_string(i)); } - - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(nullptr != param_value); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - param_value->set_format(tensor->format); + tensor::TensorPtr tensor_info = std::make_shared(type_id, shape_vector); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + return RET_ERROR; + } + int status = RET_OK; if (!tensor->data.empty()) { auto size = tensor->data.size(); - char *tensor_data = new (std::nothrow) char[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } + auto tensor_data = static_cast(tensor_info->data_c()); auto ret = memcpy_s(tensor_data, size, tensor->data.data(), size); if (EOK != ret) { MS_LOG(ERROR) << "memcpy_s error"; - delete[] tensor_data; return RET_MEMORY_FAILED; } - param_value->SetTensorData(tensor_data, size); - parameter->set_default_param(param_value); + status = lite::InitParameterFromTensorInfo(parameter, tensor_info); } else if (std::find(meta_graph_->inputIndex.begin(), meta_graph_->inputIndex.end(), i) == meta_graph_->inputIndex.end()) { - parameter->set_default_param(param_value); + status = lite::InitParameterFromTensorInfo(parameter, tensor_info); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return RET_ERROR; } AddNode(i, parameter); } diff --git a/mindspore/lite/test/models_npu.cfg b/mindspore/lite/test/models_npu.cfg index c657bb302c1..150b60a3623 100644 --- a/mindspore/lite/test/models_npu.cfg +++ b/mindspore/lite/test/models_npu.cfg @@ -67,7 +67,7 @@ ml_video_edit_v10_best_model_nomean_20200723 8 #hdc_ocr_detect.onnx 30 #too many subgraphs ml_edu_kit_hand_detection.onnx 1 ml_edu_kit_hand_key_position.onnx 2 -ml_video_edit_oneclick_adaptis.pb 2 3 +#ml_video_edit_oneclick_adaptis.pb 2 3 densenet.tflite 3 resnet_v2_101_299.tflite 1 ml_video_edit_enhance.pb 2 diff --git a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg index a50a1513751..c30618a69a8 100644 --- a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg +++ b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg @@ -1,9 +1,9 @@ ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11 -ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 11 +ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 12.3 ml_video_edit_img_segment_adaptise.pb;2 40 ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5 ml_video_edit_person_divison_video;2 38 -ml_video_edit_oneclick_adaptis.pb;3 6 +ml_video_edit_oneclick_adaptis.pb;3 6.1 hdc_tb_cn_neg.tflite;3 281 decoder_step_201217.pb;5 187 ml_video_edit_art_transfer.onnx;3 3 diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 473e8b29eae..3a9656b47b3 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -21,6 +21,8 @@ #include #include #include +#include "tools/converter/converter_flags.h" +#include "tools/common/tensor_util.h" #include "abstract/abstract_value.h" #include "mindspore/core/ir/primitive.h" #include "ops/fusion/partial_fusion.h" @@ -32,8 +34,8 @@ #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/tensor.h" -#include "src/param_value_lite.h" #include "src/common/utils.h" +#include "ops/op_utils.h" #include "tools/common/graph_util.h" #include "src/ops/ops_utils.h" @@ -77,6 +79,51 @@ std::list GetOrderedCNodes(const FuncGraphPtr fg) { } return cnodes; } +ShapeVector GetShapeVectorFromTensorInfo(const tensor::TensorPtr &tensor_info, size_t *offset) { + ShapeVector shape_vector; + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + std::string shape_str; + std::string shape_size_str; + *offset = 0; + size_t cnt = 0; + for (; *offset < tensor_info->Size(); (*offset)++) { + if (tensor_data[*offset] == ',') { + (*offset)++; + break; + } + shape_size_str.push_back(tensor_data[*offset]); + } + size_t shape_size = std::stoi(shape_size_str); + for (; *offset < tensor_info->Size(); (*offset)++) { + if (tensor_data[*offset] == ',') { + cnt++; + shape_vector.push_back(std::stoi(shape_str)); + shape_str.clear(); + } else { + shape_str.push_back(tensor_data[*offset]); + } + if (cnt == shape_size) { + (*offset)++; + break; + } + } + + return shape_vector; +} +schema::Format GetFormatByFmk(int32_t fmk_type) { + switch (fmk_type) { + case converter::FmkType_ONNX: + case lite::converter::FmkType_CAFFE: + case lite::converter::FmkType_MS: + return schema::Format_NCHW; + case lite::converter::FmkType_TF: + case lite::converter::FmkType_TFLITE: + return schema::Format_NHWC; + default: + MS_LOG(ERROR) << "don't support current fmk: " + fmk_type; + return static_cast(fmk_type); + } +} } // namespace void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { @@ -164,9 +211,9 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me QuantParamsVector input_quant_params; QuantParamsVector output_quant_params; dst_node->quantType = schema::QuantType_QUANT_NONE; - auto quant_param_valueptr = primitive->GetAttr("quant_params"); - if (quant_param_valueptr != nullptr) { - auto quant_param_holder = quant_param_valueptr->cast(); + auto quant_tensor_info_ptr = primitive->GetAttr("quant_params"); + if (quant_tensor_info_ptr != nullptr) { + auto quant_param_holder = quant_tensor_info_ptr->cast(); if (quant_param_holder == nullptr) { MS_LOG(ERROR) << "quant param is invalid."; return RET_ERROR; @@ -553,160 +600,174 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano const std::shared_ptr &primitive_c, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode) { - auto paramNode = input_anode->cast(); - std::string input_name = paramNode->fullname_with_scope(); + auto param_node = input_anode->cast(); + std::string input_name = param_node->fullname_with_scope(); if (node_id_map_.find(input_name) != node_id_map_.end()) { - output_cnode->inputIndex.emplace_back(node_id_map_[paramNode->name()]); + output_cnode->inputIndex.emplace_back(node_id_map_[param_node->name()]); return RET_OK; } - auto paramTensor = std::make_unique(); - paramTensor->format = schema::Format_NHWC; - paramTensor->name = paramNode->name(); - auto abstractBase = paramNode->abstract(); - if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); + auto schema_tensor = std::make_unique(); + schema_tensor->format = GetFormatByFmk(meta_graphT->fmkType); + if (schema_tensor->format != schema::Format_NHWC && schema_tensor->format != schema::Format_NCHW) { + MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format; + return RET_ERROR; + } + schema_tensor->name = param_node->name(); + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); return RET_PARAM_INVALID; } - if (!utils::isa(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); return RET_INPUT_TENSOR_ERROR; } - auto abstractTensor = utils::cast(abstractBase); - auto typePtr = abstractTensor->element()->GetTypeTrack(); + auto abstract_tensor = utils::cast(abstract_base); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); MS_ASSERT(typePtr != nullptr); - paramTensor->dataType = typePtr->type_id(); - if (!utils::isa(abstractTensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); + schema_tensor->dataType = typePtr->type_id(); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); return RET_PARAM_INVALID; } - auto shape_vector = utils::cast(abstractTensor->BuildShape())->shape(); + auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + size_t offset = 0; + if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) { + shape_vector = GetShapeVectorFromTensorInfo(tensor_info, &offset); + } std::vector dims; (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), [](const int64_t &value) { return static_cast(value); }); - paramTensor->dims = dims; - auto paramValue = std::dynamic_pointer_cast(paramNode->default_param()); - if (paramValue != nullptr && paramValue->tensor_size() != 0) { - paramTensor->data.resize(paramValue->tensor_size()); - paramTensor->format = schema::Format(paramValue->format()); - if (EOK != memcpy_s(paramTensor->data.data(), paramTensor->data.size(), paramValue->tensor_addr(), - paramValue->tensor_size())) { - MS_LOG(ERROR) << "memcpy_s failed."; - return RET_ERROR; + schema_tensor->dims = dims; + if (tensor_info != nullptr && tensor_info->Size() != 0) { + if (schema_tensor->dataType == kObjectTypeTensorType && shape_vector.empty() && + meta_graphT->fmkType == converter::FmkType_ONNX) { + schema_tensor->data.resize(0); + } else { + schema_tensor->data.resize(tensor_info->Size() - offset); + if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), + static_cast(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_ERROR; + } } } - - paramTensor->name = input_name; + if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) { + schema_tensor->format = static_cast(GetValue(primitive_c->GetAttr(opt::kWeightFormat))); + } + schema_tensor->name = input_name; QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr ? nullptr : primitive_c->GetAttr("quant_params")->cast(); if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && - paramTensor->dataType == kNumberTypeInt8) { - paramTensor->enableHuffmanCode = true; + schema_tensor->dataType == kNumberTypeInt8) { + schema_tensor->enableHuffmanCode = true; } node_id_map_[input_name] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); return RET_OK; } -int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, +int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { int ret; - auto valueAbstract = valueNode->abstract(); - auto abstractTensor = utils::cast(valueAbstract); - if (abstractTensor == nullptr || abstractTensor->element() == nullptr) { - MS_LOG(ERROR) << "abstractTensor or abstractTensor->element() is nullptr"; + auto valueAbstract = value_node->abstract(); + auto abstract_tensor = utils::cast(valueAbstract); + if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { + MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; return RET_ERROR; } - auto typePtr = abstractTensor->element()->GetTypeTrack(); - (*paramTensor)->dataType = typePtr->type_id(); - auto shape_vector = utils::cast(abstractTensor->BuildShape())->shape(); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + (*schema_tensor)->dataType = typePtr->type_id(); + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); std::vector dims; (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), [](const int64_t &value) { return static_cast(value); }); - (*paramTensor)->dims = dims; - if (train_flag_ && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; - (*paramTensor)->nodeType = NodeType_ValueNode; + (*schema_tensor)->dims = dims; + if (train_flag_ && (*schema_tensor)->dims.empty()) (*schema_tensor)->dims = {1}; + (*schema_tensor)->nodeType = NodeType_ValueNode; auto data = value->cast(); - (*paramTensor)->data.resize(data->Size()); - ret = memcpy_s((*paramTensor)->data.data(), data->Size(), data->data_c(), data->Size()); + (*schema_tensor)->data.resize(data->Size()); + ret = memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size()); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s error."; return RET_ERROR; } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); return ret; } -int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, +int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { int ret; // data of int64 is converted to int32 here. - (*paramTensor)->dataType = kNumberTypeInt32; - (*paramTensor)->dims = {1}; - (*paramTensor)->nodeType = NodeType_ValueNode; + (*schema_tensor)->dataType = kNumberTypeInt32; + (*schema_tensor)->dims = {1}; + (*schema_tensor)->nodeType = NodeType_ValueNode; int real_data = opt::CastToInt(value).front(); - (*paramTensor)->data.resize(sizeof(int32_t)); - ret = memcpy_s((*paramTensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); + (*schema_tensor)->data.resize(sizeof(int32_t)); + ret = memcpy_s((*schema_tensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s error."; return RET_ERROR; } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); return ret; } -void AnfExporter::ProcessBoolImm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, +void AnfExporter::ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - auto valueAbstract = valueNode->abstract(); + auto valueAbstract = value_node->abstract(); auto abstractScalar = utils::cast(valueAbstract); auto typePtr = abstractScalar->GetTypeTrack(); - (*paramTensor)->dataType = typePtr->type_id(); - (*paramTensor)->dims = {1}; - (*paramTensor)->nodeType = NodeType_ValueNode; + (*schema_tensor)->dataType = typePtr->type_id(); + (*schema_tensor)->dims = {1}; + (*schema_tensor)->nodeType = NodeType_ValueNode; auto data = value->cast(); - (*paramTensor)->data.emplace_back(data->value()); - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + (*schema_tensor)->data.emplace_back(data->value()); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); } -int AnfExporter::ProcessNumber(const ValueNodePtr &valueNode, schema::TensorT *paramTensor, +int AnfExporter::ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - auto data = valueNode->value()->cast(); - paramTensor->data.resize(sizeof(int)); + auto data = value_node->value()->cast(); + schema_tensor->data.resize(sizeof(int)); int number_type = data->number_type(); - if (EOK != ::memcpy_s(paramTensor->data.data(), sizeof(int), &number_type, sizeof(int))) { + if (EOK != ::memcpy_s(schema_tensor->data.data(), sizeof(int), &number_type, sizeof(int))) { MS_LOG(ERROR) << "memcpy_s failed"; return RET_MEMORY_FAILED; } - paramTensor->dataType = kNumberTypeInt32; - paramTensor->dims = {1}; - paramTensor->nodeType = NodeType_ValueNode; - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + schema_tensor->dataType = kNumberTypeInt32; + schema_tensor->dims = {1}; + schema_tensor->nodeType = NodeType_ValueNode; + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(paramTensor); + meta_graphT->allTensors.emplace_back(schema_tensor); return RET_OK; } -void AnfExporter::ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, +void AnfExporter::ProcessInt(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - (*paramTensor)->dataType = kNumberTypeInt32; - (*paramTensor)->dims = {1}; - (*paramTensor)->nodeType = NodeType_ValueNode; - (*paramTensor)->data.emplace_back(kNumberTypeInt32); - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + (*schema_tensor)->dataType = kNumberTypeInt32; + (*schema_tensor)->dims = {1}; + (*schema_tensor)->nodeType = NodeType_ValueNode; + (*schema_tensor)->data.emplace_back(kNumberTypeInt32); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); } -int AnfExporter::ProcessValueSequence(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, +int AnfExporter::ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { int ret = RET_OK; - auto valueAbstract = valueNode->abstract(); + auto valueAbstract = value_node->abstract(); auto abstractSequnce = utils::cast(valueAbstract); if (abstractSequnce->isa()) { auto abstractTuple = utils::cast(valueAbstract); @@ -724,72 +785,71 @@ int AnfExporter::ProcessValueSequence(const ValueNodePtr &valueNode, std::unique return RET_ERROR; } } - (*paramTensor)->dataType = kNumberTypeInt32; - (*paramTensor)->dims = {static_cast(shape.size())}; - (*paramTensor)->nodeType = NodeType_ValueNode; - (*paramTensor)->data.resize(shape.size() * sizeof(int)); - ret = memcpy_s((*paramTensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(), + (*schema_tensor)->dataType = kNumberTypeInt32; + (*schema_tensor)->dims = {static_cast(shape.size())}; + (*schema_tensor)->nodeType = NodeType_ValueNode; + (*schema_tensor)->data.resize(shape.size() * sizeof(int)); + ret = memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(), shape.size() * sizeof(int32_t)); if (ret != RET_OK) { - MS_LOG(ERROR) << "memcpy_s data into paramTensor failed."; + MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed."; return RET_ERROR; } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); } return ret; } -int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - int ret; - auto valueLite = std::dynamic_pointer_cast(value); - (*paramTensor)->data.resize(valueLite->tensor_size()); - (*paramTensor)->format = schema::Format(valueLite->format()); - (*paramTensor)->dataType = valueLite->tensor_type(); - (*paramTensor)->dims = valueLite->tensor_shape(); - if (train_flag_ && (*paramTensor)->dims.empty()) { - (*paramTensor)->dims = {1}; +int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + auto tensor_info = std::dynamic_pointer_cast(value); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Input value is not a tensor"; + return RET_INPUT_PARAM_INVALID; + } + auto ret = UpdateTensorTFromTensorInfo(tensor_info, schema_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UpdateTensorTFromTensorInfo failed"; + return ret; + } + if (train_flag_ && (*schema_tensor)->dims.empty()) { + (*schema_tensor)->dims = {1}; } - ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), - valueLite->tensor_size()); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s data into tensor failed."; - return RET_ERROR; - } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); return ret; } int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_anode, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode) { - auto valueNode = input_anode->cast(); - auto paramTensor = std::make_unique(); - auto value = valueNode->value(); + auto value_node = input_anode->cast(); + auto schema_tensor = std::make_unique(); + auto value = value_node->value(); int ret = RET_OK; + if (train_flag_) { - paramTensor->name = valueNode->fullname_with_scope(); + schema_tensor->name = value_node->fullname_with_scope(); } if (value->isa()) { - ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); + ret = ProcessTensor(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa() || value->isa()) { - ret = ProcessInt32OrInt64Imm(valueNode, ¶mTensor, value, output_cnode, meta_graphT); + ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa()) { - ProcessBoolImm(valueNode, ¶mTensor, value, output_cnode, meta_graphT); + ProcessBoolImm(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa()) { - ProcessInt(valueNode, ¶mTensor, output_cnode, meta_graphT); + ProcessInt(value_node, &schema_tensor, output_cnode, meta_graphT); } else if (value->isa()) { - ret = ProcessValueSequence(valueNode, ¶mTensor, value, output_cnode, meta_graphT); + ret = ProcessValueSequence(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa()) { - ret = ProcessNumber(valueNode, paramTensor.release(), output_cnode, meta_graphT); - } else if (value->isa()) { - ret = ProcessParamValueLite(valueNode, ¶mTensor, value, output_cnode, meta_graphT); + ret = ProcessNumber(value_node, schema_tensor.release(), output_cnode, meta_graphT); + } else if (value->isa()) { + ret = ProcessTensorInfo(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa()) { MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; return RET_OK; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 56087c7a6a9..35d2a526b4b 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -25,6 +25,7 @@ #include "ops/primitive_c.h" #include "ir/func_graph.h" #include "tools/converter/converter_context.h" +#include "tools/converter/converter_flags.h" using mindspore::ops::PrimitiveC; @@ -52,25 +53,25 @@ class AnfExporter { const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); int ConvertInputValueNode(const std::shared_ptr &input_anode, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); - int ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessInt32OrInt64Imm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - void ProcessBoolImm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + void ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - void ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + void ProcessInt(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessNumber(const ValueNodePtr &valueNode, schema::TensorT *paramTensor, schema::CNodeT *output_cnode, + int ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessValueSequence(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + int ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessParamValueLite(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); + int ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, const std::unique_ptr &meta_graphT, schema::CNodeT *return_node); diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index c15b3148dd2..ca705654182 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "src/common/utils.h" #include "tools/common/tensor_util.h" +#include "src/common/utils.h" #include "tools/common/graph_util.h" +#include "abstract/utils.h" namespace mindspore::lite { std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor) { @@ -43,6 +44,111 @@ std::unique_ptr CopyQuantParamT(const std::unique_ptr &shape, + TypeId data_type) { + tensor::TensorPtr tensor_info = nullptr; + if (shape.empty() && data_size == mindspore::abstract::TypeIdSize(data_type)) { + ShapeVector scalar_shape = {1}; + tensor_info = std::make_shared(data_type, scalar_shape); + tensor_info->set_shape({}); + } else { + tensor_info = std::make_shared(data_type, shape); + } + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "new tensor init failed"; + return nullptr; + } + if (data_size == 0) { + return tensor_info; + } + if (data == nullptr) { + MS_LOG(ERROR) << "input tensor data is nullptr"; + return nullptr; + } + auto ret = memcpy_s(tensor_info->data_c(), tensor_info->data().nbytes(), data, data_size); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error : " << ret; + return nullptr; + } + return tensor_info; +} + +int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) { + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "tensor info is nullptr."; + return RET_ERROR; + } + if (data == nullptr) { + MS_LOG(ERROR) << "data is nullptr."; + return RET_ERROR; + } + auto ret = memcpy_s(tensor_info->data_c(), tensor_info->data().nbytes(), data, data_size); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error : " << ret; + return RET_ERROR; + } + return RET_OK; +} + +std::unique_ptr CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, + const std::string &tensor_name) { + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Input tensor is nullptr"; + return nullptr; + } + auto schema_tensor = std::make_unique(); + schema_tensor->name = tensor_name; + auto ret = UpdateTensorTFromTensorInfo(tensor_info, &schema_tensor); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init schema tensor failed"; + return nullptr; + } + return schema_tensor; +} + +int UpdateTensorTFromTensorInfo(const tensor::TensorPtr &src_tensor, std::unique_ptr *dst_tensor) { + if (src_tensor == nullptr) { + MS_LOG(ERROR) << "Input tensor info is nullptr"; + return RET_INPUT_PARAM_INVALID; + } + if (dst_tensor == nullptr || *dst_tensor == nullptr) { + MS_LOG(ERROR) << "Input schema tensor is nullptr"; + return RET_INPUT_PARAM_INVALID; + } + auto &schema_tensor = *dst_tensor; + schema_tensor->format = schema::Format_NHWC; + schema_tensor->dataType = src_tensor->data_type(); + auto &shape_vector = src_tensor->shape(); + std::vector dims; + (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), + [](const int64_t &value) { return static_cast(value); }); + schema_tensor->dims = dims; + if (src_tensor->data().data() != nullptr) { + schema_tensor->data.resize(src_tensor->data().nbytes()); + if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), src_tensor->data().data(), + src_tensor->data().nbytes())) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_ERROR; + } + } + return RET_OK; +} + +int InitParameterFromTensorInfo(const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info) { + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "tensor info is nullptr."; + return RET_ERROR; + } + auto abstract_tensor = tensor_info->ToAbstract(); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create abstract tensor failed."; + return RET_ERROR; + } + param_node->set_abstract(abstract_tensor); + param_node->set_default_param(tensor_info); + return RET_OK; +} + size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); } size_t GetElementSize(const TypeId &dataType) { diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index 9db927308b0..6c5d18fd5a5 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -20,12 +20,14 @@ #include #include #include +#include #include #include #include #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "ir/dtype/type_id.h" +#include "ir/tensor.h" #include "src/common/utils.h" namespace mindspore { @@ -41,6 +43,18 @@ using schema::Format::Format_NHWC; std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor); +tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector &shape, + TypeId data_type); + +int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size); + +std::unique_ptr CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, + const std::string &tensor_name = ""); + +int UpdateTensorTFromTensorInfo(const tensor::TensorPtr &src_tensor, std::unique_ptr *dst_tensor); + +int InitParameterFromTensorInfo(const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info); + size_t GetElementSize(const TensorT &tensor); size_t GetElementSize(const TypeId &dataType); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index 93f44447d0c..ade075a6675 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -133,7 +133,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { continue; } if (tensor->quantParams.size() != 1) { // perchannel - MS_LOG(ERROR) << "perchannel doquant is not supported yet"; + MS_LOG(ERROR) << "perchannel do quant is not supported yet"; return RET_ERROR; } // perlayer diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 4e624f6eaaf..1a601fa6154 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -15,7 +15,6 @@ */ #include "tools/converter/parser/caffe/caffe_model_parser.h" #include -#include #include #include #include @@ -23,7 +22,7 @@ #include "tools/converter/parser/caffe/caffe_inspector.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "ops/return.h" #include "ops/make_tuple.h" #include "ops/tuple_get_item.h" @@ -350,8 +349,6 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v std::vector shape_vector; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); if (layer.type() == "Convolution" || layer.type() == "Deconvolution") { if (i == 0) { parameter->set_name(layer.name() + "/weight"); @@ -361,40 +358,34 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v } else { parameter->set_name(layer.name() + "/input-" + std::to_string(i + layer.top_size())); } - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(TypeId::kNumberTypeFloat32); - param_value->set_format(schema::Format::Format_NCHW); int count = 0; + tensor::TensorPtr tensor_info = nullptr; if (layer.blobs(i).double_data_size() > 0) { count = layer.blobs(i).double_data_size(); auto buf = std::make_unique(count); for (int j = 0; j < count; ++j) { buf[j] = layer.blobs(j).double_data(j); } - param_value->set_tensor_addr(buf.release()); + tensor_info = CreateTensorInfo(buf.get(), count * sizeof(float), shape_vector, TypeId::kNumberTypeFloat32); } else { count = layer.blobs(i).data_size(); - auto buf = std::make_unique(count); - if (buf == nullptr) { - MS_LOG(INFO) << "new buffer failed"; - return RET_NULL_PTR; - } const float *data_ptr = layer.blobs(i).data().data(); if (data_ptr == nullptr) { MS_LOG(INFO) << "data of origin layer is nullptr"; return RET_NULL_PTR; } - if (EOK != ::memcpy_s(buf.get(), count * sizeof(float), data_ptr, count * sizeof(float))) { - MS_LOG(ERROR) << "memcpy_s failed."; - return RET_ERROR; - } - param_value->set_tensor_addr(buf.release()); + tensor_info = CreateTensorInfo(data_ptr, count * sizeof(float), shape_vector, TypeId::kNumberTypeFloat32); + } + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed"; + return RET_NULL_PTR; + } + auto status = InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return RET_ERROR; } - param_value->set_tensor_size(count * sizeof(float)); - parameter->set_default_param(param_value); const_parameters->emplace_back(parameter); } return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 253570c78e9..ddec9ff8097 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -17,19 +17,13 @@ #include "tools/converter/parser/onnx/onnx_constant_parser.h" #include #include -#include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "ops/constant.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" namespace mindspore { namespace lite { STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_tensor, ops::PrimitiveC *prim) { - ParamValueLitePtr param_value = std::make_shared(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "new a paramValueLite failed."; - return RET_ERROR; - } auto data_type = OnnxModelParser::GetDataTypeFromOnnx(static_cast(onnx_const_tensor.data_type())); if (data_type == kTypeUnknown) { @@ -41,14 +35,16 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), [](const int64_t &val) { return static_cast(val); }); - param_value->set_tensor_type(data_type); - param_value->set_tensor_shape(shape); - param_value->set_format(schema::Format_NCHW); - if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, param_value) != RET_OK) { + auto tensor_info = std::make_shared(data_type, shape_vector); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "new a paramValueLite failed."; + return RET_ERROR; + } + if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, tensor_info) != RET_OK) { MS_LOG(ERROR) << "get value failed."; return RET_ERROR; } - prim->set_attr("const_data", param_value); + prim->set_attr("const_data", tensor_info); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc index a172795242a..81b2deca541 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc @@ -19,70 +19,45 @@ #include #include #include -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "ops/constant.h" namespace mindspore { namespace lite { STATUS OnnxGivenTensorFillParser::ParseInt8GivenIntTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, const std::vector &shape) { - ParamValueLitePtr param_value = std::make_shared(); - int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); if (iter == onnx_node.attribute().end()) { return RET_OK; } + ShapeVector shape_vector(shape.begin(), shape.end()); size_t data_size = data_count * sizeof(int64_t) / sizeof(uint8_t); - char *param_data = new (std::nothrow) char[data_size]; - if (param_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - if (iter->ints().data() == nullptr) { - MS_LOG(ERROR) << "origin ints data in onnx is nullptr"; - delete[] param_data; - return RET_NULL_PTR; - } - if (memcpy_s(param_data, data_size, iter->ints().data(), data_size) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] param_data; + auto tensor_info = CreateTensorInfo(iter->ints().data(), data_size, shape_vector, kNumberTypeInt64); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; return RET_ERROR; } - param_value->set_tensor_shape(shape); - param_value->set_format(schema::Format_NUM_OF_FORMAT); - param_value->set_tensor_type(kNumberTypeInt64); - param_value->SetTensorData(param_data, data_size); - prim->set_attr("const_data", param_value); + prim->set_attr("const_data", tensor_info); return RET_OK; } STATUS OnnxGivenTensorFillParser::ParseInt8GivenTensorFill(const onnx::NodeProto &onnx_node, ops::PrimitiveC *prim, const std::vector &shape) { - ParamValueLitePtr param_value = std::make_shared(); - int data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); auto iter = std::find_if(onnx_node.attribute().begin(), onnx_node.attribute().end(), [](const onnx::AttributeProto &attr) { return attr.name() == "values"; }); if (iter == onnx_node.attribute().end()) { return RET_OK; } - char *param_data = new (std::nothrow) char[data_count]; - if (param_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - if (memcpy_s(param_data, data_count, iter->s().data(), data_count) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] param_data; + ShapeVector shape_vector(shape.begin(), shape.end()); + auto tensor_info = CreateTensorInfo(iter->s().data(), data_count, shape_vector, kNumberTypeUInt8); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; return RET_ERROR; } - param_value->set_tensor_shape(shape); - param_value->set_format(schema::Format_NUM_OF_FORMAT); - param_value->set_tensor_type(kNumberTypeUInt8); - param_value->SetTensorData(param_data, data_count); - prim->set_attr("const_data", param_value); + prim->set_attr("const_data", tensor_info); return RET_OK; } ops::PrimitiveC *OnnxGivenTensorFillParser::Parse(const onnx::GraphProto &onnx_graph, diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index a39985a3d83..3a031d51de4 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -20,15 +20,16 @@ #include #include #include +#include "tools/optimizer/common/gllo_utils.h" #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" +#include "tools/common/tensor_util.h" #include "ops/return.h" #include "ops/make_tuple.h" #include "ops/tensor_list_stack.h" #include "ops/tuple_get_item.h" #include "ir/func_graph.h" -#include "src/param_value_lite.h" #include "tools/converter/converter_flags.h" namespace mindspore { @@ -209,6 +210,7 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F status = RET_ERROR; continue; } + primitive_c->AddAttr(mindspore::opt::kWeightFormat, MakeValue(Format_NCHW)); status = ConvertOpQuantParams(onnx_node, primitive_c); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; @@ -429,25 +431,14 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncG ext_subgraph_input->set_abstract(outside_input_node->abstract()); ext_subgraph_input->set_name(input_name); if (outside_input_node->isa()) { - auto param_value = outside_input_node->cast()->default_param()->cast(); - auto copy_param_value = std::make_shared(); - auto copy_data = new (std::nothrow) char[param_value->tensor_size()]; - if (copy_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - auto ret = - memcpy_s(copy_data, param_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); - if (ret != EOK) { - delete[](copy_data); - MS_LOG(ERROR) << "memcpy error: " << ret; + auto tensor_info = outside_input_node->cast()->default_param()->cast(); + auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(), + tensor_info->data_type()); + if (copy_tensor_info == nullptr) { + MS_LOG(ERROR) << "memcpy failed."; return RET_ERROR; } - copy_param_value->set_tensor_shape(param_value->tensor_shape()); - copy_param_value->set_format(param_value->format()); - copy_param_value->set_tensor_type(param_value->tensor_type()); - copy_param_value->SetTensorData(copy_data, param_value->tensor_size()); - ext_subgraph_input->set_default_param(copy_param_value); + ext_subgraph_input->set_default_param(copy_tensor_info); } else { // output inside cnode need make extra input graph_inputs->emplace_back(ext_subgraph_input); @@ -675,16 +666,16 @@ STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, Qua MS_LOG(ERROR) << "quant param get failed"; return RET_ERROR; } - auto param_value_lite = quant_parameter_node->default_param()->cast(); - if (param_value_lite == nullptr) { - MS_LOG(ERROR) << "parameterNode's default param is not paramValueLite"; + auto tensor_info = quant_parameter_node->default_param()->cast(); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "parameterNode's default param is not tensor::TensorPtr"; return RET_ERROR; } if (scale_or_not) { - quant_param->scale = *reinterpret_cast(param_value_lite->tensor_addr()); + quant_param->scale = *reinterpret_cast(tensor_info->data_c()); quant_param->inited = true; } else { - quant_param->zeroPoint = *reinterpret_cast(param_value_lite->tensor_addr()); + quant_param->zeroPoint = *reinterpret_cast(tensor_info->data_c()); quant_param->inited = true; } return RET_OK; @@ -704,10 +695,14 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) { return nullptr; } tensor_data[0] = val; - auto param_value = std::make_shared(); - param_value->set_tensor_shape({}); - param_value->SetTensorData(tensor_data, sizeof(int)); - const_node->set_default_param(param_value); + auto tensor_info = CreateTensorInfo(tensor_data, 1 * sizeof(int), {1}, kNumberTypeInt32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + delete[] tensor_data; + return nullptr; + } + delete[] tensor_data; + const_node->set_default_param(tensor_info); return const_node; } @@ -841,10 +836,9 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v auto while_tensor_array_input = anf_root_graph->add_parameter(); std::vector shape_vector; auto abstract_tensor = std::make_shared(kTensorType, shape_vector); - auto param_value = std::make_shared(); - param_value->set_tensor_type(kObjectTypeTensorType); + auto tensor_info = std::make_shared(kObjectTypeTensorType, shape_vector); while_tensor_array_input->set_abstract(abstract_tensor); - while_tensor_array_input->set_default_param(param_value); + while_tensor_array_input->set_default_param(tensor_info); while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray"); root_while_node->add_input(while_tensor_array_input); @@ -1035,33 +1029,18 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const } parameter_node->set_abstract(abstract_tensor); parameter_node->set_name(name); - std::vector shape; - ParamValueLitePtr param_value = std::make_shared(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "new param_value failed"; - return RET_MEMORY_FAILED; - } - param_value->set_tensor_shape(shape); - param_value->set_format(schema::Format_NUM_OF_FORMAT); - param_value->set_tensor_type(type); int data_size = 0; if (type == kNumberTypeFloat32) { data_size = sizeof(float); } else { data_size = sizeof(int64_t); } - auto *tensor_data = new (std::nothrow) char[data_size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - if (memcpy_s(tensor_data, data_size, data, data_size) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] tensor_data; + auto tensor_info = CreateTensorInfo(data, data_size, {1}, type); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } - param_value->SetTensorData(tensor_data, data_size); - parameter_node->set_default_param(param_value); + parameter_node->set_default_param(tensor_info); anf_nodes_map_.emplace(name, parameter_node); return RET_OK; } @@ -1078,26 +1057,23 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr ¶meter_node, c parameter_node->set_abstract(abstract_tensor); parameter_node->set_name(tensor.name()); - ParamValueLitePtr param_value = std::make_shared(); + auto tensor_info = std::make_shared(data_type, shape_vector); std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), [](const int64_t &value) { return static_cast(value); }); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(data_type); - param_value->set_format(schema::Format::Format_NCHW); - auto status = CopyOnnxTensorData(tensor, param_value); + auto status = CopyOnnxTensorData(tensor, tensor_info); if (status != RET_OK) { MS_LOG(ERROR) << "copy data failed."; return status; } - parameter_node->set_default_param(param_value); + parameter_node->set_default_param(tensor_info); return RET_OK; } STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, - const ParamValueLitePtr ¶m_value_lite) { - if (param_value_lite == nullptr) { - MS_LOG(ERROR) << "param_value_lite is nullptr."; + const tensor::TensorPtr &tensor_info) { + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "tensor_info is nullptr."; return RET_NULL_PTR; } size_t data_count = 1; @@ -1150,17 +1126,11 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t MS_LOG(ERROR) << "origin data in onnx model is nullptr"; return RET_MEMORY_FAILED; } - char *param_data = new (std::nothrow) char[data_size]; - if (param_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - if (memcpy_s(static_cast(param_data), data_size, onnx_data, data_size) != EOK) { + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + if (memcpy_s(tensor_data, data_size, onnx_data, data_size) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; - delete[] param_data; return RET_ERROR; } - param_value_lite->SetTensorData(param_data, data_size); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 8deb07d2893..6bc638a1cf0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -31,7 +31,6 @@ #include "tools/converter/model_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" #include "proto/onnx.pb.h" -#include "src/param_value_lite.h" namespace mindspore { namespace lite { @@ -44,8 +43,8 @@ class OnnxModelParser : public ModelParser { FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) override; static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); - static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, - const ParamValueLitePtr ¶m_value_lite); + static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, + const tensor::TensorPtr ¶m_value_lite); private: STATUS InitOriginModel(const std::string &model_file); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 847e1671530..5a05e93a0b5 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -20,7 +20,6 @@ #include #include "src/common/log_adapter.h" #include "src/common/utils.h" -#include "src/param_value_lite.h" #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" @@ -29,6 +28,7 @@ #include "ops/make_tuple.h" #include "ops/tuple_get_item.h" #include "ir/anf.h" +#include "abstract/utils.h" #include "tools/converter/converter_flags.h" namespace mindspore { @@ -85,12 +85,30 @@ STATUS CheckStrView(std::string_view str_view, uint64_t *scratch) { return RET_OK; } -STATUS GetFloatValue(const tensorflow::TensorProto &tensor_proto, const tensorflow::TensorShapeProto &tensor_shape, - ParamValueLitePtr param_value, int shape_size) { - auto tensor_data = new (std::nothrow) float[shape_size]; +int GetShapeSize(const tensorflow::TensorProto &tensor_proto) { + auto &tensor_shape = tensor_proto.tensor_shape(); + int shape_size = 1; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_size *= tensor_shape.dim(i).size(); + } + return shape_size; +} + +STATUS SetFloatTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { + auto shape_size = GetShapeSize(tensor_proto); + auto &tensor_shape = tensor_proto.tensor_shape(); + ShapeVector shape_vector{}; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector.push_back(tensor_shape.dim(i).size()); + } + *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor data failed."; + return RET_ERROR; + } + auto tensor_data = reinterpret_cast((*tensor_info)->data_c()); if (tensor_data == nullptr) { MS_LOG(ERROR) << "new data failed"; - delete[] tensor_data; return RET_ERROR; } @@ -107,17 +125,25 @@ STATUS GetFloatValue(const tensorflow::TensorProto &tensor_proto, const tensorfl return RET_ERROR; } } - auto tensor_size = shape_size * sizeof(float); - param_value->SetTensorData(tensor_data, tensor_size); + return RET_OK; } -STATUS GetInt32Value(const tensorflow::TensorProto &tensor_proto, const tensorflow::TensorShapeProto &tensor_shape, - ParamValueLitePtr param_value, int shape_size) { - auto tensor_data = new (std::nothrow) int[shape_size]; +STATUS SetInt32TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { + auto shape_size = GetShapeSize(tensor_proto); + auto &tensor_shape = tensor_proto.tensor_shape(); + ShapeVector shape_vector{}; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector.push_back(tensor_shape.dim(i).size()); + } + *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeInt32); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor data failed."; + return RET_ERROR; + } + auto tensor_data = reinterpret_cast((*tensor_info)->data_c()); if (tensor_data == nullptr) { MS_LOG(ERROR) << "new data failed"; - delete[] tensor_data; return RET_ERROR; } @@ -134,21 +160,29 @@ STATUS GetInt32Value(const tensorflow::TensorProto &tensor_proto, const tensorfl return RET_ERROR; } } - auto tensor_size = shape_size * sizeof(int); - param_value->SetTensorData(tensor_data, tensor_size); + return RET_OK; } -STATUS GetInt64Value(const tensorflow::TensorProto &tensor_proto, const tensorflow::TensorShapeProto &tensor_shape, - ParamValueLitePtr param_value, int shape_size) { - param_value->set_tensor_type(kNumberTypeInt32); - auto *tensor_data = new (std::nothrow) int[shape_size]; +STATUS SetInt64TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { + auto shape_size = GetShapeSize(tensor_proto); + auto &tensor_shape = tensor_proto.tensor_shape(); + ShapeVector shape_vector{}; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector.push_back(tensor_shape.dim(i).size()); + } + *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeInt32); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor data failed."; + return RET_ERROR; + } + auto tensor_data = reinterpret_cast((*tensor_info)->data_c()); if (tensor_data == nullptr) { MS_LOG(ERROR) << "new data failed"; delete[] tensor_data; return RET_ERROR; } - if (tensor_shape.dim_size() == 0) { // scalar + if (tensor_proto.tensor_shape().dim_size() == 0) { // scalar const auto &origin_data = tensor_proto.int64_val(); for (int i = 0; i < tensor_proto.int64_val_size(); ++i) { if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { @@ -170,14 +204,84 @@ STATUS GetInt64Value(const tensorflow::TensorProto &tensor_proto, const tensorfl } } } - param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t)); + + return RET_OK; +} + +STATUS SetBoolTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { + auto shape_size = GetShapeSize(tensor_proto); + auto &tensor_shape = tensor_proto.tensor_shape(); + ShapeVector shape_vector{}; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector.push_back(tensor_shape.dim(i).size()); + } + *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeBool); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor data failed."; + return RET_ERROR; + } + auto tensor_data = reinterpret_cast((*tensor_info)->data_c()); + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + delete[] tensor_data; + return RET_ERROR; + } + + if (tensor_proto.bool_val_size() == 1) { + int value = tensor_proto.bool_val(0); + for (int i = 0; i < shape_size; i++) { + tensor_data[i] = value; + } + } + + return RET_OK; +} + +STATUS SetStringTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { + auto &tensor_shape = tensor_proto.tensor_shape(); + ShapeVector shape_vector{}; + for (int i = 0; i < tensor_shape.dim_size(); i++) { + shape_vector.push_back(tensor_shape.dim(i).size()); + } + std::string shape_str; + shape_str += std::to_string(shape_vector.size()) + ","; + for (auto &dim : shape_vector) { + shape_str += std::to_string(dim) + ","; + } + + auto tensor_data = new (std::nothrow) string; + if (tensor_proto.string_val_size() == 1) { + *tensor_data = tensor_proto.string_val(0); + } else { + MS_LOG(ERROR) << "string size bigger than one, not support."; + delete tensor_data; + return RET_ERROR; + } + + shape_vector = {static_cast(shape_str.size() + (*tensor_data).size())}; + *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kObjectTypeString); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + return RET_ERROR; + } + auto tensor_info_data = reinterpret_cast((*tensor_info)->data_c()); + if (memcpy_s(tensor_info_data, shape_str.size(), shape_str.data(), shape_str.size()) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + return RET_ERROR; + } + if (memcpy_s(tensor_info_data + shape_str.size(), (*tensor_data).size(), (*tensor_data).data(), + (*tensor_data).size()) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + return RET_ERROR; + } + + delete tensor_data; return RET_OK; } } // namespace -STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, - const ParamValueLitePtr ¶m_value) { +STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) { if (tensor_proto.variant_val_size() != 1) { MS_LOG(ERROR) << "only support variant_val_size == 1 now"; return RET_ERROR; @@ -211,23 +315,6 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ tensorflow::TensorShapeProto element_shape_proto; element_shape_proto.ParseFromString(std::string(str_view.data(), str_view.size())); auto dim_size = element_shape_proto.dim_size(); - auto tensor_data = new (std::nothrow) int[dim_size + 2]; // encode element_dtype,shape.size,shape[i]... into data - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; - return RET_ERROR; - } - tensor_data[0] = TensorFlowUtils::GetTFDataType(tensorflow::DataType(element_dtype)); - tensor_data[1] = element_shape_proto.dim_size(); - for (int i = 0; i < dim_size; ++i) { - auto dim = element_shape_proto.dim(i).size(); - if (dim > static_cast(INT32_MAX) || dim < static_cast(INT32_MIN)) { - MS_LOG(ERROR) << "int64 data " << dim << " too big to fit into int32"; - delete[] tensor_data; - return RET_ERROR; - } else { - tensor_data[i + 2] = static_cast(dim); - } - } std::vector tensor_list_data(dim_size + 2); tensor_list_data[0] = TensorFlowUtils::GetTFDataType(tensorflow::DataType(element_dtype)); tensor_list_data[1] = element_shape_proto.dim_size(); @@ -235,7 +322,6 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ auto dim = element_shape_proto.dim(i).size(); if (dim > static_cast(INT32_MAX) || dim < static_cast(INT32_MIN)) { MS_LOG(ERROR) << "int64 data " << dim << " too big to fit into int32"; - delete[] tensor_data; return RET_ERROR; } else { tensor_list_data[i + 2] = static_cast(dim); @@ -250,51 +336,30 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ } tensor_list_data.insert(tensor_list_data.end(), single_tensor_data.begin(), single_tensor_data.end()); } - auto tensor_data_ptr = new (std::nothrow) int[tensor_list_data.size()]; - if (tensor_data_ptr == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; - return RET_NULL_PTR; + *tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int), + {static_cast(tensor_list_data.size())}, kObjectTypeTensorType); + if (*tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor data failed."; + return RET_ERROR; } - if (EOK != ::memcpy_s(tensor_data_ptr, tensor_list_data.size() * sizeof(int), tensor_list_data.data(), - tensor_list_data.size() * sizeof(int))) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_NULL_PTR; - } - param_value->SetTensorData(tensor_data_ptr, tensor_list_data.size() * sizeof(int)); return RET_OK; } -STATUS TFModelParser::GetValueFromType(const tensorflow::TensorProto &tensor_proto, - const tensorflow::TensorShapeProto &tensor_shape, ParamValueLitePtr param_value, - const TypeId &type, int shape_size) { +STATUS TFModelParser::SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, + tensor::TensorPtr *tensor_info) { + auto type = (*tensor_info)->data_type(); if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { - return GetFloatValue(tensor_proto, tensor_shape, param_value, shape_size); + return SetFloatTensorInfo(tensor_proto, tensor_info); } else if (type == kNumberTypeInt32 || type == kNumberTypeInt) { - return GetInt32Value(tensor_proto, tensor_shape, param_value, shape_size); + return SetInt32TensorInfo(tensor_proto, tensor_info); } else if (type == kNumberTypeInt64) { - return GetInt64Value(tensor_proto, tensor_shape, param_value, shape_size); + return SetInt64TensorInfo(tensor_proto, tensor_info); } else if (type == kNumberTypeBool) { - auto tensor_data = new (std::nothrow) int[shape_size]; - if (tensor_proto.bool_val_size() == 1) { - int value = tensor_proto.bool_val(0); - for (int i = 0; i < shape_size; i++) { - tensor_data[i] = value; - } - } - auto tensor_size = shape_size * sizeof(int); - param_value->SetTensorData(tensor_data, tensor_size); + return SetBoolTensorInfo(tensor_proto, tensor_info); } else if (type == kObjectTypeTensorType) { - return ConvertConstVariant(tensor_proto, param_value); + return ConvertConstVariant(tensor_proto, tensor_info); } else if (type == kObjectTypeString) { - auto tensor_data = new (std::nothrow) string; - if (tensor_proto.string_val_size() == 1) { - *tensor_data = tensor_proto.string_val(0); - } else { - MS_LOG(ERROR) << "string size bigger than one, not support."; - return RET_ERROR; - } - auto tensor_size = (*tensor_data).size(); - param_value->SetTensorData(tensor_data, tensor_size); + return SetStringTensorInfo(tensor_proto, tensor_info); } else { MS_LOG(ERROR) << "Unsupported dataType: " << type; return RET_ERROR; @@ -309,35 +374,25 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co MS_ASSERT(shape_vector != nullptr); const tensorflow::TensorProto &tensor_proto = attr_value.tensor(); const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape(); - int shape_size = 1; shape_vector->clear(); for (int i = 0; i < tensor_shape.dim_size(); i++) { shape_vector->push_back(tensor_shape.dim(i).size()); - shape_size *= tensor_shape.dim(i).size(); } - - auto param_value = std::make_shared(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "param_value is nullptr"; + auto tensor_info = std::make_shared(type, *shape_vector); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "tensor info is nullptr"; return RET_ERROR; } - param_value->set_tensor_type(type); - if (GetValueFromType(tensor_proto, tensor_shape, param_value, type, shape_size) != RET_OK) { - MS_LOG(ERROR) << "get value from type failed."; + auto status = SetTensorInfoFromType(tensor_proto, &tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "set tensor data from type failed."; return RET_ERROR; } - std::vector param_shape(shape_vector->begin(), shape_vector->end()); - param_value->set_tensor_shape(param_shape); - if (TensorFlowUtils::FindAttrValue(node_def, "data_format", const_cast(&attr_value))) { - auto format = mindspore::lite::TensorFlowUtils::ParseNodeFormat(node_def); - if (format == mindspore::Format::NUM_OF_FORMAT) { - MS_LOG(ERROR) << "Do not support data format: " << attr_value.s(); - } - param_value->set_format(format); - } else { - param_value->set_format(schema::Format::Format_NHWC); + status = InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed."; + return RET_ERROR; } - parameter->set_default_param(param_value); return RET_OK; } @@ -365,6 +420,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa MS_LOG(INFO) << "Found value attr, means it has default value"; auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector); if (status != RET_OK) { + MS_LOG(ERROR) << "convert const tensor failed."; return status; } } else { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index a7f6a54b551..5a6b5fbb25d 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -29,7 +29,6 @@ #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "tools/converter/model_parser.h" -#include "src/param_value_lite.h" namespace mindspore { namespace lite { @@ -41,12 +40,10 @@ class TFModelParser : public ModelParser { FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); private: - static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr ¶m_value); + static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, std::vector *shape_vector); - static STATUS GetValueFromType(const tensorflow::TensorProto &tensor_proto, - const tensorflow::TensorShapeProto &tensor_shape, ParamValueLitePtr param_value, - const TypeId &type, int shape_size); + static STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, std::unordered_map *anf_node_map); STATUS ConvertGraphInputsAndConsts(const std::map &tf_graph_nodes, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index e43fe9aa927..070a3106b7f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -20,7 +20,6 @@ #include #include #include "tools/converter/converter_flags.h" -#include "src/param_value_lite.h" #include "src/common/file_utils.h" #include "ops/return.h" #include "ops/make_tuple.h" @@ -160,7 +159,7 @@ STATUS TfliteModelParser::ConvertOps() { tensor_name = GetTensorName(i, tflite_op_type, op_name); } auto parameter = func_graph_->add_parameter(); - status = ConvertConstTensor(input_tensor.get(), parameter.get(), tensor_name); + status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; continue; @@ -354,7 +353,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { return RET_OK; } -STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, +STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter, const std::string &tensor_name) { if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null, get const tensor failed."; @@ -366,31 +365,53 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para return RET_NULL_PTR; } const auto &tflite_model_buffers = tflite_model_->buffers; - auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); + auto type_id = GetTfliteDataType(tensor->type); std::vector shape_vector; - (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); - parameter->set_name(tensor_name); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(tensor->shape); - param_value->set_tensor_type(GetTfliteDataType(tensor->type)); - param_value->set_format(schema::Format::Format_NHWC); const auto &data = tflite_model_buffers.at(tensor->buffer)->data; - if (!data.empty()) { - auto size = data.size(); - char *tensor_data = new (std::nothrow) char[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; + std::string shape_str; + if (data.empty()) { + shape_vector = {}; + } else if (type_id == kObjectTypeString) { + shape_str += std::to_string(tensor->shape.size()) + ","; + for (auto &dim : tensor->shape) { + shape_str += std::to_string(dim) + ","; } - std::memcpy(tensor_data, data.data(), size); - param_value->SetTensorData(tensor_data, size); + shape_vector = {static_cast(shape_str.size() + data.size())}; + } else { + (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); } - parameter->set_default_param(param_value); + + auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, type_id); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "init tensor info failed"; + return RET_NULL_PTR; + } + if (!data.empty()) { + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + if (type_id == kObjectTypeString) { + if (memcpy_s(tensor_data, shape_str.size(), shape_str.data(), shape_str.size()) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + return RET_ERROR; + } + if (memcpy_s(tensor_data + shape_str.size(), data.size(), data.data(), data.size()) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + return RET_ERROR; + } + } else { + if (memcpy_s(tensor_data, tensor_info->Size(), data.data(), data.size()) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + return RET_ERROR; + } + } + } + auto status = InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed."; + return RET_ERROR; + } + parameter->set_name(tensor_name); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 1d800b34898..99e1c0b635a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -41,7 +41,8 @@ class TfliteModelParser : public ModelParser { FuncGraphPtr func_graph_; char *tflite_model_buf_ = nullptr; std::unique_ptr ReadTfliteModel(const char *model_path); - STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter, const std::string &tensor_name); + STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter, + const std::string &tensor_name); STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode); STATUS ConvertOpQuantParams(const tflite::OperatorT *op, ops::PrimitiveC *primitive_c); STATUS ConvertOps(); diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc index 7d4b9846476..ced6ca8873b 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc @@ -15,21 +15,20 @@ */ #include "tools/converter/quantizer/huffman_encode.h" -#include #include "src/dequant.h" #include "tools/converter/quantizer/quantize_util.h" namespace mindspore { namespace lite { -STATUS HuffmanEncode::DoHuffmanEncode(const ParamValueLitePtr &weight, const PrimitivePtr &primitive, void *quant_datas, +STATUS HuffmanEncode::DoHuffmanEncode(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, void *quant_datas, const size_t &bit_num) { if (quant_datas == nullptr) { MS_LOG(ERROR) << "quant data is nullptr"; return RET_ERROR; } auto *raw_datas = static_cast(quant_datas); - size_t elem_count = weight->tensor_shape_size(); - size_t packed_size = elem_count * bit_num; + size_t elem_count = weight->DataSize(); + int packed_size = elem_count * bit_num; HuffmanPriorityQueue pq; auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); @@ -47,19 +46,16 @@ STATUS HuffmanEncode::DoHuffmanEncode(const ParamValueLitePtr &weight, const Pri MS_LOG(ERROR) << "DoHuffmanCompress failed"; return status; } - size_t ch_size = huffman_encoded_str_.length(); + int ch_size = huffman_encoded_str_.length(); if (ch_size < packed_size) { - auto encode_data = new (std::nothrow) char[ch_size]; - if (encode_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed."; - return RET_MEMORY_FAILED; + if (ch_size != weight->data().nbytes()) { + MS_LOG(ERROR) << "Data size of weight is error."; + return RET_ERROR; } - if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { + if (memcpy_s(weight->data_c(), weight->data().nbytes(), huffman_encoded_str_.c_str(), ch_size) != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; - delete[] encode_data; return RET_MEMORY_FAILED; } - weight->SetTensorData(encode_data, ch_size); auto quant_param_holder = quant::GetCNodeQuantHolder(primitive); MS_ASSERT(quant_param_holder != nullptr); quant_param_holder->set_enable_huffman_code(true); diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.h b/mindspore/lite/tools/converter/quantizer/huffman_encode.h index cd0090187cc..00dd6b65239 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.h +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.h @@ -29,7 +29,6 @@ #include "schema/inner/model_generated.h" #include "securec/include/securec.h" #include "src/common/log_adapter.h" -#include "src/param_value_lite.h" namespace mindspore { namespace lite { @@ -58,7 +57,7 @@ class HuffmanEncode { ~HuffmanEncode(); - STATUS DoHuffmanEncode(const ParamValueLitePtr &weight, const PrimitivePtr &primitive, void *quant_datas, + STATUS DoHuffmanEncode(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, void *quant_datas, const size_t &bit_num); private: diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index bc60292b429..f9d13631e4f 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -565,8 +565,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter"; return RET_NULL_PTR; } - ParamValueLitePtr paramValue = std::dynamic_pointer_cast(parameter->default_param()); - if (paramValue == nullptr) { + tensor::TensorPtr tensor_info = std::dynamic_pointer_cast(parameter->default_param()); + if (tensor_info == nullptr) { MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; return RET_NULL_PTR; } @@ -583,8 +583,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); } } - auto status = - QuantFilter(paramValue, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t, perchanel); + auto status = QuantFilter(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t, + perchanel, kNumberTypeInt8); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -616,8 +616,8 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv auto bias_parameter_ptr = std::dynamic_pointer_cast(bias); MS_ASSERT(bias_parameter_ptr != nullptr); auto bias_default_param = bias_parameter_ptr->default_param(); - auto bias_param = std::dynamic_pointer_cast(bias_default_param); - MS_ASSERT(bias_parameter_ptr != nullptr); + auto bias_param = std::dynamic_pointer_cast(bias_default_param); + MS_ASSERT(bias_parameter != nullptr); auto quant_param_holder = GetCNodeQuantHolder(primitive); MS_ASSERT(quant_param_holder != nullptr); auto active_weight_quant_params = quant_param_holder->input_quant_params(); @@ -653,7 +653,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv bias_scales.push_back(scaleX * scaleY); } MS_ASSERT(!bias_scales.empty()); - size_t shape_size = bias_param->tensor_shape_size(); + size_t shape_size = bias_param->DataSize(); // set bias quant param std::vector quant_params; @@ -667,17 +667,16 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv // quant bias data std::vector quant_datas(shape_size); - auto *raw_datas = static_cast(bias_param->tensor_addr()); + auto *raw_datas = static_cast(bias_param->data_c()); if (ComputeBiasDataAndQuantParam(bias_scales, input_scales, raw_datas, quant_param_holder, &quant_params, &quant_datas) != RET_OK) { MS_LOG(ERROR) << "compute bias data failed."; return RET_ERROR; } quant_param_holder->AddInputQuantParam(quant_params); - auto ret = - memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas.data(), shape_size * sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed."; + auto ret = SetTensorData(bias_param, quant_datas.data(), shape_size * sizeof(int32_t)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "set tensor data failed."; return RET_ERROR; } // set dtype @@ -1133,11 +1132,11 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con auto bias = cnode->input(3); auto bias_parameter_ptr = std::dynamic_pointer_cast(bias); auto bias_default_param = bias_parameter_ptr->default_param(); - auto bias_param = std::dynamic_pointer_cast(bias_default_param); - int *bias_datas = static_cast(bias_param->tensor_addr()); + auto bias_param = std::dynamic_pointer_cast(bias_default_param); + int *bias_datas = static_cast(bias_param->data_c()); - if (static_cast(bias_param->tensor_shape_size()) != bias_diff.size()) { - MS_LOG(DEBUG) << "unexpected bias data count: " << bias_param->tensor_shape_size() + if (static_cast(bias_param->DataSize()) != bias_diff.size()) { + MS_LOG(DEBUG) << "unexpected bias data count: " << bias_param->DataSize() << " not the same as bias_diff: " << bias_diff.size(); return RET_ERROR; } @@ -1146,7 +1145,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con << " not the same as bias_diff: " << bias_diff.size(); return RET_ERROR; } - for (int i = 0; i < bias_param->tensor_shape_size(); i++) { + for (int i = 0; i < bias_param->DataSize(); i++) { auto scale = bias_quant_params[i].scale; if (fabs(scale) <= 0.0f) { MS_LOG(ERROR) << "divisor 'scale' cannot be 0."; @@ -1177,36 +1176,20 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con } ShapeVector shape; shape.push_back(bias_diff.size()); - auto type_ptr = TypeIdToType(kNumberTypeFloat32); - auto abstract_tensor = std::make_shared(type_ptr, shape); - parameter->set_abstract(abstract_tensor); - parameter->set_name("added_" + op_name + "_bias"); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int64_t &value) { return static_cast(value); }); - param_value->set_tensor_shape(shape_vector); - param_value->set_tensor_type(kNumberTypeFloat32); - - auto size = sizeof(float) * bias_diff.size(); - char *tensor_data = new (std::nothrow) char[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_MEMORY_FAILED; - } - STATUS status = memcpy_s(tensor_data, size * sizeof(char), bias_diff.data(), size * sizeof(char)); - if (status != EOK) { - MS_LOG(ERROR) << "memcpy_s error: " << status; - delete[] tensor_data; + auto tensor_info = CreateTensorInfo(bias_diff.data(), sizeof(float) * bias_diff.size(), shape, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } - param_value->SetTensorData(tensor_data, size); - parameter->set_default_param(param_value); + auto status = InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return RET_ERROR; + } + parameter->set_name("added_" + op_name + "_bias"); cnode->add_input(parameter); DoBiasQuant(parameter, primitive); - delete[] tensor_data; } else { MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size(); } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 32474b8f388..b799c972b00 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -45,6 +45,7 @@ #include "tools/anf_exporter/anf_exporter.h" #include "tools/converter/quantizer/bitpacking.h" #include "src/common/utils.h" +#include "tools/common/tensor_util.h" #include "abstract/abstract_value.h" #include "securec/include/securec.h" @@ -861,49 +862,39 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { } auto old_cnode = old_cnode_iter->second; auto inputs = cnode->inputs(); - for (size_t i = 0; i < inputs.size(); i++) { - auto input_node = inputs[i]; + for (const auto &input_node : inputs) { if (input_node->isa()) { auto param_node = input_node->cast(); - if (param_node->has_default()) { - ParamValueLitePtr old_param_value = std::static_pointer_cast(param_node->default_param()); - auto new_param_value = std::make_shared(); - - auto copyed_data = malloc(old_param_value->tensor_size()); - if (copyed_data == nullptr) { - MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size(); - return nullptr; - } - memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size()); - - new_param_value->set_tensor_size(old_param_value->tensor_size()); - new_param_value->set_tensor_addr(copyed_data); - new_param_value->set_tensor_shape(old_param_value->tensor_shape()); - new_param_value->set_format(old_param_value->format()); - new_param_value->set_tensor_type(old_param_value->tensor_type()); - - param_node->set_default_param(new_param_value); - } - - auto old_abstract_base = param_node->abstract(); - if (!utils::isa(old_abstract_base)) { - MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); + if (!param_node->has_default()) { + MS_LOG(ERROR) << "Param node has no default parameter: " << cnode_name; + return nullptr; + } + auto old_tensor_info = std::static_pointer_cast(param_node->default_param()); + if (old_tensor_info == nullptr) { + MS_LOG(ERROR) << "Default param of param node is not a tensor info:" << cnode_name; + return nullptr; + } + auto new_tensor_info = lite::CreateTensorInfo(old_tensor_info->data().data(), old_tensor_info->data().nbytes(), + old_tensor_info->shape(), old_tensor_info->data_type()); + if (new_tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; + } + auto status = lite::InitParameterFromTensorInfo(param_node, new_tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return nullptr; } - auto old_abstract = utils::cast(old_abstract_base); - auto new_abstract = std::make_shared(old_abstract->element()->GetTypeTrack(), - old_abstract->GetShapeTrack()); - param_node->set_abstract(new_abstract); } } // end inputs loop } // end cnodes loop return new_func_graph; } -void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) { +void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) { MS_ASSERT(node != nullptr); MS_ASSERT(param_node != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); auto op_name = node->fullname_with_scope(); @@ -917,26 +908,22 @@ void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamVal return; } - *param_value = std::static_pointer_cast((*param_node)->default_param()); - if (*param_value == nullptr) { - MS_LOG(INFO) << "default_param can not cast to ParamValueLite"; + *tensor_info = std::static_pointer_cast((*param_node)->default_param()); + if (*tensor_info == nullptr) { + MS_LOG(INFO) << "default_param can not cast to tensor::Tensor"; return; } } -STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size) { +STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_datas, int new_size, TypeId new_data_type) { MS_ASSERT(weight != nullptr); MS_ASSERT(new_size > 0); - delete[] reinterpret_cast(weight->tensor_addr()); - char *new_tensor_data = new (std::nothrow) char[new_size]; - if (new_tensor_data == nullptr) { - MS_LOG(ERROR) << "new data error"; + weight->set_data_type(new_data_type); + if (new_size != weight->data().nbytes()) { + MS_LOG(ERROR) << "Data size of tensor info is error."; return RET_ERROR; } - memcpy(new_tensor_data, quant_datas, new_size); - - weight->set_tensor_size(new_size); - weight->set_tensor_addr(new_tensor_data); + memcpy(weight->data_c(), quant_datas, new_size); return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index c6ba5c15f4c..f52cd57fdda 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -108,7 +108,7 @@ std::pair OutlierMethod(std::vector min_datas, std::vector< std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam); -STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); +STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_datas, int new_size, TypeId new_data_type); void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas, bool channel_at_first, float *desired_max, float *desired_min); @@ -166,13 +166,13 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan } template -STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, +STATUS DoPerChannelQuant(const tensor::TensorPtr &weight, const QuantType &quant_type, std::vector *quant_params, const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means, std::vector *quant_datas, - std::vector *dequant_datas, bool channel_at_first = true) { - auto dims = weight->tensor_shape(); - size_t elem_count = weight->tensor_shape_size(); - auto *raw_datas = static_cast(weight->tensor_addr()); + std::vector *dequant_datas, TypeId quant_data_type, bool channel_at_first = true) { + auto dims = weight->shape(); + size_t elem_count = weight->DataSize(); + auto *raw_datas = static_cast(weight->data_c()); auto channels = dims[0]; if (!channel_at_first) { if (dims.size() != 2) { @@ -253,7 +253,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant } quant_params->emplace_back(quant_param); } - auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T)); + auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T), quant_data_type); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; return RET_ERROR; @@ -262,12 +262,13 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant } template -STATUS DoPerLayerQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, +STATUS DoPerLayerQuant(const tensor::TensorPtr &weight, const QuantType &quant_type, std::vector *quant_params, const int &quant_max, const int &quant_min, - const size_t &bit_num, const bool &k_means, std::vector *quant_datas) { - auto dims = weight->tensor_shape(); - size_t elem_count = weight->tensor_shape_size(); - auto *raw_datas = static_cast(weight->tensor_addr()); + const size_t &bit_num, const bool &k_means, std::vector *quant_datas, + TypeId quant_data_type) { + auto dims = weight->shape(); + size_t elem_count = weight->DataSize(); + auto *raw_datas = static_cast(weight->data_c()); float min = FLT_MAX; float max = -FLT_MIN; for (uint32_t i = 0; i < elem_count; i++) { @@ -293,7 +294,7 @@ STATUS DoPerLayerQuant(const ParamValueLitePtr &weight, const QuantType &quant_t (*quant_datas)[i] = quant_data; } } - auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T)); + auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T), quant_data_type); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; return RET_ERROR; @@ -301,7 +302,7 @@ STATUS DoPerLayerQuant(const ParamValueLitePtr &weight, const QuantType &quant_t return RET_OK; } template -STATUS DoBitPack(const ParamValueLitePtr &weight, const size_t &bit_num, const std::vector &quant_datas) { +STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const std::vector &quant_datas) { if (bit_num != 8 && bit_num != 16) { std::vector data{}; for (size_t i = 0; i < quant_datas.size(); ++i) { @@ -310,7 +311,8 @@ STATUS DoBitPack(const ParamValueLitePtr &weight, const size_t &bit_num, const s if (bit_num > 0 && bit_num < 8) { std::vector pack_data{}; BitPack::BitPacking(bit_num, data, &pack_data); - auto status = UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint8_t)); + auto status = + UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint8_t), kNumberTypeUInt8); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; return RET_ERROR; @@ -318,7 +320,8 @@ STATUS DoBitPack(const ParamValueLitePtr &weight, const size_t &bit_num, const s } else if (bit_num > 8 && bit_num < 16) { std::vector pack_data{}; BitPack::BitPacking(bit_num, data, &pack_data); - auto status = UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint16_t)); + auto status = + UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint16_t), kNumberTypeUInt16); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; return RET_ERROR; @@ -329,11 +332,12 @@ STATUS DoBitPack(const ParamValueLitePtr &weight, const size_t &bit_num, const s } template -STATUS QuantFilter(const ParamValueLitePtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max, - int quant_min, size_t bit_num, bool per_channel, int index = 1, bool k_means = false) { +STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max, + int quant_min, size_t bit_num, bool per_channel, TypeId quant_data_type, int index = 1, + bool k_means = false) { MS_ASSERT(weight != nullptr); MS_ASSERT(primitive != nullptr); - auto dims = weight->tensor_shape(); + auto dims = weight->shape(); if (per_channel) { if (dims.size() <= 1) { MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel"; @@ -342,8 +346,8 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const PrimitivePtr &primitiv } std::vector quant_params; - size_t elem_count = weight->tensor_shape_size(); - auto *raw_data = static_cast(weight->tensor_addr()); + size_t elem_count = weight->DataSize(); + auto *raw_data = static_cast(weight->data_c()); if (raw_data == nullptr) { MS_LOG(ERROR) << "rawDatas is nullptr"; return RET_ERROR; @@ -354,7 +358,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const PrimitivePtr &primitiv int ret = RET_OK; if (per_channel) { bool channel_at_first = true; - if (primitive->name() == ops::kNameMatMul && weight->tensor_shape().size() == 2) { + if (primitive->name() == ops::kNameMatMul && weight->shape().size() == 2) { auto matmul_prim = primitive->cast>(); MS_ASSERT(matmul_prim != nullptr); channel_at_first = @@ -362,7 +366,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const PrimitivePtr &primitiv } // channel at first ret = DoPerChannelQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data, - &dequant_datas, channel_at_first); + &dequant_datas, quant_data_type, channel_at_first); if (ret == RET_CONTINUE) { return ret; } else if (ret != RET_OK) { @@ -370,7 +374,8 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const PrimitivePtr &primitiv return ret; } } else { - ret = DoPerLayerQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data); + ret = DoPerLayerQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data, + quant_data_type); if (ret != RET_OK) { MS_LOG(ERROR) << "Do per layer quant failed."; return ret; @@ -422,6 +427,6 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); -void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value); +void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info); } // namespace mindspore::lite::quant #endif diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index 02ad8c767f5..fdb9bc8fe03 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -25,7 +25,6 @@ #include "ir/func_graph.h" #include "ir/anf.h" #include "base/base.h" -#include "src/param_value_lite.h" #include "tools/converter/converter_flags.h" #include "tools/converter/quant_param_holder.h" diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 1b1902c29d0..a6523951a20 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -57,10 +57,10 @@ WeightQuantizer::~WeightQuantizer() { } } -STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, +STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node, const PrimitivePtr &primitive) { // set dtype - param_value->set_tensor_type(type_id_); + tensor_info->set_data_type(type_id_); auto abstract_base = param_node->abstract(); if (abstract_base == nullptr) { MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); @@ -78,7 +78,7 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr return RET_OK; } -STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { +STATUS WeightQuantizer::DoConvQuantize(const CNodePtr &cnode) { auto primitive = GetValueNode(cnode->input(0)); if (primitive == nullptr) { MS_LOG(ERROR) << "primitive is nullptr"; @@ -91,24 +91,25 @@ STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { } ParameterPtr param_node; - ParamValueLitePtr param_value; + tensor::TensorPtr tensor_info; - GetLiteParameter(input_node, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr) { + GetLiteParameter(input_node, ¶m_node, &tensor_info); + if (param_node == nullptr || tensor_info == nullptr) { MS_LOG(ERROR) << "GetLiteParameter error"; return RET_ERROR; } - if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { - MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); + if (tensor_info->data_type() != mindspore::kNumberTypeFloat32) { + MS_LOG(ERROR) << "model weight data type invalid which is " << tensor_info->data_type(); return RET_ERROR; } auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, + type_id_); } else if (type_id_ == kNumberTypeInt16) { - status = - QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, + type_id_); } if (status == RET_CONTINUE) { return RET_OK; @@ -116,7 +117,7 @@ STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - status = SetAbstract(param_value, param_node, primitive); + status = SetAbstract(tensor_info, param_node, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; @@ -124,9 +125,9 @@ STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { return RET_OK; } -STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { +STATUS WeightQuantizer::DoMulQuantize(const CNodePtr &cnode) { auto already_quant = false; - ParamValueLitePtr param_value = nullptr; + tensor::TensorPtr tensor_info = nullptr; ParameterPtr param_node = nullptr; int index = 0; for (size_t i = 1; i < cnode->size(); i++) { @@ -134,18 +135,18 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { if (inputNode->isa()) { param_node = inputNode->cast(); if ((param_node != nullptr) && param_node->has_default()) { - param_value = std::static_pointer_cast(param_node->default_param()); - if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) { - param_value = nullptr; + tensor_info = std::static_pointer_cast(param_node->default_param()); + if ((tensor_info == nullptr) || (tensor_info->Size() == 0) || (tensor_info->data_c() == nullptr)) { + tensor_info = nullptr; continue; - } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || - param_value->tensor_type() == mindspore::kNumberTypeInt16) { + } else if (tensor_info->data_type() == mindspore::kNumberTypeInt8 || + tensor_info->data_type() == mindspore::kNumberTypeInt16) { MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been " << " quantized"; already_quant = true; break; - } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { - param_value = nullptr; + } else if (tensor_info->data_type() != mindspore::kNumberTypeFloat32) { + tensor_info = nullptr; continue; } else { index = i; @@ -159,7 +160,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { return RET_OK; } - if (param_value == nullptr) { + if (tensor_info == nullptr) { MS_LOG(WARNING) << cnode->fullname_with_scope() << " No valid input param node !"; return RET_OK; } @@ -172,11 +173,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, - index - 1); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, + type_id_, index - 1); } else if (type_id_ == kNumberTypeInt16) { - status = QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, - index - 1); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true, + type_id_, index - 1); } if (status == RET_CONTINUE) { return RET_OK; @@ -184,7 +185,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - status = SetAbstract(param_value, param_node, primitive); + status = SetAbstract(tensor_info, param_node, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; @@ -193,7 +194,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { return RET_OK; } -STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) { +STATUS WeightQuantizer::DoLstmQuantize(const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); auto op_name = cnode->fullname_with_scope(); @@ -226,32 +227,32 @@ STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) { return status; } -STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { +STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) { auto primitive = GetValueNode(cnode->input(0)); MS_ASSERT(primitive != nullptr); auto first_input = cnode->input(1); ParameterPtr param_node; - ParamValueLitePtr param_value; - GetLiteParameter(first_input, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + tensor::TensorPtr tensor_info; + GetLiteParameter(first_input, ¶m_node, &tensor_info); + if (param_node == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) { MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight"; return RET_OK; } - if (param_value->tensor_size() / 4 < quant_strategy_->m_weight_size_) { - MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < " + if (tensor_info->Size() / 4 < quant_strategy_->m_weight_size_) { + MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << tensor_info->Size() / 4 << " < " << quant_strategy_->m_weight_size_; return RET_OK; } auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = - QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, + type_id_, 0); } else if (type_id_ == kNumberTypeInt16) { - status = - QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, type_id_, 0); } if (status == RET_CONTINUE) { return RET_OK; @@ -259,7 +260,7 @@ STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - status = SetAbstract(param_value, param_node, primitive); + status = SetAbstract(tensor_info, param_node, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; @@ -272,28 +273,28 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr auto op_name = cnode->fullname_with_scope(); auto weight_i = cnode->input(index); ParameterPtr param_node; - ParamValueLitePtr param_value; - GetLiteParameter(weight_i, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr) { + + tensor::TensorPtr tensor_info; + GetLiteParameter(weight_i, ¶m_node, &tensor_info); + if (param_node == nullptr || tensor_info == nullptr) { MS_LOG(INFO) << "LSTM input index " << index << " is not weight"; return RET_OK; } - if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { - MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; + if (tensor_info->data_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(WARNING) << "tensor_info tensor type is: " << tensor_info->data_type() << " not quant"; return RET_OK; } - if (param_value->tensor_size() / 4 < quant_strategy_->m_weight_size_) { - MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " - << quant_strategy_->m_weight_size_; + if (tensor_info->Size() / 4 < quant_strategy_->m_weight_size_) { + MS_LOG(INFO) << op_name << " weight_i cnt: " << tensor_info->Size() / 4 << " < " << quant_strategy_->m_weight_size_; return RET_OK; } auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, - index - 1); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, + type_id_, index - 1); } else if (type_id_ == kNumberTypeInt16) { - status = QuantFilter(param_value, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, index - 1); + status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, type_id_, index - 1); } if (status == RET_CONTINUE) { return RET_OK; @@ -301,7 +302,7 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - status = SetAbstract(param_value, param_node, primitive); + status = SetAbstract(tensor_info, param_node, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; @@ -378,7 +379,7 @@ float CompareOutputData(const std::unordered_map &input_node, const std::string &op_name, - ParameterPtr *param_node, ParamValueLitePtr *param_value) { + ParameterPtr *param_node, tensor::TensorPtr *tensor_info) { if (!input_node->isa()) { MS_LOG(WARNING) << op_name << " the second input is not parameter"; return RET_CONTINUE; @@ -485,30 +486,30 @@ STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr &inp MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; return RET_CONTINUE; } - *param_value = std::static_pointer_cast((*param_node)->default_param()); - if (*param_value == nullptr) { + *tensor_info = std::static_pointer_cast((*param_node)->default_param()); + if (*tensor_info == nullptr) { MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; return RET_CONTINUE; } - if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) { + if ((*tensor_info)->data_type() != TypeId::kNumberTypeFloat32) { MS_LOG(WARNING) << op_name << " the second input type is not float"; return RET_CONTINUE; } return RET_OK; } STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, - const ParamValueLitePtr ¶m_value, const PrimitivePtr &primitive) { + const tensor::TensorPtr &tensor_info, const PrimitivePtr &primitive) { int status; type_id_ = TypeId::kNumberTypeInt8; int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); if (type_id_ == TypeId::kNumberTypeInt8) { - status = QuantFilter(param_value, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, - bit_num_t, true); + status = QuantFilter(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, + bit_num_t, true, type_id_); } else if (type_id_ == TypeId::kNumberTypeInt16) { - status = QuantFilter(param_value, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, - bit_num_t, true); + status = QuantFilter(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, + bit_num_t, true, type_id_); } else { MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; return RET_ERROR; @@ -519,7 +520,7 @@ STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m MS_LOG(ERROR) << "quant filter failed."; return RET_ERROR; } - status = SetAbstract(param_value, param_node, primitive); + status = SetAbstract(tensor_info, param_node, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; @@ -542,24 +543,24 @@ STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { auto input_node = cnode->input(2); ParameterPtr param_node; - ParamValueLitePtr param_value; - status = GetParamNodeAndValue(input_node, op_name, ¶m_node, ¶m_value); + tensor::TensorPtr tensor_info; + status = GetParamNodeAndValue(input_node, op_name, ¶m_node, &tensor_info); if (status == RET_CONTINUE) { continue; } // copy origin data in case to recover - auto *raw_data = static_cast(param_value->tensor_addr()); - auto elem_count = param_value->tensor_shape_size(); + auto *raw_data = static_cast(tensor_info->data_c()); + auto elem_count = tensor_info->DataSize(); std::unique_ptr origin_data(new (std::nothrow) float[elem_count]); - auto ret = memcpy_s(origin_data.get(), sizeof(float) * elem_count, raw_data, param_value->tensor_size()); + auto ret = memcpy_s(origin_data.get(), sizeof(float) * elem_count, raw_data, tensor_info->Size()); if (ret != EOK) { MS_LOG(ERROR) << "memcpy fail: " - << " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size(); + << " dst size: " << sizeof(float) * elem_count << " src size: " << tensor_info->Size(); return RET_ERROR; } // 1. try quant for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { - status = TryQuant(bit_num_t, param_node, param_value, primitive); + status = TryQuant(bit_num_t, param_node, tensor_info, primitive); if (status != RET_OK) { MS_LOG(ERROR) << "TryQuant failed."; return RET_ERROR; @@ -616,7 +617,8 @@ STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { MS_LOG(DEBUG) << "op: " << op_name << " intermediate bit: " << bit_num_t << " mean_error: " << mean_error << " [recover]"; // recover - status = UpdateTensorDataAndSize(param_value, origin_data.get(), sizeof(float) * elem_count); + status = + UpdateTensorDataAndSize(tensor_info, origin_data.get(), sizeof(float) * elem_count, kNumberTypeFloat32); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize fail"; return RET_ERROR; @@ -631,7 +633,7 @@ STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { return status; } -STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { +STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) { // 0.2 Parse input calib files auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); if (status != RET_OK) { @@ -669,7 +671,7 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { return RET_OK; } -STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { +STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); for (auto &cnode : func_graph->GetOrderedCnodes()) { auto primitive = GetValueNode>(cnode->input(0)); diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index bc3f0d27e85..111de02303c 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -41,10 +41,10 @@ class WeightQuantizer : public Quantizer { ~WeightQuantizer(); STATUS DoQuantize(FuncGraphPtr func_graph) override; - STATUS DoConvQuantize(CNodePtr); - STATUS DoMulQuantize(CNodePtr); - STATUS DoLstmQuantize(CNodePtr cnode); - STATUS DoGatherQuantize(CNodePtr cnode); + STATUS DoConvQuantize(const CNodePtr &); + STATUS DoMulQuantize(const CNodePtr &); + STATUS DoLstmQuantize(const CNodePtr &cnode); + STATUS DoGatherQuantize(const CNodePtr &cnode); STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const PrimitivePtr &primitive, const int &index); @@ -61,16 +61,17 @@ class WeightQuantizer : public Quantizer { std::vector> images_; // multi_input, [[mode_input_0], [model_input_1]...] std::vector> fp32_output_tensors_; - STATUS DoMixedQuant(FuncGraphPtr); - STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, const PrimitivePtr &primitive); - STATUS DoFixedQuant(FuncGraphPtr); - STATUS RunFp32Graph(FuncGraphPtr); + STATUS DoMixedQuant(const FuncGraphPtr &); + STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node, + const PrimitivePtr &primitive); + STATUS DoFixedQuant(const FuncGraphPtr &); + STATUS RunFp32Graph(const FuncGraphPtr &); STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); STATUS CheckImageCnt(); STATUS GetParamNodeAndValue(const std::shared_ptr &input_node, const std::string &op_name, - ParameterPtr *param_node, ParamValueLitePtr *param_value); - STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const ParamValueLitePtr ¶m_value, + ParameterPtr *param_node, tensor::TensorPtr *tensor_info); + STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info, const PrimitivePtr &primitive); STATUS DoQuantSearch(const FuncGraphPtr &func_graph); }; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index aacd4c3324d..29cef5187fd 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -23,6 +23,7 @@ #include "Eigen/Core" #include "ops/fusion/conv2d_fusion.h" #include "src/common/common.h" +#include "tools/common/tensor_util.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" @@ -433,33 +434,31 @@ int CheckLeastInputSize(const CNodePtr &node, const int size) { } ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, - const ParamValueLitePtr &weight_tensor) { + const tensor::TensorPtr &weight_tensor) { auto bias_parameter = func_graph->add_parameter(); MS_ASSERT(bias_parameter != nullptr); - std::vector shape = {kernel_num}; - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = - std::make_shared(TypeIdToType(weight_tensor->tensor_type()), shape_vector); - bias_parameter->set_abstract(abstract_tensor); + std::vector shape_vector = {kernel_num}; + auto tensor_info = lite::CreateTensorInfo(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t), shape_vector, + weight_tensor->data_type()); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + return nullptr; + } + auto status = lite::InitParameterFromTensorInfo(bias_parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; + } - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->SetTensorData(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t)); - param_value->set_format(weight_tensor->format()); - param_value->set_tensor_type(weight_tensor->tensor_type()); - param_value->set_tensor_shape(shape); - bias_parameter->set_default_param(param_value); return bias_parameter; } -ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { +tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node) { MS_ASSERT(node != nullptr); if (!utils::isa(node)) { if (utils::isa(node)) { auto valueNode = node->cast(); - auto value = std::dynamic_pointer_cast(valueNode->value()); + auto value = std::dynamic_pointer_cast(valueNode->value()); if (value != nullptr) { return value; } @@ -469,8 +468,8 @@ ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { } auto param = node->cast(); MS_ASSERT(param != nullptr); - auto param_value = std::dynamic_pointer_cast(param->default_param()); - return param_value; + auto tensor_info = std::dynamic_pointer_cast(param->default_param()); + return tensor_info; } AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) { @@ -526,11 +525,11 @@ bool IsParamNode(const BaseRef &n) { return false; } auto param = utils::cast(n)->default_param(); - auto tensor = std::dynamic_pointer_cast(param); + auto tensor = std::dynamic_pointer_cast(param); if (tensor == nullptr) { return false; } - return tensor->tensor_addr() != nullptr; + return tensor->data_c() != nullptr; } bool IsConvNode(const BaseRef &n) { @@ -717,8 +716,8 @@ std::shared_ptr>> GetRealNodeUsedListByOu } return output_node_list; } -STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, - int32_t *filterH, int32_t *filterW) { +STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int64_t *filterK, int64_t *filterC, + int64_t *filterH, int64_t *filterW) { MS_ASSERT(oriDims.size() == 4); std::unordered_map maps = { {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2}, @@ -780,7 +779,7 @@ STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, return RET_OK; } -STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, +STATUS SetFilterDim(const tensor::TensorPtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); std::unordered_map maps = { @@ -796,22 +795,22 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3 switch (maps.find(type)->second) { case 1: - tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); + tensor->set_shape({filterH, filterW, filterC, filterK}); break; case 2: - tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); + tensor->set_shape({filterH, filterW, filterK, filterC}); break; case 3: - tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); + tensor->set_shape({filterK, filterC, filterH, filterW}); break; case 4: - tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); + tensor->set_shape({filterC, filterK, filterH, filterW}); break; case 5: - tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); + tensor->set_shape({filterC, filterH, filterW, filterK}); break; case 6: - tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); + tensor->set_shape({filterK, filterH, filterW, filterC}); break; default: MS_LOG(ERROR) << "Unsupported transFilterType: " << type; @@ -981,7 +980,7 @@ void TransFilterDataKHWC2CHWK(kTransFilterType type, int32_t filterK, int32_t fi } template -static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, +static STATUS TransFilterData(const tensor::TensorPtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); int count = filterH * filterW * filterC * filterK; @@ -995,7 +994,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType return RET_ERROR; } - void *originWeightData = tensor->tensor_addr(); + void *originWeightData = tensor->data_c(); T *weightData = static_cast(originWeightData); if (weightData == nullptr) { @@ -1046,7 +1045,7 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } } - auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T)); + auto ret = ::memcpy_s(tensor->data_c(), count * sizeof(T), buf.get(), count * sizeof(T)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; @@ -1055,18 +1054,18 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType } template -static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { +static STATUS TransFilterFormat(const tensor::TensorPtr &tensor, kTransFilterType type) { MS_ASSERT(tensor != nullptr); - auto oriDims = tensor->tensor_shape(); + auto oriDims = tensor->shape_c(); if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); return lite::RET_ERROR; } - int32_t filterH; - int32_t filterW; - int32_t filterC; - int32_t filterK; + int64_t filterH; + int64_t filterW; + int64_t filterC; + int64_t filterK; auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); if (status != lite::RET_OK) { MS_LOG(ERROR) << "GetFilterDim failed: " << status; @@ -1086,7 +1085,7 @@ static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterTyp return lite::RET_OK; } -STATUS TransFilterFormatWithType(const ParamValueLitePtr &tensor, TypeId data_type, +STATUS TransFilterFormatWithType(const tensor::TensorPtr &tensor, TypeId data_type, kTransFilterType trans_filter_type) { if (data_type == kNumberTypeFloat32) { return TransFilterFormat(tensor, trans_filter_type); @@ -1102,17 +1101,16 @@ STATUS TransFilterFormatWithType(const ParamValueLitePtr &tensor, TypeId data_ty } } -STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { +STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) { if (tensor == nullptr) { return lite::RET_NULL_PTR; } - auto ori_dims = tensor->tensor_shape(); + auto ori_dims = tensor->shape_c(); if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size(); return lite::RET_ERROR; } - auto src_format = tensor->format(); - auto data_type = tensor->tensor_type(); + auto data_type = tensor->data_type(); lite::STATUS status; std::unordered_map khwc_trans_maps = { {schema::Format::Format_KCHW, kKCHW2KHWC}, {schema::Format::Format_CKHW, kCKHW2KHWC}, @@ -1195,42 +1193,42 @@ STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_for } ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const ParamValueLitePtr ¶m_value) { + const tensor::TensorPtr &tensor_info) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(cnode != nullptr); MS_ASSERT(param_value != nullptr); auto param_node = func_graph->add_parameter(); - auto shape = param_value->tensor_shape(); + auto shape = tensor_info->shape(); std::vector shape_vector; std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int &val) { return static_cast(val); }); - auto data_type = param_value->tensor_type() == kNumberTypeInt64 ? kNumberTypeInt32 : param_value->tensor_type(); - auto abstract_tensor = std::make_shared(TypeIdToType(data_type), shape_vector); - param_node->set_abstract(abstract_tensor); + auto data_type = tensor_info->data_type() == kNumberTypeInt64 ? kNumberTypeInt32 : tensor_info->data_type(); if (utils::isa(node)) { param_node->set_name(node->cast()->fullname_with_scope()); } else if (utils::isa(node)) { param_node->set_name(node->cast()->name()); } - ParamValueLitePtr param_value_new = std::make_shared(); - param_value_new->set_format(param_value->format()); - param_value_new->set_tensor_shape(shape); + auto tensor_info_new = std::make_shared(data_type, shape_vector); + if (tensor_info_new == nullptr) { + MS_LOG(ERROR) << "new tensor::Tensor failed."; + return nullptr; + } size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - if (param_value->tensor_size() == 0) { - if (param_value->tensor_type() == kNumberTypeInt64) { - param_value_new->set_tensor_type(kNumberTypeInt32); + if (tensor_info->Size() == 0) { + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; } - param_node->set_default_param(param_value_new); return param_node; } - if (param_value->tensor_type() == kNumberTypeInt64) { - param_value_new->set_tensor_type(kNumberTypeInt32); - auto *tensor_data = new (std::nothrow) int[data_count]; + if (tensor_info->data_type() == kNumberTypeInt64) { + auto *tensor_data = reinterpret_cast(tensor_info_new->data_c()); if (tensor_data == nullptr) { MS_LOG(ERROR) << "new data failed"; return nullptr; } - auto *origin_data = reinterpret_cast(param_value->tensor_addr()); + auto *origin_data = reinterpret_cast(tensor_info->data_c()); for (size_t i = 0; i < data_count; ++i) { if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; @@ -1239,23 +1237,24 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr tensor_data[i] = static_cast(origin_data[i]); } } - param_value_new->SetTensorData(tensor_data, data_count * sizeof(int32_t)); } else { - param_value_new->set_tensor_type(param_value->tensor_type()); - char *tensor_data = new (std::nothrow) char[param_value->tensor_size()]; + tensor_info_new->set_data_type(tensor_info->data_type()); + auto *tensor_data = reinterpret_cast(tensor_info_new->data_c()); if (tensor_data == nullptr) { MS_LOG(ERROR) << "new data failed"; return nullptr; } - if (memcpy_s(tensor_data, param_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()) != - lite::RET_OK) { + if (memcpy_s(tensor_data, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) { MS_LOG(ERROR) << "memcpy data failed."; - delete[] tensor_data; return nullptr; } - param_value_new->SetTensorData(tensor_data, param_value->tensor_size()); } - param_node->set_default_param(param_value_new); + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; + } + param_node->set_default_param(tensor_info_new); return param_node; } @@ -1264,21 +1263,19 @@ ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const in MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); - - auto type_ptr = TypeIdToType(kNumberTypeInt32); - auto abstract_tensor = std::make_shared(type_ptr); - param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape({1}); - param_value->set_tensor_type(kNumberTypeInt32); + auto tensor_info = lite::CreateTensorInfo(&data, sizeof(int32_t), {1}, kNumberTypeInt32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; + } - char *default_data = new (std::nothrow) char[sizeof(int32_t)]; - *(reinterpret_cast(default_data)) = data; - param_value->SetTensorData(default_data, sizeof(int32_t)); - param_node->set_default_param(param_value); + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; + } return param_node; } @@ -1287,29 +1284,21 @@ ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std: MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); - - auto type_ptr = TypeIdToType(kNumberTypeInt32); - std::vector shape_vector{static_cast(data.size())}; - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - std::vector shape{static_cast(data.size())}; - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(kNumberTypeInt32); - - if (!data.empty()) { - char *default_data = new (std::nothrow) char[data.size() * sizeof(int32_t)]; - if (memcpy_s(default_data, data.size() * sizeof(int32_t), data.data(), data.size() * sizeof(int32_t)) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] default_data; - return nullptr; - } - param_value->SetTensorData(default_data, data.size() * sizeof(int32_t)); + std::vector shape_vector{static_cast(data.size())}; + auto tensor_info = lite::CreateTensorInfo(data.data(), data.size() * sizeof(int32_t), shape_vector, kNumberTypeInt32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; } - param_node->set_default_param(param_value); + + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; + } + return param_node; } @@ -1318,39 +1307,28 @@ ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const st MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); + param_node->set_name(node_name); - auto type_ptr = TypeIdToType(kNumberTypeInt32); std::vector shape_vector; shape_vector.push_back(data.size()); shape_vector.push_back(2); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - param_node->set_abstract(abstract_tensor); - param_node->set_name(node_name); - - ParamValueLitePtr param_value = std::make_shared(); - - MS_ASSERT(param_value != nullptr); - std::vector shape; - shape.push_back(data.size()); - shape.push_back(2); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(kNumberTypeInt32); - std::vector data_1d; for (auto pair : data) { data_1d.insert(data_1d.end(), pair.begin(), pair.end()); } auto size = data_1d.size() * sizeof(int32_t); - char *default_data = new (std::nothrow) char[size]; - if (memcpy_s(default_data, size, data_1d.data(), size) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] default_data; + auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeInt32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; + } + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return nullptr; } - param_value->SetTensorData(default_data, size); - param_node->set_default_param(param_value); return param_node; } @@ -1359,26 +1337,18 @@ ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const MS_ASSERT(func_graph != nullptr); MS_ASSERT(data.size() != 0); auto param_node = func_graph->add_parameter(); - - auto type_ptr = TypeIdToType(kNumberTypeFloat32); - std::vector shape_vector = {1}; - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - param_node->set_abstract(abstract_tensor); param_node->set_name(node_name); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape({1}); - param_value->set_tensor_type(kNumberTypeFloat32); - - char *default_data = new (std::nothrow) char[sizeof(float)]; - if (memcpy_s(default_data, sizeof(float), &data, sizeof(float)) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - delete[] default_data; + auto tensor_info = lite::CreateTensorInfo(&data, sizeof(float), {1}, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; + } + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return nullptr; } - param_value->SetTensorData(default_data, sizeof(float)); - param_node->set_default_param(param_value); return param_node; } } // namespace opt diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 80b7bf797c5..9b4cb989e4f 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -26,7 +26,6 @@ #include "src/common/utils.h" #include "backend/optimizer/common/pattern_engine.h" #include "schema/inner/model_generated.h" -#include "src/param_value_lite.h" #include "tools/converter/converter_context.h" using PrimitiveCPtr = std::shared_ptr; @@ -39,6 +38,7 @@ inline const PrimitivePtr kPrimDivFusion = std::make_shared("DivFusio inline const PrimitivePtr kPrimErf = std::make_shared("Erf"); inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared("make_tuple"); inline const PrimitivePtr kPrimIdentity = std::make_shared("Identity"); +constexpr auto kWeightFormat = "weight_format"; std::vector CastToInt(const ValuePtr &value); std::vector> CastToVec2DInt(const ValuePtr &value); @@ -66,7 +66,7 @@ int CheckIfNodeIsParam(const AnfNodePtr &node); int CheckLeastInputSize(const CNodePtr &node, int size); ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, - const ParamValueLitePtr &weight_tensor); + const tensor::TensorPtr &weight_tensor); bool IsParamNode(const BaseRef &n); @@ -88,7 +88,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); -ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); +tensor::TensorPtr GetTensorInfo(const AnfNodePtr &node); AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index); @@ -118,23 +118,23 @@ enum kTransFilterType { kHWKC2KHWC }; -STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, - int32_t *filterH, int32_t *filterW); +STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int64_t *filterK, int64_t *filterC, + int64_t *filterH, int64_t *filterW); -STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, +STATUS SetFilterDim(const tensor::TensorPtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); template -static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, +static STATUS TransFilterData(const tensor::TensorPtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); template -static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); +static lite::STATUS TransFilterFormat(const tensor::TensorPtr &tensor, kTransFilterType type); -STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); +STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format); ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const ParamValueLitePtr ¶m_value); + const tensor::TensorPtr &tensor_info); ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data, const std::string &node_name); diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc index d89cf59f609..f87efb8f797 100644 --- a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -19,8 +19,7 @@ #include #include "ops/mat_mul.h" #include "schema/inner/model_generated.h" -#include "src/param_value_lite.h" -#include "utils/utils.h" +#include "tools/common/tensor_util.h" #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -52,12 +51,12 @@ void *GetInputAddr(const AnfNodePtr &node, size_t input_index) { } if (cnode->input(input_index)->isa()) { auto param_input = cnode->input(input_index)->cast(); - auto param_value = std::dynamic_pointer_cast(param_input->default_param()); - if (param_value == nullptr) { - MS_LOG(ERROR) << "param not paramValueLite"; + auto tensor_info = std::dynamic_pointer_cast(param_input->default_param()); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "param not tensor::Tensor"; return nullptr; } - return param_value->tensor_addr(); + return tensor_info->data_c(); } MS_LOG(ERROR) << "input not parameter"; return nullptr; @@ -68,42 +67,36 @@ STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPt auto joint_fullconnect_size = stack_node->inputs().size() - 1; auto fc = stack_node->input(1)->cast(); auto fc_weight = fc->input(2)->cast(); - auto fc_weight_param = std::dynamic_pointer_cast(fc_weight->default_param()); - auto tensor_size = fc_weight_param->tensor_size(); - auto rmatmul_input_shape = fc_weight_param->tensor_shape(); - auto new_tensor_data = new (std::nothrow) int8_t[joint_fullconnect_size * tensor_size]; - if (new_tensor_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; + auto fc_weight_param = std::dynamic_pointer_cast(fc_weight->default_param()); + auto tensor_size = fc_weight_param->Size(); + auto rmatmul_input_shape = fc_weight_param->shape(); + + rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size); + std::vector shape_vector(rmatmul_input_shape.begin(), rmatmul_input_shape.end()); + auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, fc_weight_param->data_type()); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; return RET_ERROR; } for (size_t i = 1; i < joint_fullconnect_size + 1; i++) { auto tensor_addr = GetInputAddr(stack_node->input(i), 2); if (tensor_addr == nullptr) { MS_LOG(ERROR) << "input tensor addr nullptr"; - delete[] new_tensor_data; return RET_ERROR; } - if (EOK != memcpy_s(new_tensor_data + (i - 1) * tensor_size, tensor_size, tensor_addr, tensor_size)) { + if (EOK != memcpy_s(static_cast(tensor_info->data_c()) + (i - 1) * tensor_size, tensor_size, tensor_addr, + tensor_size)) { MS_LOG(ERROR) << "memcpy_s data failed"; - delete[] new_tensor_data; return RET_ERROR; } } - rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size); - auto type_ptr = TypeIdToType(fc_weight_param->tensor_type()); - std::vector shape_vector; - (void)std::transform(rmatmul_input_shape.begin(), rmatmul_input_shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - rmatmul_input->set_abstract(abstract_tensor); + auto status = lite::InitParameterFromTensorInfo(rmatmul_input, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return RET_ERROR; + } rmatmul_input->set_name(stack_node->fullname_with_scope() + "right_parameter"); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(rmatmul_input_shape); - param_value->set_tensor_type(fc_weight_param->tensor_type()); - param_value->set_format(fc_weight_param->format()); - param_value->SetTensorData(new_tensor_data, joint_fullconnect_size * tensor_size); - rmatmul_input->set_default_param(param_value); + return RET_OK; } } // namespace diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 93928d3a0c3..cc1d6d577c4 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -18,11 +18,11 @@ #include #include #include -#include #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/anf_exporter/anf_exporter.h" #include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" #include "src/common/common.h" #include "src/ops/populate/populate_register.h" #include "src/kernel_registry.h" @@ -85,34 +85,27 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { auto parameter = func_graph->add_parameter(); std::vector shape(tensor->shape()); std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto type_id = static_cast(tensor->data_type()); - auto type_ptr = TypeIdToType(type_id); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - param_value->set_format(tensor->format()); + auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + return nullptr; + } if (tensor->MutableData() != nullptr) { - auto size = tensor->Size(); - auto tensor_data = new (std::nothrow) uint8_t[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; - return nullptr; - } - auto ret = memcpy_s(tensor_data, size, tensor->MutableData(), tensor->Size()); + auto tensor_data = static_cast(tensor_info->data_c()); + auto ret = memcpy_s(tensor_data, tensor->Size(), tensor->MutableData(), tensor->Size()); if (ret != EOK) { - delete[] tensor_data; MS_LOG(ERROR) << "memcpy error: " << ret; return nullptr; } - param_value->SetTensorData(tensor_data, size); } - parameter->set_default_param(param_value); + auto status = lite::InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return nullptr; + } return parameter; } kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector *outputs, const CNodePtr &cnode, @@ -203,11 +196,11 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector const std::vector &outputs) { MS_ASSERT(cnode != nullptr); auto prim = GetValueNode(cnode->input(0)); - auto quant_param_valueptr = prim->GetAttr("quant_params"); - if (quant_param_valueptr == nullptr) { + auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); + if (quant_tensor_info_ptr == nullptr) { return lite::RET_OK; } - auto quant_param_holder = quant_param_valueptr->cast(); + auto quant_param_holder = quant_tensor_info_ptr->cast(); if (quant_param_holder == nullptr) { MS_LOG(ERROR) << "quant param is invalid."; return lite::RET_ERROR; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index 7baaaaf26a1..d664b8fde50 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -18,7 +18,7 @@ #include "ops/fusion/add_fusion.h" #include "ops/fusion/conv2d_fusion.h" #include "ops/fusion/conv2d_transpose_fusion.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -102,9 +102,9 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co return lite::RET_INVALID_OP_ATTR; } auto add_weight_param = bias_add_weight->cast()->default_param(); - auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); - auto add_weight_data = reinterpret_cast(add_weight_tensor->tensor_addr()); - auto add_weight_shape = add_weight_tensor->tensor_shape(); + auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); + auto add_weight_data = reinterpret_cast(add_weight_tensor->data_c()); + auto add_weight_shape = add_weight_tensor->shape(); if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { for (int i = 0; i < kernel_nums; i++) { add_bias_data[i] = *add_weight_data; @@ -122,20 +122,20 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co return lite::RET_INVALID_OP_ATTR; } auto conv_bias_param = conv_bias_node->cast()->default_param(); - auto conv_bias_tensor = std::dynamic_pointer_cast(conv_bias_param); - if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) { + auto conv_bias_tensor = std::dynamic_pointer_cast(conv_bias_param); + if (conv_bias_tensor->shape().empty() || conv_bias_tensor->shape()[0] != kernel_nums) { MS_LOG(ERROR) << "conv_bias_node shape error"; delete[] add_bias_data; return lite::RET_INVALID_OP_ATTR; } - auto conv_bias_data = reinterpret_cast(conv_bias_tensor->tensor_addr()); + auto conv_bias_data = reinterpret_cast(conv_bias_tensor->data_c()); for (int i = 0; i < kernel_nums; i++) { conv_bias_data[i] += add_bias_data[i]; } delete[] add_bias_data; } else { auto conv_weight_param = conv_weight_node->cast()->default_param(); - auto conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + auto conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); conv_node->add_input(conv_new_bias); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc index 42ed168e0fd..5004430d3e7 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -18,7 +18,6 @@ #include #include "ops/batch_norm.h" #include "ops/fused_batch_norm.h" -#include "src/param_value_lite.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -46,8 +45,8 @@ bool IsBatchNode(const BaseRef &n) { void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node, float *trans_scale, float eps, int kernel_num) { auto bn_var_param = bn_var_node->cast()->default_param(); - auto bn_var_tensor = std::dynamic_pointer_cast(bn_var_param); - auto bn_var_data = reinterpret_cast(bn_var_tensor->tensor_addr()); + auto bn_var_tensor = std::dynamic_pointer_cast(bn_var_param); + auto bn_var_data = reinterpret_cast(bn_var_tensor->data_c()); // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) if (memcpy_s(trans_scale, kernel_num * sizeof(float), bn_var_data, kernel_num * sizeof(float)) != EOK) { MS_LOG(ERROR) << "memcpy_s transScale error"; @@ -67,8 +66,8 @@ void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node, } if (bn_scale_node != nullptr) { auto bn_scale_param = bn_scale_node->cast()->default_param(); - auto bn_scale_tensor = std::dynamic_pointer_cast(bn_scale_param); - auto bn_scale_data = reinterpret_cast(bn_scale_tensor->tensor_addr()); + auto bn_scale_tensor = std::dynamic_pointer_cast(bn_scale_param); + auto bn_scale_data = reinterpret_cast(bn_scale_tensor->data_c()); // scale/sqrt(variance + eps) for (int32_t i = 0; i < kernel_num; i++) { trans_scale[i] *= bn_scale_data[i]; @@ -78,8 +77,8 @@ void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node, void CalTransBias(const AnfNodePtr &bn_mean_node, const AnfNodePtr &bn_bias_node, const float *trans_scale, float *trans_bias, int kernel_num) { auto bn_mean_param = bn_mean_node->cast()->default_param(); - auto bn_mean_tensor = std::dynamic_pointer_cast(bn_mean_param); - auto bn_mean_data = reinterpret_cast(bn_mean_tensor->tensor_addr()); + auto bn_mean_tensor = std::dynamic_pointer_cast(bn_mean_param); + auto bn_mean_data = reinterpret_cast(bn_mean_tensor->data_c()); // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) // -mean/sqrt(variance + eps) for (int32_t i = 0; i < kernel_num; i++) { @@ -88,8 +87,8 @@ void CalTransBias(const AnfNodePtr &bn_mean_node, const AnfNodePtr &bn_bias_node if (bn_bias_node != nullptr) { auto bn_bias_param = bn_bias_node->cast()->default_param(); - auto bn_bias_tensor = std::dynamic_pointer_cast(bn_bias_param); - auto bn_bias_data = reinterpret_cast(bn_bias_tensor->tensor_addr()); + auto bn_bias_tensor = std::dynamic_pointer_cast(bn_bias_param); + auto bn_bias_data = reinterpret_cast(bn_bias_tensor->data_c()); // -scale*mean/sqrt(variance + eps) + bias for (int32_t i = 0; i < kernel_num; i++) { trans_bias[i] += bn_bias_data[i]; @@ -108,18 +107,18 @@ STATUS CalEstimatedData(const AnfNodePtr &origin_node, const AnfNodePtr &scale_f return RET_ERROR; } auto origin_param = origin_node->cast()->default_param(); - auto origin_tensor = std::dynamic_pointer_cast(origin_param); - auto origin_data = reinterpret_cast(origin_tensor->tensor_addr()); + auto origin_tensor = std::dynamic_pointer_cast(origin_param); + auto origin_data = reinterpret_cast(origin_tensor->data_c()); auto scale_factor_param = scale_factor_node->cast()->default_param(); - auto scale_factor_tensor = std::dynamic_pointer_cast(scale_factor_param); - if (scale_factor_tensor->tensor_shape_size() < 1) { + auto scale_factor_tensor = std::dynamic_pointer_cast(scale_factor_param); + if (scale_factor_tensor->DataSize() < 1) { MS_LOG(ERROR) << "scale factor data size is not equal to 1"; return RET_ERROR; } - auto scale_factor_data = (reinterpret_cast(scale_factor_tensor->tensor_addr()))[0]; + auto scale_factor_data = (reinterpret_cast(scale_factor_tensor->data_c()))[0]; float scale_factor = scale_factor_data == 0 ? 0 : 1 / scale_factor_data; - for (int i = 0; i < origin_tensor->tensor_shape_size(); i++) { + for (int i = 0; i < origin_tensor->DataSize(); i++) { origin_data[i] = origin_data[i] * scale_factor; } return RET_OK; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc index 169ff0bf35e..a2025a9665d 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc @@ -15,9 +15,9 @@ */ #include "tools/optimizer/fusion/conv_conv_fusion.h" -#include #include #include +#include "tools/common/tensor_util.h" #include "ops/fusion/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" @@ -57,33 +57,33 @@ STATUS GenNewConvBias(const ParameterPtr &down_bias_node, const ParameterPtr &do const ParameterPtr &up_bias_node, const ParameterPtr &new_bias_node) { float *down_bias_data = nullptr; if (down_bias_node != nullptr) { - auto down_bias_param = std::dynamic_pointer_cast(down_bias_node->default_param()); - auto down_bias_shape = down_bias_param->tensor_shape(); + auto down_bias_param = std::dynamic_pointer_cast(down_bias_node->default_param()); + auto down_bias_shape = down_bias_param->shape(); if (down_bias_shape.size() != 1) { MS_LOG(ERROR) << "cur conv_conv fusion only support scalar bias shape"; return RET_FAILED; } - down_bias_data = static_cast(down_bias_param->tensor_addr()); + down_bias_data = static_cast(down_bias_param->data_c()); } - auto up_bias_param = std::dynamic_pointer_cast(up_bias_node->default_param()); - auto up_bias_shape = up_bias_param->tensor_shape(); + auto up_bias_param = std::dynamic_pointer_cast(up_bias_node->default_param()); + auto up_bias_shape = up_bias_param->shape(); if (up_bias_shape.size() != 1) { MS_LOG(ERROR) << "cur conv_conv fusion only support scalar bias shape"; return RET_FAILED; } - auto down_weight_param = std::dynamic_pointer_cast(down_weight_node->default_param()); - auto down_weight_data = static_cast(down_weight_param->tensor_addr()); - auto down_weight_shape = down_weight_param->tensor_shape(); - auto up_bias_data = static_cast(up_bias_param->tensor_addr()); + auto down_weight_param = std::dynamic_pointer_cast(down_weight_node->default_param()); + auto down_weight_data = static_cast(down_weight_param->data_c()); + auto down_weight_shape = down_weight_param->shape(); + auto up_bias_data = static_cast(up_bias_param->data_c()); int new_bias_size = down_weight_shape[0]; - auto new_bias_data = new (std::nothrow) float[new_bias_size]; - if (new_bias_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; + auto tensor_info = lite::CreateTensorInfo(nullptr, 0, {new_bias_size}, up_bias_param->data_type()); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } + auto new_bias_data = static_cast(tensor_info->data_c()); if (memset_s(new_bias_data, new_bias_size * sizeof(float), 0, new_bias_size * sizeof(float)) != EOK) { MS_LOG(ERROR) << "memset_s failed"; - delete[] new_bias_data; return RET_ERROR; } auto up_bias_size = up_bias_shape[0]; @@ -95,43 +95,33 @@ STATUS GenNewConvBias(const ParameterPtr &down_bias_node, const ParameterPtr &do new_bias_data[i] += down_bias_data[i]; } } - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape({new_bias_size}); - param_value->set_tensor_type(up_bias_param->tensor_type()); - param_value->set_format(up_bias_param->format()); - param_value->SetTensorData(new_bias_data, sizeof(float) * new_bias_size); + new_bias_node->set_name(down_bias_node->fullname_with_scope()); - new_bias_node->set_default_param(param_value); + new_bias_node->set_default_param(tensor_info); new_bias_node->set_abstract(down_bias_node->abstract()); return RET_OK; } // up weight shape[cout0,h,w,cin0] down weight shape[cout1,1,1,cout0],new weight shape [cout1,h,w,cin0] STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr &up_weight_node, const ParameterPtr &new_weight_node) { - auto down_weight_param = std::dynamic_pointer_cast(down_weight_node->default_param()); - auto down_weight_shape = down_weight_param->tensor_shape(); - auto up_weight_param = std::dynamic_pointer_cast(up_weight_node->default_param()); - auto up_weight_shape = up_weight_param->tensor_shape(); - auto up_weight_data = static_cast(up_weight_param->tensor_addr()); - auto down_weight_data = static_cast(down_weight_param->tensor_addr()); + auto down_weight_param = std::dynamic_pointer_cast(down_weight_node->default_param()); + auto down_weight_shape = down_weight_param->shape(); + auto up_weight_param = std::dynamic_pointer_cast(up_weight_node->default_param()); + auto up_weight_shape = up_weight_param->shape(); + auto up_weight_data = static_cast(up_weight_param->data_c()); + auto down_weight_data = static_cast(down_weight_param->data_c()); int cout0 = up_weight_shape[0]; int cin0 = up_weight_shape[kNHWC_CDim]; int cout1 = down_weight_shape[0]; int window_size = up_weight_shape[kNHWC_WDim] * up_weight_shape[kNHWC_HDim]; auto new_weight_shape = up_weight_shape; new_weight_shape[0] = down_weight_shape[0]; - int size = std::accumulate(new_weight_shape.begin(), new_weight_shape.end(), 1, std::multiplies<>()); - auto new_weight_data = new (std::nothrow) float[size]; - if (new_weight_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; - return RET_ERROR; - } - if (memset_s(new_weight_data, size * sizeof(float), 0, size * sizeof(float)) != EOK) { - MS_LOG(ERROR) << "memset_s failed"; - delete[] new_weight_data; + auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_weight_shape, up_weight_param->data_type()); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } + auto new_weight_data = static_cast(tensor_info->data_c()); for (int i = 0; i < cout1; i++) { auto down_weight_base = i * cout0; auto new_weight_base = i * window_size * cin0; @@ -148,14 +138,9 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr } } } - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(new_weight_shape); - param_value->set_tensor_type(up_weight_param->tensor_type()); - param_value->set_format(up_weight_param->format()); - param_value->SetTensorData(new_weight_data, sizeof(float) * size); + new_weight_node->set_name(down_weight_node->fullname_with_scope()); - new_weight_node->set_default_param(param_value); + new_weight_node->set_default_param(tensor_info); new_weight_node->set_abstract(down_weight_node->abstract()); return RET_OK; } @@ -230,9 +215,9 @@ const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const A return nullptr; } auto down_weight_parameter = down_conv_cnode->input(kConvWeightIndex)->cast(); - auto down_weight_value = std::dynamic_pointer_cast(down_weight_parameter->default_param()); - auto down_weight_shape = down_weight_value->tensor_shape(); - auto down_weight_type = down_weight_value->tensor_type(); + auto down_weight_value = std::dynamic_pointer_cast(down_weight_parameter->default_param()); + auto down_weight_shape = down_weight_value->shape(); + auto down_weight_type = down_weight_value->data_type(); // down conv node filter must 1x1,only support float32 if (down_weight_shape.size() != kNHWC_DIMS || down_weight_type != kNumberTypeFloat32 || (down_weight_shape[kNHWC_HDim] != 1 || down_weight_shape[kNHWC_WDim] != 1)) { @@ -241,9 +226,9 @@ const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const A auto up_conv_cnode = down_conv_cnode->input(1)->cast(); auto up_weight_parameter = up_conv_cnode->input(kConvWeightIndex)->cast(); - auto up_weight_value = std::dynamic_pointer_cast(up_weight_parameter->default_param()); - auto up_weight_shape = up_weight_value->tensor_shape(); - auto up_weight_type = up_weight_value->tensor_type(); + auto up_weight_value = std::dynamic_pointer_cast(up_weight_parameter->default_param()); + auto up_weight_shape = up_weight_value->shape(); + auto up_weight_type = up_weight_value->data_type(); if (up_weight_shape.size() != kNHWC_DIMS || up_weight_type != kNumberTypeFloat32 || (up_weight_shape[kNHWC_HDim] != 1 || up_weight_shape[kNHWC_WDim] != 1)) { return nullptr; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc index ac29c9f736d..dd2ae1140b4 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -16,7 +16,6 @@ #include "tools/optimizer/fusion/conv_scale_fusion.h" #include -#include "src/param_value_lite.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -70,8 +69,8 @@ void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, return; } auto scale_weight_param = scale_weight_node->cast()->default_param(); - auto weight_value = std::dynamic_pointer_cast(scale_weight_param); - auto weight_data = reinterpret_cast(weight_value->tensor_addr()); + auto weight_value = std::dynamic_pointer_cast(scale_weight_param); + auto weight_data = reinterpret_cast(weight_value->data_c()); if (EOK != memcpy_s(trans_scale, kernel_num * sizeof(float), weight_data, kernel_num * sizeof(float))) { MS_LOG(ERROR) << "memcpy_s transScale failed"; @@ -81,8 +80,8 @@ void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, if (scale_bias_node != nullptr) { auto scale_bias_param = scale_bias_node->cast()->default_param(); - auto bias_value = std::dynamic_pointer_cast(scale_bias_param); - auto bias_data = reinterpret_cast(bias_value->tensor_addr()); + auto bias_value = std::dynamic_pointer_cast(scale_bias_param); + auto bias_data = reinterpret_cast(bias_value->data_c()); if (EOK != memcpy_s(trans_bias, kernel_num * sizeof(float), bias_data, kernel_num * sizeof(float))) { MS_LOG(ERROR) << "memcpy_s transScale failed"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index eddbd60bcfa..846a971bd55 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -18,7 +18,7 @@ #include #include "ops/fusion/conv2d_fusion.h" #include "ops/fusion/conv2d_transpose_fusion.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -70,14 +70,14 @@ void GenerateNewWeightConv2D(float *dst_weight, const float *conv_weight, const } void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weight, - const ParamValueLitePtr &weight_tensor, FmkType fmk, int group, int kernel_num) { + const tensor::TensorPtr &weight_tensor, FmkType fmk, int group, int kernel_num) { if (dst_weight == nullptr || scale_weight == nullptr || weight_tensor == nullptr) { return; } - auto weight_data = reinterpret_cast(weight_tensor->tensor_addr()); + auto weight_data = reinterpret_cast(weight_tensor->data_c()); if (fmk == lite::converter::FmkType_TF) { - auto cin_group = weight_tensor->tensor_shape()[3] / group; - int area_size = weight_tensor->tensor_shape()[0] * weight_tensor->tensor_shape()[1]; + auto cin_group = weight_tensor->shape()[3] / group; + int area_size = weight_tensor->shape()[0] * weight_tensor->shape()[1]; for (int j = 0; j < area_size; j++) { for (int i = 0; i < kernel_num; ++i) { for (int k = 0; k < cin_group; ++k) { @@ -87,8 +87,8 @@ void GenerateNewWeightConv2DTranspose(float *dst_weight, const float *scale_weig } } } else { - auto cin_group = weight_tensor->tensor_shape()[0] / group; - int area_size = weight_tensor->tensor_shape()[2] * weight_tensor->tensor_shape()[3]; + auto cin_group = weight_tensor->shape()[0] / group; + int area_size = weight_tensor->shape()[2] * weight_tensor->shape()[3]; int cout_size = kernel_num * area_size; for (int k = 0; k < cin_group; ++k) { for (int i = 0; i < kernel_num; ++i) { @@ -197,50 +197,31 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const return; } auto conv_weight_param = conv_weight_node->cast()->default_param(); - auto weight_tensor = std::dynamic_pointer_cast(conv_weight_param); + auto weight_tensor = std::dynamic_pointer_cast(conv_weight_param); if (kernel_num <= 0) { MS_LOG(ERROR) << "kernel num less than 0"; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR); return; } - auto temp_weight_data = new (std::nothrow) float[weight_tensor->tensor_shape_size()]; - if (temp_weight_data == nullptr) { - MS_LOG(ERROR) << "new ParamValueLite failed"; - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR); - return; - } - auto new_weight_tensor = std::make_shared(); + auto new_weight_tensor = lite::CreateTensorInfo(weight_tensor->data_c(), weight_tensor->DataSize() * sizeof(float), + weight_tensor->shape(), weight_tensor->data_type()); if (new_weight_tensor == nullptr) { - delete temp_weight_data; - MS_LOG(ERROR) << "new ParamValueLite failed"; + MS_LOG(ERROR) << "create tensor info failed."; return; } - new_weight_tensor->set_tensor_size(weight_tensor->tensor_size()); - new_weight_tensor->set_tensor_shape(weight_tensor->tensor_shape()); - new_weight_tensor->set_tensor_type(weight_tensor->tensor_type()); - new_weight_tensor->set_format(weight_tensor->format()); - auto ret = memcpy_s(temp_weight_data, weight_tensor->tensor_shape_size() * sizeof(float), - weight_tensor->tensor_addr(), weight_tensor->tensor_shape_size() * sizeof(float)); - if (ret != EOK) { - delete temp_weight_data; - MS_LOG(ERROR) << "memcpy_s error:" << ret; - return; - } - new_weight_tensor->SetTensorData(temp_weight_data, new_weight_tensor->tensor_size()); CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale); float *bias_data = nullptr; // conv has bias,bias_flag true bool bias_flag = false; if (conv_bias_node != nullptr) { auto conv_bias_param = conv_bias_node->cast()->default_param(); - auto bias_tensor = std::dynamic_pointer_cast(conv_bias_param); - bias_data = reinterpret_cast(bias_tensor->tensor_addr()); + auto bias_tensor = std::dynamic_pointer_cast(conv_bias_param); + bias_data = reinterpret_cast(bias_tensor->data_c()); bias_flag = true; } else { bias_data = new (std::nothrow) float[kernel_num]; if (bias_data == nullptr) { MS_LOG(ERROR) << "tensor_data is nullptr"; - delete temp_weight_data; return; } } @@ -253,7 +234,6 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const auto new_weight_paramter = func_graph->add_parameter(); if (new_weight_paramter == nullptr) { MS_LOG(ERROR) << "new_weight_paramter is nullptr"; - delete temp_weight_data; return; } new_weight_paramter->set_default_param(new_weight_tensor); @@ -262,15 +242,15 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const conv_node->set_input(kConvWeightIndex, new_weight_paramter); } -void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const ParamValueLitePtr &weight_tensor, +void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const tensor::TensorPtr &weight_tensor, int kernel_num, const float *trans_scale) const { MS_ASSERT(weight_data != nullptr); MS_ASSERT(trans_scale != nullptr); - if (weight_tensor->tensor_shape().size() != 4) { + if (weight_tensor->shape().size() != 4) { MS_LOG(ERROR) << "weight tensor shape error"; return; } - auto weight_shape_size = weight_tensor->tensor_shape_size(); + auto weight_shape_size = weight_tensor->DataSize(); auto tmp_weight_data = new (std::nothrow) float[weight_shape_size]; if (tmp_weight_data == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); @@ -284,7 +264,7 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); return; } - auto weight_data = reinterpret_cast(weight_tensor->tensor_addr()); + auto weight_data = reinterpret_cast(weight_tensor->data_c()); auto conv_prim = GetValueNode(conv_node->input(0)); MS_ASSERT(conv_prim != nullptr); bool is_depth_wise = diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h index 4c8dc9707dc..07dc18d1b12 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h @@ -20,7 +20,6 @@ #include #include "backend/optimizer/common/optimizer.h" #include "tools/converter/converter_flags.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; namespace mindspore::opt { @@ -33,7 +32,7 @@ class ConvTransformFusion : public PatternProcessPass { void GenTransParam(const CNodePtr &, int, float *, float *) const; virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0; void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; - void CalNewWeightTensor(const CNodePtr &, const ParamValueLitePtr &, int, const float *) const; + void CalNewWeightTensor(const CNodePtr &, const tensor::TensorPtr &, int, const float *) const; void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; void SetFmkType(FmkType type) { this->fmk_type_ = type; } diff --git a/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc index 3e351b90bb7..fa875574d25 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc @@ -15,7 +15,6 @@ */ #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h" #include -#include "src/param_value_lite.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" diff --git a/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc index 3545f41dea2..fc629da05f2 100644 --- a/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/gelu_fusion.cc @@ -49,17 +49,17 @@ const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &i if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) { return value; } - auto param_value_lite = parameter_node->default_param()->cast(); + auto param_value_lite = parameter_node->default_param()->cast(); if (param_value_lite == nullptr) { return value; } - if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) { + if (param_value_lite->data_type() != kNumberTypeFloat32 && param_value_lite->data_type() != kNumberTypeFloat) { return value; } - if (param_value_lite->tensor_size() != sizeof(float)) { + if (param_value_lite->Size() != sizeof(float)) { return value; } - return *static_cast(param_value_lite->tensor_addr()); + return *static_cast(param_value_lite->data_c()); } const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, diff --git a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc index 6d5d0ff677b..9adef402f12 100644 --- a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc @@ -18,7 +18,6 @@ #include "ops/fusion/layer_norm_fusion.h" #include "ops/fusion/reduce_fusion.h" #include "mindspore/core/ops/instance_norm.h" -#include "src/param_value_lite.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -33,13 +32,12 @@ STATUS GetReduceAxes(const BaseRef &n, std::vector *axes) { if (!axes_param->has_default() || axes_param->default_param() == nullptr) { return lite::RET_NOT_SUPPORT; } - auto axes_value = axes_param->default_param()->cast(); + auto axes_value = axes_param->default_param()->cast(); if (axes_value == nullptr) { return lite::RET_ERROR; } - axes->resize(axes_value->tensor_shape()[0]); - if (memcpy_s(axes->data(), axes_value->tensor_size(), axes_value->tensor_addr(), axes_value->tensor_size()) == - EOK) { + axes->resize(axes_value->shape()[0]); + if (memcpy_s(axes->data(), axes_value->Size(), axes_value->data_c(), axes_value->Size()) == EOK) { return lite::RET_OK; } } @@ -174,8 +172,10 @@ bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type return false; } auto beta_param = beta_node->cast()->default_param(); - auto beta_tensor = std::dynamic_pointer_cast(beta_param); - auto beta_shape = beta_tensor->tensor_shape(); + auto beta_tensor = std::dynamic_pointer_cast(beta_param); + std::vector beta_shape; + std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape), + [](int64_t val) { return static_cast(val); }); // gamma auto gamma_node = utils::cast((*equiv)[gamma_]); MS_ASSERT(gamma_node != nullptr); @@ -183,8 +183,10 @@ bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type return false; } auto gamma_param = gamma_node->cast()->default_param(); - auto gamma_tensor = std::dynamic_pointer_cast(gamma_param); - auto gamma_shape = gamma_tensor->tensor_shape(); + auto gamma_tensor = std::dynamic_pointer_cast(gamma_param); + std::vector gamma_shape; + std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape), + [](int64_t val) { return static_cast(val); }); // epsilon auto epsilon_node = utils::cast((*equiv)[epsilon_]); MS_ASSERT(epsilon_node != nullptr); @@ -192,9 +194,9 @@ bool NormFusion::CheckPattern(const EquivPtr &equiv, schema::PrimitiveType *type return false; } auto epsilon_param = epsilon_node->cast()->default_param(); - auto epsilon_tensor = std::dynamic_pointer_cast(epsilon_param); - auto epsilon_data = reinterpret_cast(epsilon_tensor->tensor_addr()); - auto epsilon_shape = epsilon_tensor->tensor_shape(); + auto epsilon_tensor = std::dynamic_pointer_cast(epsilon_param); + auto epsilon_data = reinterpret_cast(epsilon_tensor->data_c()); + auto epsilon_shape = epsilon_tensor->shape(); // mean2 std::vector mean2_axes; if (!IsReduceNode(equiv, mean2_, mean2_axes_, &mean2_axes)) { diff --git a/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc index e316f81bb2f..db89351cf4e 100644 --- a/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc @@ -17,7 +17,6 @@ #include #include "ops/fusion/activation.h" #include "ops/op_utils.h" -#include "src/param_value_lite.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h index a54a797fd1c..8f6346c09b2 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h @@ -21,7 +21,6 @@ #include #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" #include "schema/inner/model_generated.h" -#include "src/param_value_lite.h" #include "backend/optimizer/common/optimizer.h" #include "utils/utils.h" #include "include/errorcode.h" diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 420615f0a25..7907ad3890b 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -23,6 +23,7 @@ #include "ops/stack.h" #include "ops/transpose.h" #include "src/common/utils.h" +#include "tools/common/tensor_util.h" #include "utils/utils.h" #include "securec/include/securec.h" @@ -197,7 +198,7 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr return pattern; } -ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const { +tensor::TensorPtr TfBidirectionGruFusion::GetDefaultTensorInfo(const AnfNodePtr ¶meter_anf) const { MS_ASSERT(parameter_anf != nullptr); if (!utils::isa(parameter_anf)) { MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr"; @@ -208,8 +209,8 @@ ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr MS_LOG(DEBUG) << "parameter not have default value"; return nullptr; } - auto param_value = std::dynamic_pointer_cast(parameter->default_param()); - return param_value; + auto tensor_info = std::dynamic_pointer_cast(parameter->default_param()); + return tensor_info; } STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, @@ -219,19 +220,19 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k MS_ASSERT(bw_cand_kernel != nullptr); MS_ASSERT(input_size != nullptr); MS_ASSERT(hidden_size != nullptr); - auto fw_cand_kernel_value = GetDefaultParamValue(fw_cand_kernel_anf); + auto fw_cand_kernel_value = GetDefaultTensorInfo(fw_cand_kernel_anf); if (fw_cand_kernel_value == nullptr) { return RET_ERROR; } - auto fw_cand_kernel_shape = fw_cand_kernel_value->tensor_shape(); + auto fw_cand_kernel_shape = fw_cand_kernel_value->shape(); if (fw_cand_kernel_shape.size() != 2) { return RET_ERROR; } - auto bw_cand_kernel_value = GetDefaultParamValue(bw_cand_kernel_anf); + auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf); if (bw_cand_kernel_value == nullptr) { return RET_ERROR; } - auto bw_cand_kernel_shape = bw_cand_kernel_value->tensor_shape(); + auto bw_cand_kernel_shape = bw_cand_kernel_value->shape(); if (bw_cand_kernel_shape.size() != 2) { return RET_ERROR; } @@ -261,32 +262,13 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun } parameter->set_abstract(abstract_tensor); - auto gate_weight_default = std::make_shared(); + auto gate_weight_default = std::make_shared(type, shape_vector); if (gate_weight_default == nullptr) { MS_LOG(ERROR) << "gate_weight_default is nullptr"; return nullptr; } - gate_weight_default->set_tensor_shape(shape); - gate_weight_default->set_tensor_type(type); - gate_weight_default->set_format(schema::Format_NHWC); - int data_len = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - int data_size = 0; - if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { - data_size = data_len * sizeof(float); - *tensor_data = new (std::nothrow) float[data_len]; - } else if (type == kNumberTypeInt || type == kNumberTypeInt32) { - data_size = data_len * sizeof(int); - *tensor_data = new (std::nothrow) int[data_len]; - } else { - MS_LOG(DEBUG) << "unsupported data type"; - return nullptr; - } - if (*tensor_data == nullptr) { - MS_LOG(ERROR) << "new data failed"; - return nullptr; - } - gate_weight_default->SetTensorData(*tensor_data, data_size); + *tensor_data = gate_weight_default->data_c(); parameter->set_default_param(gate_weight_default); return parameter; } @@ -317,27 +299,27 @@ STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, MS_ASSERT(cand_weight != nullptr); MS_ASSERT(gate_tensor_data != nullptr); MS_ASSERT(recu_tensor_data != nullptr); - const std::vector gate_shape{input_size + hidden_size, hidden_size * 2}; - const std::vector cand_shape{hidden_size * 2, hidden_size}; - auto gate_weight_value = GetDefaultParamValue(gate_weight); + const std::vector gate_shape{input_size + hidden_size, hidden_size * 2}; + const std::vector cand_shape{hidden_size * 2, hidden_size}; + auto gate_weight_value = GetDefaultTensorInfo(gate_weight); if (gate_weight_value == nullptr) { return RET_ERROR; } - auto gate_weight_data = reinterpret_cast(gate_weight_value->tensor_addr()); + auto gate_weight_data = reinterpret_cast(gate_weight_value->data_c()); if (gate_weight_data == nullptr) { return RET_ERROR; } - auto gate_weight_shape = gate_weight_value->tensor_shape(); + auto gate_weight_shape = gate_weight_value->shape(); - auto cand_weight_value = GetDefaultParamValue(cand_weight); + auto cand_weight_value = GetDefaultTensorInfo(cand_weight); if (cand_weight_value == nullptr) { return RET_ERROR; } - auto cand_weight_data = reinterpret_cast(cand_weight_value->tensor_addr()); + auto cand_weight_data = reinterpret_cast(cand_weight_value->data_c()); if (cand_weight_data == nullptr) { return RET_ERROR; } - auto cand_weight_shape = cand_weight_value->tensor_shape(); + auto cand_weight_shape = cand_weight_value->shape(); if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) { return RET_ERROR; @@ -369,20 +351,20 @@ STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, cons const int hidden_size, float *tensor_data) const { MS_ASSERT(bias != nullptr); MS_ASSERT(tensor_data != nullptr); - std::vector gate_shape{hidden_size * 2}; - std::vector cand_shape{hidden_size}; - auto gate_bias_value = GetDefaultParamValue(gate_bias); + std::vector gate_shape{hidden_size * 2}; + std::vector cand_shape{hidden_size}; + auto gate_bias_value = GetDefaultTensorInfo(gate_bias); if (gate_bias_value == nullptr) { return RET_ERROR; } - auto gate_bias_data = reinterpret_cast(gate_bias_value->tensor_addr()); - auto gate_bias_shape = gate_bias_value->tensor_shape(); - auto cand_bias_value = GetDefaultParamValue(cand_bias); + auto gate_bias_data = reinterpret_cast(gate_bias_value->data_c()); + auto gate_bias_shape = gate_bias_value->shape(); + auto cand_bias_value = GetDefaultTensorInfo(cand_bias); if (cand_bias_value == nullptr) { return RET_ERROR; } - auto cand_bias_data = reinterpret_cast(cand_bias_value->tensor_addr()); - auto cand_bias_shape = cand_bias_value->tensor_shape(); + auto cand_bias_data = reinterpret_cast(cand_bias_value->data_c()); + auto cand_bias_shape = cand_bias_value->shape(); if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) { return RET_ERROR; } @@ -504,6 +486,8 @@ CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &fu std::vector new_node_inputs = {value_node, input, gate_weight, recu_weight, bias, stacked_hidden, input_length}; auto new_node = func_graph->NewCNode(new_node_inputs); + auto prim = GetValueNode(new_node->input(0)); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::NHWC)); new_node->set_fullname_with_scope(base_name); return new_node; } diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h index 0c315149268..57a6f4a0a78 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.h @@ -21,7 +21,6 @@ #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "schema/inner/model_generated.h" -#include "src/param_value_lite.h" #include "backend/optimizer/common/optimizer.h" #include "utils/utils.h" #include "include/errorcode.h" @@ -48,7 +47,7 @@ class TfBidirectionGruFusion : public PatternProcessPass { private: AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const; - ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const; + tensor::TensorPtr GetDefaultTensorInfo(const AnfNodePtr ¶meter_anf) const; lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf, int *input_size, int *hidden_size) const; ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name, diff --git a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc index 775da6c4d9d..04f9c75985b 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc @@ -17,7 +17,7 @@ #include #include "ops/lstm.h" #include "src/common/utils.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -110,19 +110,10 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi return pattern; } -STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, +STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, const float *const data_ptr, const int hidden_size) const { MS_ASSERT(weight != nullptr); MS_ASSERT(data_ptr != nullptr); - auto default_param = std::make_shared(); - if (default_param == nullptr) { - MS_LOG(ERROR) << "new_default is nullptr"; - return RET_ERROR; - } - default_param->set_tensor_shape(shape); - default_param->set_tensor_type(kNumberTypeFloat32); - default_param->set_format(schema::Format_NHWC); - if (shape.size() != 3) { MS_LOG(ERROR) << "lstm weight shape must have 3 dims"; return RET_ERROR; @@ -141,16 +132,17 @@ STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, } } } - default_param->SetTensorData(tensor_data, param_num * 4); - weight->set_default_param(default_param); - std::vector shape_vector_i(shape.begin(), shape.end()); - auto abstract_tensor_i = std::make_shared(kFloat32, shape_vector_i); - if (abstract_tensor_i == nullptr) { - MS_LOG(ERROR) << "abstract_tensor is nullptr"; - delete[] tensor_data; + auto tensor_info = lite::CreateTensorInfo(tensor_data, param_num * 4, shape, kNumberTypeFloat32); + delete[] tensor_data; + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; + return RET_ERROR; + } + auto status = lite::InitParameterFromTensorInfo(weight, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return RET_ERROR; } - weight->set_abstract(abstract_tensor_i); return RET_OK; } @@ -169,17 +161,17 @@ STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterP MS_LOG(DEBUG) << "weight not have default value"; return RET_ERROR; } - if (!utils::isa(weight_param->default_param())) { - MS_LOG(DEBUG) << "default value is not ParamValueLite"; + if (!utils::isa(weight_param->default_param())) { + MS_LOG(DEBUG) << "default value is not tensor::Tensor"; return RET_FAILED; } - auto origin_tensor = std::dynamic_pointer_cast(weight_param->default_param()); - if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + auto origin_tensor = std::dynamic_pointer_cast(weight_param->default_param()); + if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) { MS_LOG(DEBUG) << "origin_tensor is not float32 type"; return RET_ERROR; } - auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); - auto data_shape = origin_tensor->tensor_shape(); + auto data_ptr = reinterpret_cast(origin_tensor->data_c()); + auto data_shape = origin_tensor->shape(); if (data_shape.size() != 2) { MS_LOG(ERROR) << "weight data shape invalid"; return RET_ERROR; @@ -194,13 +186,13 @@ STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterP } const auto input_size = data_shape[0] - hidden_size; - std::vector shape_i{1, 4 * hidden_size, input_size}; + std::vector shape_i{1, 4 * hidden_size, input_size}; if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) { MS_LOG(ERROR) << "get weight_i failed"; return RET_ERROR; } - std::vector shape_c{1, 4 * hidden_size, hidden_size}; + std::vector shape_c{1, 4 * hidden_size, hidden_size}; if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) { MS_LOG(ERROR) << "get weight_i failed"; return RET_ERROR; @@ -222,32 +214,23 @@ STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const Para MS_LOG(DEBUG) << "bias not have default value"; return RET_ERROR; } - if (!utils::isa(old_bias_param->default_param())) { - MS_LOG(DEBUG) << "default value is not ParamValueLite"; + if (!utils::isa(old_bias_param->default_param())) { + MS_LOG(DEBUG) << "default value is not tensor::Tensor"; return RET_FAILED; } - auto origin_tensor = std::dynamic_pointer_cast(old_bias_param->default_param()); - if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + auto origin_tensor = std::dynamic_pointer_cast(old_bias_param->default_param()); + if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) { MS_LOG(DEBUG) << "origin_tensor is not float32 type"; return RET_ERROR; } - auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); - auto data_shape = origin_tensor->tensor_shape(); + auto data_ptr = reinterpret_cast(origin_tensor->data_c()); + auto data_shape = origin_tensor->shape(); if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) { MS_LOG(DEBUG) << "bias data shape illegal"; return RET_ERROR; } - std::vector shape{1, 8 * hidden_size}; - - auto default_param = std::make_shared(); - if (default_param == nullptr) { - MS_LOG(ERROR) << "new_default is nullptr"; - return RET_ERROR; - } - default_param->set_tensor_shape(shape); - default_param->set_tensor_type(kNumberTypeFloat32); - default_param->set_format(schema::Format_NHWC); + std::vector shape{1, 8 * hidden_size}; std::unique_ptr tensor_data(new (std::nothrow) float[hidden_size * 8]); auto forget_bias_node = utils::cast((*body_equiv)[forget_bias_]); @@ -256,7 +239,7 @@ STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const Para return RET_ERROR; } float forget_bias_value = 0.0f; - if (GetFloatScalarFromParamValueLite(forget_bias_node, &forget_bias_value) != RET_OK) { + if (GetFloatScalarFromTensorInfo(forget_bias_node, &forget_bias_value) != RET_OK) { return RET_ERROR; } @@ -273,15 +256,19 @@ STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const Para } } } - default_param->SetTensorData(tensor_data.release(), hidden_size * 8 * 4); - new_bias->set_default_param(default_param); - std::vector shape_vector_i(shape.begin(), shape.end()); - auto abstract_tensor_i = std::make_shared(kFloat32, shape_vector_i); - if (abstract_tensor_i == nullptr) { - MS_LOG(ERROR) << "abstract_tensor is nullptr"; + + auto tensor_info = lite::CreateTensorInfo(tensor_data.get(), hidden_size * 8 * 4, shape, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } - new_bias->set_abstract(abstract_tensor_i); + + auto status = lite::InitParameterFromTensorInfo(new_bias, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return RET_ERROR; + } + return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h index 712ff825469..2e4426a196f 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tf_lstm_cell_fusion.h @@ -22,7 +22,6 @@ #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "backend/optimizer/common/optimizer.h" #include "utils/utils.h" -#include "src/param_value_lite.h" #include "include/errorcode.h" namespace mindspore { @@ -40,7 +39,7 @@ class TfLstmCellFusion : public TfliteLstmCellFusion { lite::STATUS SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i, const ParameterPtr &weight_c, int hidden_size) const; - lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, + lite::STATUS SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector &shape, const float *const data_ptr, const int hidden_size) const; lite::STATUS PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias, const AnfNodePtr &old_bias, const int hidden_size) const; diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc index d7d65a1b12f..357195eef52 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -14,15 +14,14 @@ * limitations under the License. */ #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" -#include #include +#include #include #include "ops/lstm.h" #include "ops/squeeze.h" #include "ops/tuple_get_item.h" #include "src/common/utils.h" -#include "src/param_value_lite.h" -#include "schema/inner/model_generated.h" +#include "tools/common/tensor_util.h" #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" @@ -51,36 +50,36 @@ bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { } } // namespace -STATUS TfliteLstmCellFusion::GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const { - if (param_value == nullptr || v == nullptr) { - MS_LOG(ERROR) << "param_value or v is nullptr"; +STATUS TfliteLstmCellFusion::GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) const { + if (tensor_info == nullptr || v == nullptr) { + MS_LOG(ERROR) << "tensor_info or v is nullptr"; return RET_ERROR; } - if (!utils::isa(param_value)) { - MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; + if (!utils::isa(tensor_info)) { + MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr"; return RET_ERROR; } - auto param_ptr = utils::cast(param_value); + auto param_ptr = utils::cast(tensor_info); if (!param_ptr->has_default()) { MS_LOG(DEBUG) << "param not have default"; return RET_ERROR; } auto default_param = param_ptr->default_param(); - if (!utils::isa(default_param)) { - MS_LOG(DEBUG) << "param_value is not ParamValueLitePtr"; + if (!utils::isa(default_param)) { + MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr"; return RET_ERROR; } - auto default_param_ptr = utils::cast(default_param); - auto tensor_shape = default_param_ptr->tensor_shape(); + auto default_param_ptr = utils::cast(default_param); + auto tensor_shape = default_param_ptr->shape(); if (!(tensor_shape.size() == 0 || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) { MS_LOG(DEBUG) << "default param is not scalar"; return RET_ERROR; } - if (default_param_ptr->tensor_type() != kNumberTypeFloat32 && default_param_ptr->tensor_type() != kNumberTypeFloat) { + if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) { MS_LOG(DEBUG) << "default param is not float"; return RET_ERROR; } - *v = *(reinterpret_cast(default_param_ptr->tensor_addr())); + *v = *(reinterpret_cast(default_param_ptr->data_c())); return RET_OK; } @@ -278,16 +277,16 @@ bool TfliteLstmCellFusion::CheckBodyGraph(const FuncGraphPtr &func_graph, const MS_ASSERT(hidden_zoneout_new_node != nullptr); float cell_old, cell_new, hidden_old, hidden_new; - if (GetFloatScalarFromParamValueLite(cell_zoneout_old_node, &cell_old) != RET_OK) { + if (GetFloatScalarFromTensorInfo(cell_zoneout_old_node, &cell_old) != RET_OK) { return false; } - if (GetFloatScalarFromParamValueLite(cell_zoneout_new_node, &cell_new) != RET_OK) { + if (GetFloatScalarFromTensorInfo(cell_zoneout_new_node, &cell_new) != RET_OK) { return false; } - if (GetFloatScalarFromParamValueLite(hidden_zoneout_old_node, &hidden_old) != RET_OK) { + if (GetFloatScalarFromTensorInfo(hidden_zoneout_old_node, &hidden_old) != RET_OK) { return false; } - if (GetFloatScalarFromParamValueLite(hidden_zoneout_new_node, &hidden_new) != RET_OK) { + if (GetFloatScalarFromTensorInfo(hidden_zoneout_new_node, &hidden_new) != RET_OK) { return false; } if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) { @@ -313,7 +312,7 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector &par MS_ASSERT(new_param != nullptr); MS_ASSERT(params.size() == 4); std::vector data_ptrs; - std::vector> data_shapes; + std::vector> data_shapes; for (auto ¶m : params) { if (!utils::isa(param)) { MS_LOG(DEBUG) << "param is not Parameter node"; @@ -324,17 +323,17 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector &par MS_LOG(DEBUG) << "param not have default value"; return RET_FAILED; } - if (!utils::isa(param_t->default_param())) { - MS_LOG(DEBUG) << "default value is not ParamValueLite"; + if (!utils::isa(param_t->default_param())) { + MS_LOG(DEBUG) << "default value is not tensor::Tensor"; return RET_FAILED; } - auto origin_tensor = std::dynamic_pointer_cast(param_t->default_param()); - if (origin_tensor->tensor_type() != kNumberTypeFloat32 && origin_tensor->tensor_type() != kNumberTypeFloat) { + auto origin_tensor = std::dynamic_pointer_cast(param_t->default_param()); + if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) { MS_LOG(DEBUG) << "origin_tensor is not float32 type"; return RET_FAILED; } - auto data_ptr = reinterpret_cast(origin_tensor->tensor_addr()); - auto data_shape = origin_tensor->tensor_shape(); + auto data_ptr = reinterpret_cast(origin_tensor->data_c()); + auto data_shape = origin_tensor->shape(); data_ptrs.push_back(data_ptr); data_shapes.push_back(data_shape); } @@ -345,13 +344,7 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector &par return RET_FAILED; } } - auto new_default = std::make_shared(); - if (new_default == nullptr) { - MS_LOG(ERROR) << "new_default is nullptr"; - return RET_ERROR; - } - std::vector new_shape; - float *tensor_data = nullptr; + std::vector new_shape; int step = 0; int data_size = 0; if (is_bias) { @@ -361,23 +354,25 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector &par } step = data_shapes[0][0]; data_size = 8 * step; - new_shape = std::vector({1, data_size}); + new_shape = std::vector({1, data_size}); } else { if (data_shapes[0].size() != 2) { MS_LOG(ERROR) << "weight data shape error"; return RET_ERROR; } - new_shape = std::vector({1, data_shapes[0][0] * 4, data_shapes[0][1]}); + new_shape = std::vector({1, data_shapes[0][0] * 4, data_shapes[0][1]}); step = data_shapes[0][0] * data_shapes[0][1]; data_size = 4 * step; } - tensor_data = new (std::nothrow) float[data_size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new data failed"; + auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; } + + auto tensor_data = static_cast(tensor_info->data_c()); for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0 tensor_data[i] = 0.0f; } @@ -387,23 +382,16 @@ STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector &par auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s error"; - delete[] tensor_data; return RET_ERROR; } } - new_default->set_tensor_shape(new_shape); - new_default->set_tensor_type(kNumberTypeFloat32); - new_default->set_format(schema::Format_NHWC); - new_default->SetTensorData(tensor_data, data_size * sizeof(float)); - new_param->set_default_param(new_default); - std::vector shape_vector(new_shape.begin(), new_shape.end()); - auto abstract_tensor = std::make_shared(kFloat32, shape_vector); - if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "abstract_tensor is nullptr"; + auto status = lite::InitParameterFromTensorInfo(new_param, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return RET_ERROR; } - new_param->set_abstract(abstract_tensor); + return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h index f327d4dfd80..ec3fd3bfb69 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.h @@ -50,7 +50,7 @@ class TfliteLstmCellFusion : public PatternProcessPass { VarPtr hidden_zoneout_new_ = nullptr; std::vector while_input_vars_; - lite::STATUS GetFloatScalarFromParamValueLite(const AnfNodePtr ¶m_value, float *v) const; + lite::STATUS GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) const; CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node, const std::vector &axis) const; lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode, diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc index 4e2fb11eb1a..fbe001c9e26 100644 --- a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -54,23 +54,23 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) { } if ((min == -1) && (max == -1)) { if (clip_cnode->size() > kClipMinIndex) { - auto min_param_value = GetLiteParamValue(clip_cnode->input(kClipMinIndex)); - if (min_param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + auto min_tensor_info = GetTensorInfo(clip_cnode->input(kClipMinIndex)); + if (min_tensor_info->data_type() != mindspore::kNumberTypeFloat32) { MS_LOG(ERROR) << "Clip param type invalid"; return false; } - min = *reinterpret_cast(min_param_value->tensor_addr()); + min = *reinterpret_cast(min_tensor_info->data_c()); } else { min = FLT_MIN; } if (clip_cnode->size() > kClipMaxIndex) { - auto max_param_value = GetLiteParamValue(clip_cnode->input(kClipMaxIndex)); - if (max_param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + auto max_tensor_info = GetTensorInfo(clip_cnode->input(kClipMaxIndex)); + if (max_tensor_info->data_type() != mindspore::kNumberTypeFloat32) { MS_LOG(ERROR) << "Clip param type invalid"; return false; } - max = *reinterpret_cast(max_param_value->tensor_addr()); + max = *reinterpret_cast(max_tensor_info->data_c()); } else { max = FLT_MAX; } diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h index a5b7968b709..e49705b4ec6 100644 --- a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.h @@ -19,7 +19,6 @@ #include #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; using mindspore::schema::QuantType; diff --git a/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc b/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc index 8026242f1fd..30697658737 100644 --- a/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.cc @@ -23,13 +23,14 @@ namespace { constexpr size_t kTripleNum = 3; constexpr size_t kConvWeightIndex = 2; } // namespace -lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const ParamValueLitePtr &tensor) { +lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const tensor::TensorPtr &tensor, + const schema::Format &format) { if (tensor == nullptr) { return lite::RET_NULL_PTR; } - auto shape = tensor->tensor_shape(); - std::vector new_shape(shape); - switch (tensor->format()) { + auto shape = tensor->shape(); + std::vector new_shape(shape); + switch (format) { case schema::Format_NCHW: case schema::Format_KCHW: new_shape.insert(new_shape.begin() + 2, 1); @@ -42,7 +43,7 @@ lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const ParamValueLitePt MS_LOG(ERROR) << "Unsupported format."; return RET_ERROR; } - tensor->set_tensor_shape(new_shape); + tensor->set_shape(new_shape); return RET_OK; } @@ -61,14 +62,21 @@ bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) { MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); auto weight_node = conv_cnode->input(kConvWeightIndex); MS_ASSERT(weight_node != nullptr); - auto weight_value = GetLiteParamValue(weight_node); + auto weight_value = GetTensorInfo(weight_node); if (weight_value == nullptr) { MS_LOG(ERROR) << "weight node must be param value."; return false; } + auto prim = GetValueNode(conv_cnode->input(0)); + MS_ASSERT(prim != nullptr); + + schema::Format schema_format = schema::Format::Format_KCHW; + if (prim->GetAttr(opt::kWeightFormat) != nullptr) { + schema_format = static_cast(GetValue(prim->GetAttr(opt::kWeightFormat))); + } // expand weight tensor to 4 dimensions. - if (weight_value->tensor_shape().size() == kTripleNum) { - auto status = ExpandFilterShape(weight_value); + if (weight_value->shape().size() == kTripleNum) { + auto status = ExpandFilterShape(weight_value, schema_format); if (status != RET_OK) { MS_LOG(ERROR) << "Expand filter shape failed."; return false; diff --git a/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h b/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h index 4224a690cf7..8728c31170e 100644 --- a/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h +++ b/mindspore/lite/tools/optimizer/graph/conv1d_weight_expanding_pass.h @@ -22,7 +22,6 @@ #include "backend/optimizer/common/pass.h" #include "tools/optimizer/common/gllo_utils.h" -using mindspore::ParamValueLitePtr; namespace mindspore::opt { class Conv1DWeightExpandingPass : public Pass { public: @@ -31,7 +30,7 @@ class Conv1DWeightExpandingPass : public Pass { bool Run(const FuncGraphPtr &graph) override; private: - lite::STATUS ExpandFilterShape(const ParamValueLitePtr &tensor); + lite::STATUS ExpandFilterShape(const tensor::TensorPtr &tensor, const schema::Format &format); }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc index 9d9c9d4c5b1..f6d81837894 100644 --- a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -81,7 +81,7 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); auto weight_node = conv_cnode->input(kConvWeightIndex); MS_ASSERT(weight_node != nullptr); - auto weight_value = GetLiteParamValue(weight_node); + auto weight_value = GetTensorInfo(weight_node); if (weight_value == nullptr) { MS_LOG(ERROR) << "weight node must param value"; return false; @@ -89,19 +89,20 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || weight_value->tensor_type() == TypeId::kNumberTypeInt8); lite::STATUS status; - schema::Format weight_dst_format = schema::Format::Format_CHWK; - weight_value->set_format(schema::Format_KHWC); - status = TransFilterFormat(weight_value, weight_dst_format); + auto weight_src_format = schema::Format::Format_KHWC; + auto weight_dst_format = schema::Format::Format_CHWK; + + status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format); if (status == RET_OK) { - weight_value->set_format(weight_dst_format); + conv2d_fusion->AddAttr(opt::kWeightFormat, MakeValue(weight_dst_format)); } else { - MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To" + MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To" << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope(); return false; } - auto type_id = static_cast(weight_value->tensor_type()); + auto type_id = static_cast(weight_value->data_type()); auto type_ptr = TypeIdToType(type_id); - auto shape = weight_value->tensor_shape(); + auto shape = weight_value->shape(); std::vector shape_vector; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h index 0d4faf95eb5..ffe1d0df2d9 100644 --- a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.h @@ -18,7 +18,6 @@ #include #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" namespace mindspore::opt { class GroupDepthwiseOpConvertPass : public Pass { diff --git a/mindspore/lite/tools/optimizer/graph/if_pass.h b/mindspore/lite/tools/optimizer/graph/if_pass.h index 9afb405312d..5160b4725f3 100644 --- a/mindspore/lite/tools/optimizer/graph/if_pass.h +++ b/mindspore/lite/tools/optimizer/graph/if_pass.h @@ -21,7 +21,6 @@ #include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; namespace mindspore::opt { diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index fd4fa7e7c39..cdf8964379c 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -19,6 +19,7 @@ #include #include "include/errorcode.h" #include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" #include "src/common/common.h" #include "src/ops/populate/populate_register.h" #include "src/ops/ops_utils.h" @@ -27,16 +28,17 @@ namespace mindspore::opt { namespace { constexpr size_t INITIAL_SIZE = 1024; -ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) { - auto para_value_lite = std::make_shared(); - if (para_value_lite == nullptr) { - MS_LOG(ERROR) << "new ParamValueLite failed"; +tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { + std::vector shape(tensor->shape()); + std::vector shape_vector; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), + [](const int32_t &value) { return static_cast(value); }); + auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "new tensor::Tensor failed"; return nullptr; } - para_value_lite->set_tensor_shape(tensor->shape()); - para_value_lite->set_tensor_type(tensor->data_type()); - para_value_lite->set_format(tensor->format()); - return para_value_lite; + return tensor_info; } bool IsSpecialType(const CNodePtr &cnode) { @@ -62,9 +64,9 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li return nullptr; } - auto para_value_lite = NewParamValueLitePtr(tensor); - if (para_value_lite == nullptr) { - MS_LOG(ERROR) << "new ParamValueLite failed"; + auto tensor_info = NewTensorInfo(tensor); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "new tensor::Tensor failed"; return nullptr; } @@ -74,76 +76,23 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li MS_LOG(ERROR) << "cast tensor_list failed"; return nullptr; } - auto tensor_info = new int[tensor_list->element_shape().size() + 2]; - tensor_info[0] = tensor_list->tensors_data_type(); - tensor_info[1] = tensor_list->element_shape().size(); + auto tensor_data = new int[tensor_list->element_shape().size() + 2]; + tensor_data[0] = tensor_list->tensors_data_type(); + tensor_data[1] = tensor_list->element_shape().size(); for (size_t i = 0; i < tensor_list->element_shape().size(); ++i) { - tensor_info[i + 2] = tensor_list->element_shape()[i]; + tensor_data[i + 2] = tensor_list->element_shape()[i]; + } + auto status = lite::SetTensorData(tensor_info, tensor_data, tensor_list->element_shape().size() + 2); + delete[] tensor_data; + if (status != RET_OK) { + MS_LOG(ERROR) << "set tensor data failed"; + return nullptr; } - para_value_lite->set_tensor_addr(tensor_info); - para_value_lite->set_tensor_size(tensor_list->element_shape().size() + 2); } - - new_abstract->set_value(para_value_lite); + new_abstract->set_value(tensor_info); return new_abstract; } -STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { - MS_ASSERT(parameter != nullptr); - auto old_abstract = parameter->abstract(); - if (old_abstract == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); - return RET_ERROR; - } - if (!utils::isa(old_abstract)) { - MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << parameter->name(); - return RET_ERROR; - } - auto abstract_tensor = utils::cast(old_abstract); - - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - if (typePtr == nullptr) { - MS_LOG(ERROR) << "typePtr is nullptr"; - return RET_ERROR; - } - - if (!utils::isa(abstract_tensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << parameter->name(); - return RET_ERROR; - } - auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - std::vector shape; - (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), - [](const int64_t &value) { return static_cast(value); }); - - auto new_abstract = std::make_shared(typePtr, shape_vector); - auto new_value = std::make_shared(); - new_value->set_tensor_shape(shape); // scalar's shape is {} - new_value->set_tensor_type(typePtr->type_id()); - new_value->set_format(schema::Format_NHWC); // default format is NHWC - if (parameter->has_default()) { - auto param_value = std::dynamic_pointer_cast(parameter->default_param()); - new_value->set_format(param_value->format()); - new_value->set_tensor_size(param_value->tensor_size()); - - char *tensor_data = new (std::nothrow) char[new_value->tensor_size()]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed"; - return RET_ERROR; - } - auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); - if (new_value->tensor_size() != 0 && ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - delete[] tensor_data; - return RET_ERROR; - } - new_value->SetTensorData(tensor_data, new_value->tensor_size()); - } - new_abstract->set_value(new_value); - parameter->set_abstract(new_abstract); - return RET_OK; -} - void InferShapePass::FreeTensors(std::vector *tensors) { for (auto tensor : *tensors) { delete tensor; @@ -178,18 +127,18 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector(abstract); - if (!utils::isa(abstract_tensor->GetValueTrack())) { // input node not complete infershape - MS_LOG(DEBUG) << "Value of abstract is not ParamValueLite, indicate that infershape has failed"; + if (!utils::isa(abstract_tensor->GetValueTrack())) { // input node not complete infershape + MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; return RET_ERROR; } - auto param_value_lite = utils::cast(abstract_tensor->GetValueTrack()); + auto param_value_lite = utils::cast(abstract_tensor->GetValueTrack()); if (param_value_lite == nullptr) { - MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr"; + MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr"; return RET_ERROR; } std::unique_ptr tensor = nullptr; - if (param_value_lite->tensor_type() != kObjectTypeTensorType) { + if (param_value_lite->data_type() != kObjectTypeTensorType) { tensor = std::make_unique(); } else { tensor = std::make_unique(); @@ -198,29 +147,32 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vectortensor_type() != kObjectTypeTensorType) { - tensor->set_shape(param_value_lite->tensor_shape()); - tensor->set_data_type(param_value_lite->tensor_type()); - tensor->set_format(schema::Format(param_value_lite->format())); + + std::vector shape; + std::transform(param_value_lite->shape().begin(), param_value_lite->shape().end(), std::back_inserter(shape), + [](const int64_t &value) { return static_cast(value); }); + if (param_value_lite->data_type() != kObjectTypeTensorType) { + tensor->set_shape(shape); + tensor->set_data_type(param_value_lite->data_type()); } if (utils::isa(input)) { auto parameter = input->cast(); if (parameter->has_default()) { - auto param_value = std::dynamic_pointer_cast(parameter->default_param()); - if (param_value_lite->tensor_type() != kObjectTypeTensorType) { + auto tensor_info = std::dynamic_pointer_cast(parameter->default_param()); + if (param_value_lite->data_type() != kObjectTypeTensorType) { auto ret = tensor->MallocData(); if (ret != 0) { MS_LOG(ERROR) << "Malloc tensor data failed"; return RET_ERROR; } - ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); + ret = memcpy_s(tensor->MutableData(), tensor->Size(), tensor_info->data_c(), tensor_info->Size()); if (tensor->Size() != 0 && ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } } else { - int *data = reinterpret_cast(param_value->tensor_addr()); + int *data = reinterpret_cast(tensor_info->data_c()); auto tensor_list = reinterpret_cast(tensor.get()); if (tensor_list->Decode(data) != RET_OK) { return RET_ERROR; @@ -349,10 +301,6 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { auto node_list = TopoSort(func_graph->get_return()); for (auto &node : node_list) { if (utils::isa(node)) { - int status = SetParameterAbstract(node->cast()); - if (status != RET_OK) { - return false; - } continue; } if (!utils::isa(node)) { diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index c753f0cecdf..517c98ad4da 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -39,7 +39,6 @@ class InferShapePass : public Pass { abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *input_tensors); STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors); - STATUS SetParameterAbstract(const ParameterPtr ¶meter); STATUS SetCNodeAbstract(const std::vector &output_tensors, const std::shared_ptr &cnode); int StrIsContain(const std::vector &total, const std::string &aim); int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); diff --git a/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h index 8d3a927478a..dc7d43731cb 100644 --- a/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/inputs_adjust_pass.h @@ -21,7 +21,6 @@ #include #include "tools/optimizer/common/gllo_utils.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" #include "include/errorcode.h" using mindspore::lite::STATUS; diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index 7eb9fac3ec7..fa1263c38f1 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -14,7 +14,6 @@ * limitations under the License. */ #include "tools/optimizer/graph/mindir_adjust_pass.h" -#include #include #include @@ -22,15 +21,14 @@ #include "tools/converter/quant_param_holder.h" #include "tools/converter/quantizer/quantize_util.h" #include "src/common/log_adapter.h" -#include "src/tensor.h" namespace mindspore { namespace opt { namespace { constexpr size_t kDoubleNum = 2; void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &input_size) { - auto quant_param_valueptr = prim->GetAttr("quant_params"); - if (quant_param_valueptr == nullptr) { + auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); + if (quant_tensor_info_ptr == nullptr) { prim->AddAttr("quant_params", std::make_shared()); } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); @@ -63,8 +61,8 @@ void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &in } int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { - auto quant_param_valueptr = prim->GetAttr("quant_params"); - if (quant_param_valueptr == nullptr) { + auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); + if (quant_tensor_info_ptr == nullptr) { prim->AddAttr("quant_params", std::make_shared()); } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); @@ -123,8 +121,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t } int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) { - auto quant_param_valueptr = prim->GetAttr("quant_params"); - if (quant_param_valueptr == nullptr) { + auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); + if (quant_tensor_info_ptr == nullptr) { prim->AddAttr("quant_params", std::make_shared()); } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); @@ -156,8 +154,8 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t } void CheckQuantParams(const PrimitivePtr &prim) { - auto quant_param_valueptr = prim->GetAttr("quant_params"); - if (quant_param_valueptr == nullptr) { + auto quant_tensor_info_ptr = prim->GetAttr("quant_params"); + if (quant_tensor_info_ptr == nullptr) { prim->AddAttr("quant_params", std::make_shared()); } auto quant_param_holder = prim->GetAttr("quant_params")->cast(); @@ -263,61 +261,6 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) { return lite::RET_NO_CHANGE; } -int MindirAdjustPass::ParameterNodeConvert(AnfNodePtr anf_node) { - if (!utils::isa(anf_node)) { - MS_LOG(INFO) << "only parameter node need to convert tensor."; - return lite::RET_NO_CHANGE; - } - auto param_node = anf_node->cast(); - if (!param_node->has_default()) { - MS_LOG(INFO) << "this is graph input, don't need to convert."; - return lite::RET_NO_CHANGE; - } - if (utils::isa(param_node->default_param())) { - MS_LOG(INFO) << "the tensor has been a paramvalueLite."; - return lite::RET_NO_CHANGE; - } - ParamValueLitePtr param_value = std::make_shared(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "fail to new a ParamValueLite."; - return lite::RET_ERROR; - } - param_node->set_name(param_node->debug_info()->name()); - auto tensor_info = param_node->default_param()->cast(); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "the node is not a tensor::TensorPtr."; - return lite::RET_ERROR; - } - param_value->set_tensor_size(tensor_info->Size()); - param_value->set_tensor_type(tensor_info->data_type()); - auto tensor_shape = tensor_info->shape(); - std::vector shape; - std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape), - [](int64_t value) { return static_cast(value); }); - param_value->set_tensor_shape(shape); - auto *tensor = new (std::nothrow) lite::Tensor(tensor_info->data_type(), shape); - if (tensor == nullptr) { - MS_LOG(ERROR) << "new a lite::tensor failed, get a nullptr."; - return lite::RET_MEMORY_FAILED; - } - auto *tensor_data_buf = tensor->MutableData(); - if (tensor_data_buf == nullptr) { - MS_LOG(ERROR) << "malloc tensor data failed."; - delete tensor; - return lite::RET_MEMORY_FAILED; - } - if (memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; - delete tensor; - return lite::RET_MEMORY_FAILED; - } - tensor->set_data(nullptr); - param_value->set_tensor_addr(tensor_data_buf); - param_node->set_default_param(param_value); - delete tensor; - return lite::RET_OK; -} - int MindirAdjustPass::ComputeQuantParams(std::shared_ptr anf_node) { if (!utils::isa(anf_node)) { MS_LOG(INFO) << "only cnode need to convert primitive."; @@ -357,9 +300,7 @@ bool MindirAdjustPass::Run(const FuncGraphPtr &graph) { int status = lite::RET_OK; bool success_flag = true; for (auto &node : node_list) { - if (utils::isa(node)) { - status = ParameterNodeConvert(node); - } else if (utils::isa(node)) { + if (utils::isa(node)) { status = ComputeQuantParams(node); } else if (utils::isa(node)) { status = ValueNodeInt64Convert(node); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h index ab04f430b2f..e1f0a13439b 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h @@ -21,7 +21,6 @@ #include "backend/optimizer/common/pass.h" #include "tools/converter/converter_flags.h" #include "tools/optimizer/common/gllo_utils.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; using mindspore::schema::QuantType; @@ -34,7 +33,6 @@ class MindirAdjustPass : public Pass { void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } int ValueNodeInt64Convert(AnfNodePtr anf_node); void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } - int ParameterNodeConvert(AnfNodePtr anf_node); int ComputeQuantParams(AnfNodePtr anf_node); bool Run(const FuncGraphPtr &graph) override; diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc index ea9680119c8..85c2b63314e 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc @@ -14,13 +14,11 @@ * limitations under the License. */ #include "tools/optimizer/graph/onnx_inputs_adjust_pass.h" -#include #include #include #include #include #include "ops/fusion/conv2d_fusion.h" -#include "ops/fusion/conv2d_transpose_fusion.h" #include "ops/resize.h" #include "include/errorcode.h" @@ -89,12 +87,12 @@ STATUS OnnxInputAdjustOpPass::ReplaceInt64ParameterNode(const FuncGraphPtr &func MS_LOG(ERROR) << "default data is nullptr."; return lite::RET_NULL_PTR; } - auto param_value = default_value->cast(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "default data is not paramvaluelite."; + auto tensor_info = default_value->cast(); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "default data is not tensor::Tensor."; return lite::RET_NULL_PTR; } - auto param_node_new = BuildParameterNode(func_graph, param_node, param_value); + auto param_node_new = BuildParameterNode(func_graph, param_node, tensor_info); manager->Replace(param_node, param_node_new); } else { // set graph input @@ -151,17 +149,17 @@ STATUS OnnxInputAdjustOpPass::ReplaceConstant(const FuncGraphPtr &func_graph, co MS_LOG(ERROR) << "value is not primitive_c."; return lite::RET_ERROR; } - auto param_value = primitive_c->GetAttr("const_data"); - if (param_value == nullptr) { + auto tensor_info = primitive_c->GetAttr("const_data"); + if (tensor_info == nullptr) { MS_LOG(ERROR) << "constant cnode has no data."; return lite::RET_ERROR; } - auto param_value_lite = param_value->cast(); - if (param_value_lite == nullptr) { - MS_LOG(ERROR) << "valueptr is not paramvalueliteptr."; + auto tensor_info_ptr = tensor_info->cast(); + if (tensor_info_ptr == nullptr) { + MS_LOG(ERROR) << "valueptr is not tensor::Tensorptr."; return lite::RET_ERROR; } - auto param_node = BuildParameterNode(func_graph, cnode, param_value_lite); + auto param_node = BuildParameterNode(func_graph, cnode, tensor_info_ptr); if (param_node == nullptr) { MS_LOG(ERROR) << "convert constant to param node failed."; return lite::RET_ERROR; @@ -199,17 +197,17 @@ STATUS OnnxInputAdjustOpPass::ReplaceTransposeWithGraphInput(const FuncGraphPtr auto perm_anf = cnode->input(2); auto perm_param = perm_anf->cast(); if (perm_param == nullptr || !perm_param->has_default() || - !utils::isa(perm_param->default_param())) { + !utils::isa(perm_param->default_param())) { MS_LOG(DEBUG) << "transpose second input is not parameter node."; return lite::RET_OK; } - auto perm_value = perm_param->default_param()->cast(); - if (perm_value->tensor_shape().empty()) { + auto perm_value = perm_param->default_param()->cast(); + if (perm_value->shape().empty()) { MS_LOG(ERROR) << "transpose second input is invalid."; return lite::RET_ERROR; } - std::vector perm(perm_value->tensor_shape()[0]); - if (memcpy_s(perm.data(), perm_value->tensor_size(), perm_value->tensor_addr(), perm_value->tensor_size()) != EOK) { + std::vector perm(perm_value->shape()[0]); + if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return lite::RET_ERROR; } @@ -252,12 +250,12 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph, if (param_node == nullptr || !param_node->has_default()) { continue; } - const auto &default_data = param_node->default_param()->cast(); + const auto &default_data = param_node->default_param()->cast(); if (default_data == nullptr) { - MS_LOG(ERROR) << "this input is not a paramValueLite."; + MS_LOG(ERROR) << "this input is not a tensor::Tensor"; return lite::RET_ERROR; } - auto shape = default_data->tensor_shape(); + auto shape = default_data->shape(); size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); break; } diff --git a/mindspore/lite/tools/optimizer/graph/onnx_pad_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/onnx_pad_adjust_pass.cc index 380ca5a422c..c2b8cd3f44e 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_pad_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/onnx_pad_adjust_pass.cc @@ -21,7 +21,7 @@ #include "ops/reshape.h" #include "ops/transpose.h" #include "ops/primitive_c.h" -#include "src/param_value_lite.h" +#include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { @@ -32,35 +32,22 @@ ParameterPtr OnnxPadAdjustPass::CreateNewParameter(const FuncGraphPtr &func_grap MS_ASSERT(func_graph != nullptr); MS_ASSERT(data != nullptr); auto parameter = func_graph->add_parameter(); - std::vector shape; - shape.push_back(static_cast(data.size())); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto type_id = static_cast(kNumberTypeInt32); - auto type_ptr = TypeIdToType(type_id); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); - - ParamValueLitePtr param_value = std::make_shared(); - MS_ASSERT(param_value != nullptr); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - param_value->set_format(schema::Format_NCHW); - + ShapeVector shape_vector; + shape_vector.push_back(static_cast(data.size())); size_t size = data.size() * sizeof(int); - auto tensor_data = new (std::nothrow) uint8_t[size]; - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; + + auto tensor_info = lite::CreateTensorInfo(data.data(), size, shape_vector, kNumberTypeInt32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "create tensor info failed."; return nullptr; } - auto ret = memcpy_s(tensor_data, size, data.data(), size); - if (ret != 0) { - MS_LOG(ERROR) << "set tensor data failed."; + + auto status = lite::InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; return nullptr; } - param_value->SetTensorData(tensor_data, size); - parameter->set_default_param(param_value); + return parameter; } diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc index 5543b7ace4c..0946d9aa99e 100644 --- a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -44,17 +44,16 @@ std::vector GetSliceBeginAndSize(const CNodePtr &cnode, const int index) { if (node == nullptr) { return content; } - auto paramter_node = node->cast(); - if (paramter_node == nullptr || !paramter_node->has_default() || paramter_node->default_param() == nullptr) { + auto param_node = node->cast(); + if (param_node == nullptr || !param_node->has_default() || param_node->default_param() == nullptr) { return content; } - auto paramter_value = paramter_node->default_param()->cast(); - if (paramter_value == nullptr) { + auto tensor_info = param_node->default_param()->cast(); + if (tensor_info == nullptr) { return content; } - content.resize(paramter_value->tensor_shape_size()); - if (memcpy_s(content.data(), paramter_value->tensor_shape_size(), paramter_value->tensor_addr(), - paramter_value->tensor_shape_size()) != EOK) { + content.resize(tensor_info->DataSize()); + if (memcpy_s(content.data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return {}; } @@ -91,12 +90,12 @@ std::vector GetDefaultParamShape(const ParameterPtr ¶m) { MS_LOG(ERROR) << "default_param is nullptr"; return shape_vector; } - if (!utils::isa(default_param)) { - MS_LOG(ERROR) << "default_param is not ParamValueLite"; + if (!utils::isa(default_param)) { + MS_LOG(ERROR) << "default_param is not tensor::Tensor"; return shape_vector; } - auto param_value_lite = utils::cast(default_param); - auto shape = param_value_lite->tensor_shape(); + auto param_value_lite = utils::cast(default_param); + auto shape = param_value_lite->shape(); std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int val) { return static_cast(val); }); return shape_vector; @@ -104,8 +103,8 @@ std::vector GetDefaultParamShape(const ParameterPtr ¶m) { bool IsScalarNode(const AnfNodePtr &nodePtr) { if (utils::isa(nodePtr) && nodePtr->cast()->has_default()) { - auto tensor = utils::cast(utils::cast(nodePtr)->default_param()); - auto shape = tensor->tensor_shape(); + auto tensor = utils::cast(utils::cast(nodePtr)->default_param()); + auto shape = tensor->shape(); if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) { return true; } @@ -158,12 +157,12 @@ std::vector GetTransposePerm(const CNodePtr &node) { if (!perm_param->has_default() || perm_param->default_param() == nullptr) { return perm; } - auto perm_value = perm_param->default_param()->cast(); + auto perm_value = perm_param->default_param()->cast(); if (perm_value == nullptr) { return perm; } - perm.resize(perm_value->tensor_shape()[0]); - if (memcpy_s(perm.data(), perm_value->tensor_size(), perm_value->tensor_addr(), perm_value->tensor_size()) != EOK) { + perm.resize(perm_value->shape()[0]); + if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) { MS_LOG(ERROR) << "memcpy failed."; return {}; } diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc index c7d500f3917..c92f3dfb118 100644 --- a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.cc @@ -106,12 +106,12 @@ STATUS TfliteInputsAdjustPass::ReplaceInt64ParameterNode(const FuncGraphPtr &fun MS_LOG(ERROR) << "default data is nullptr."; return lite::RET_NULL_PTR; } - auto param_value = default_value->cast(); - if (param_value == nullptr) { - MS_LOG(ERROR) << "default data is not paramvaluelite."; + auto tensor_info = default_value->cast(); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "default data is not tensor::Tensor."; return lite::RET_NULL_PTR; } - auto param_node_new = BuildParameterNode(func_graph, param_node, param_value); + auto param_node_new = BuildParameterNode(func_graph, param_node, tensor_info); manager->Replace(param_node, param_node_new); } else { // set graph input diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.h index 850953ccc0a..95751a9b833 100644 --- a/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_adjust_pass.h @@ -19,7 +19,6 @@ #include #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" #include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { diff --git a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc index 78b98c5b9a4..b72277825d7 100644 --- a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc @@ -44,12 +44,12 @@ std::vector GetTransposePerm(const CNodePtr &node) { if (!perm_param->has_default() || perm_param->default_param() == nullptr) { return perm; } - auto perm_value = perm_param->default_param()->cast(); + auto perm_value = perm_param->default_param()->cast(); if (perm_value == nullptr) { return perm; } - perm.resize(perm_value->tensor_shape()[0]); - if (memcpy_s(perm.data(), perm_value->tensor_size(), perm_value->tensor_addr(), perm_value->tensor_size()) != EOK) { + perm.resize(perm_value->shape()[0]); + if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) { MS_LOG(ERROR) << "memcpy failed."; return {}; } diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc index 54854b62ebb..d232bc39e42 100644 --- a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc @@ -52,8 +52,8 @@ lite::STATUS UpdateConv2DParamPass::UpdateCommonConv2D(const CNodePtr &cnode) { return lite::RET_NO_CHANGE; } auto default_param = weight_param->default_param(); - auto weight_tensor = std::dynamic_pointer_cast(default_param); - auto weight_shape = weight_tensor->tensor_shape(); + auto weight_tensor = std::dynamic_pointer_cast(default_param); + auto weight_shape = weight_tensor->shape(); std::vector kernel_size = {weight_shape[0], weight_shape[1]}; conv->set_kernel_size(kernel_size); conv->set_in_channel(weight_shape[2]); @@ -75,8 +75,8 @@ lite::STATUS UpdateConv2DParamPass::UpdateDepthWiseConv2D(const CNodePtr &cnode) if (input_node->isa()) { auto param_node = input_node->cast(); auto param = param_node->default_param(); - auto weight = std::dynamic_pointer_cast(param); - conv->set_in_channel(static_cast(weight->tensor_shape().at(0))); + auto weight = std::dynamic_pointer_cast(param); + conv->set_in_channel(static_cast(weight->shape().at(0))); } } return lite::RET_OK; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 80f2137fd7c..4751e1633e1 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -36,14 +36,19 @@ const PrimitivePtr kPrimConv2DBackpropInputFusion = std::make_shared( void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; } void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; } lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const CNodePtr &conv_node, - const ParamValueLitePtr ¶m_value) const { + const tensor::TensorPtr &tensor_info) const { MS_ASSERT(conv_cnode != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); + auto prim = GetValueNode(conv_node->input(0)); + if (prim == nullptr) { + MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; + return lite::RET_ERROR; + } switch (quant_type) { case schema::QuantType_PostTraining: case QuantType_WeightQuant: case QuantType_QUANT_NONE: - param_value->set_format(schema::Format::Format_KCHW); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KCHW)); break; default: { MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) @@ -55,9 +60,9 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const CNodePtr &conv_node, } lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node, - const ParamValueLitePtr ¶m_value) const { + const tensor::TensorPtr &tensor_info) const { MS_ASSERT(conv_cnode != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); auto prim = GetValueNode(conv_node->input(0)); if (prim == nullptr) { MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; @@ -70,12 +75,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node, // sum up from current onnx quant models if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) { if (!is_depth_wise) { - param_value->set_format(schema::Format::Format_KHWC); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KHWC)); } else { - param_value->set_format(schema::Format::Format_CHWK); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::CHWK)); } } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { - param_value->set_format(schema::Format::Format_KCHW); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KCHW)); } else { MS_LOG(ERROR) << "Unsupported op: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; @@ -91,9 +96,9 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node, if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion) || CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion)) { if (format == schema::Format::Format_NHWC) { - param_value->set_format(schema::Format::Format_KHWC); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KHWC)); } else { - param_value->set_format(schema::Format::Format_KCHW); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KCHW)); } } } break; @@ -107,9 +112,9 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const CNodePtr &conv_node, } lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node, - const ParamValueLitePtr ¶m_value) const { + const tensor::TensorPtr &tensor_info) const { MS_ASSERT(conv_cnode != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); auto prim = GetValueNode(conv_node->input(0)); if (prim == nullptr) { MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; @@ -122,7 +127,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node, case QuantType_WeightQuant: case QuantType_QUANT_NONE: { // sum up from current ms quant models - param_value->set_format(schema::Format::Format_KCHW); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KCHW)); } break; default: { MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) @@ -134,9 +139,9 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node, } lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const CNodePtr &conv_node, - const ParamValueLitePtr ¶m_value) const { + const tensor::TensorPtr &tensor_info) const { MS_ASSERT(conv_cnode != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); auto prim = GetValueNode(conv_node->input(0)); if (prim == nullptr) { MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; @@ -150,12 +155,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const CNodePtr &conv_node, case QuantType_QUANT_NONE: { if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) { if (!is_depth_wise) { - param_value->set_format(schema::Format::Format_KHWC); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KHWC)); } else { - param_value->set_format(schema::Format::Format_CHWK); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::CHWK)); } } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { - param_value->set_format(schema::Format::Format_CHWK); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::CHWK)); } } break; default: { @@ -167,9 +172,9 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const CNodePtr &conv_node, } lite::STATUS WeightFormatHardCodePass::HardCodeTF(const CNodePtr &conv_node, - const ParamValueLitePtr ¶m_value) const { + const tensor::TensorPtr &tensor_info) const { MS_ASSERT(conv_cnode != nullptr); - MS_ASSERT(param_value != nullptr); + MS_ASSERT(tensor_info != nullptr); auto prim = GetValueNode(conv_node->input(0)); if (prim == nullptr) { MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; @@ -179,13 +184,13 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTF(const CNodePtr &conv_node, if (CheckPrimitiveType(conv_node, prim::kPrimConv2DFusion)) { { if (!is_depth_wise) { - param_value->set_format(schema::Format::Format_HWCK); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::HWCK)); } else { - param_value->set_format(schema::Format::Format_HWKC); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::HWKC)); } } } else if (CheckPrimitiveType(conv_node, prim::kPrimConv2dTransposeFusion) && !is_depth_wise) { - param_value->set_format(schema::Format::Format_HWCK); + prim->AddAttr(opt::kWeightFormat, MakeValue(Format::HWCK)); } return lite::RET_OK; } @@ -206,27 +211,27 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); auto weight_node = conv_cnode->input(kConvWeightIndex); MS_ASSERT(weight_node != nullptr); - auto param_value = GetLiteParamValue(weight_node); - if (param_value == nullptr) { + auto tensor_info = GetTensorInfo(weight_node); + if (tensor_info == nullptr) { MS_LOG(ERROR) << "weight node must param value"; return false; } lite::STATUS status; switch (fmk_type) { case FmkType_CAFFE: - status = HardCodeCAFFE(conv_cnode, param_value); + status = HardCodeCAFFE(conv_cnode, tensor_info); break; case FmkType_TFLITE: - status = HardCodeTFLITE(conv_cnode, param_value); + status = HardCodeTFLITE(conv_cnode, tensor_info); break; case FmkType_TF: - status = HardCodeTF(conv_cnode, param_value); + status = HardCodeTF(conv_cnode, tensor_info); break; case FmkType_ONNX: - status = HardCodeONNX(conv_cnode, param_value); + status = HardCodeONNX(conv_cnode, tensor_info); break; case FmkType_MS: - status = HardCodeMS(conv_cnode, param_value); + status = HardCodeMS(conv_cnode, tensor_info); break; default: MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h index a78b02371b8..e78546a7a92 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h @@ -20,7 +20,6 @@ #include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; using mindspore::schema::QuantType; @@ -34,11 +33,11 @@ class WeightFormatHardCodePass : public Pass { bool Run(const FuncGraphPtr &graph) override; private: - lite::STATUS HardCodeCAFFE(const CNodePtr &node, const ParamValueLitePtr ¶m_value) const; - lite::STATUS HardCodeONNX(const CNodePtr &node, const ParamValueLitePtr ¶m_value) const; - lite::STATUS HardCodeMS(const CNodePtr &node, const ParamValueLitePtr ¶m_value) const; - lite::STATUS HardCodeTFLITE(const CNodePtr &node, const ParamValueLitePtr ¶m_value) const; - lite::STATUS HardCodeTF(const CNodePtr &conv_node, const ParamValueLitePtr ¶m_value) const; + lite::STATUS HardCodeCAFFE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const; + lite::STATUS HardCodeONNX(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const; + lite::STATUS HardCodeMS(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const; + lite::STATUS HardCodeTFLITE(const CNodePtr &node, const tensor::TensorPtr &tensor_info) const; + lite::STATUS HardCodeTF(const CNodePtr &conv_node, const tensor::TensorPtr &tensor_info) const; private: QuantType quant_type = schema::QuantType_QUANT_NONE; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index c83ade69d0f..0232ca0ee7b 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -69,9 +69,7 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu if (!utils::isa(node)) { continue; } - if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || - CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) || - CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) || + if (CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) || CheckPrimitiveType(node, prim::kPrimAdam)) { continue; } @@ -79,6 +77,13 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu auto inputs = cnode->inputs(); if (std::any_of(inputs.begin(), inputs.end(), [&weight_node](const AnfNodePtr &anf_node) { return weight_node == anf_node; })) { + if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || + CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || + CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { + auto prim = GetValueNode(cnode->input(0)); + prim->AddAttr(kWeightFormat, MakeValue(mindspore::KHWC)); + continue; + } adjust_nodes.push_back(cnode); } } @@ -138,9 +143,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr } auto conv_cnode = node->cast(); MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); + auto prim = GetValueNode(conv_cnode->input(0)); + if (prim == nullptr) { + MS_LOG(ERROR) << "Invalid anfnode, which don't have primitive."; + return lite::RET_ERROR; + } auto weight_node = conv_cnode->input(kConvWeightIndex); MS_ASSERT(weight_node != nullptr); - auto weight_value = GetLiteParamValue(weight_node); + auto weight_value = GetTensorInfo(weight_node); if (weight_value == nullptr) { MS_LOG(ERROR) << "weight node must param value"; return false; @@ -148,31 +158,30 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); lite::STATUS status; - schema::Format src_format = static_cast(weight_value->format()); + auto value_ptr = prim->GetAttr(opt::kWeightFormat); + auto weight_src_format = static_cast(GetValue(value_ptr)); schema::Format weight_dst_format = schema::Format::Format_KHWC; if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { weight_dst_format = dst_format; } - status = TransFilterFormat(weight_value, weight_dst_format); + status = TransFilterFormat(weight_value, weight_src_format, weight_dst_format); if (status == RET_OK) { - weight_value->set_format(weight_dst_format); + prim->AddAttr(opt::kWeightFormat, MakeValue(weight_dst_format)); } else { - MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To" + MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_dst_format]) << "To" << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope() << "quant type:" << quant_type; return ERROR; } - status = HandleWeightSharing(graph, weight_node->cast(), src_format, weight_dst_format); + status = HandleWeightSharing(graph, weight_node->cast(), weight_src_format, weight_dst_format); if (status != lite::RET_OK) { MS_LOG(ERROR) << "handle weight-sharing failed."; return false; } - auto type_id = static_cast(weight_value->tensor_type()); + auto type_id = static_cast(weight_value->data_type()); auto type_ptr = TypeIdToType(type_id); - auto shape = weight_value->tensor_shape(); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); + auto shape = weight_value->shape(); + std::vector shape_vector(shape.begin(), shape.end()); auto abstract_tensor = std::make_shared(type_ptr, shape_vector); weight_node->set_abstract(abstract_tensor); } diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.h b/mindspore/lite/tools/optimizer/graph/while_pass.h index 595bb41df35..252ecca61e3 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.h +++ b/mindspore/lite/tools/optimizer/graph/while_pass.h @@ -21,7 +21,6 @@ #include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" -#include "src/param_value_lite.h" using mindspore::lite::converter::FmkType; namespace mindspore::opt {