diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 8cf7b692e3c..77122147ade 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() { std::cout << "CompileGraph failed while running ", model_name.c_str(); return ret; } + if (!flags_->input_shape_list_.empty()) { + std::vector> input_shapes; + std::string input_dims_list = flags_->input_shape_list_; + while (!input_dims_list.empty()) { + auto position = + input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length(); + std::string input_dims = input_dims_list.substr(0, position); + std::vector input_shape; + while (!input_dims.empty()) { + auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length(); + std::string dim = input_dims.substr(0, pos); + input_shape.emplace_back(std::stoi(dim)); + input_dims = input_dims.substr(pos); + } + input_shapes.emplace_back(input_shape); + input_dims_list = input_dims_list.substr(position); + } + ret = session_->Resize(session_->GetInputs(), input_shapes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Input tensor resize failed."; + std::cout << "Input tensor resize failed."; + return ret; + } + } model->Free(); ms_inputs_ = session_->GetInputs(); auto end_prepare_time = GetTimeUs(); diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index 9e90885d1c4..3886bd42cd0 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); + AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes", + "Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", ""); } ~BenchmarkFlags() override = default; @@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { InDataType in_data_type_; std::string in_data_type_in_ = "bin"; int cpu_bind_mode_ = 1; + std::string input_shape_list_; // MarkPerformance int loop_count_; int num_threads_; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index fa4858e30c0..0ea1da47ce1 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -26,6 +26,9 @@ using mindspore::lite::Tensor; namespace mindspore { namespace lite { namespace { +constexpr int DEFAULT_DIM_VALUE = -1; +} +namespace { std::vector ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector &tensor_indexs, const schema::PrimitiveType node_type) { std::vector lite_tensors; @@ -85,6 +88,15 @@ void FreeTensors(std::vector input_tensors, std::vector outp } // namespace STATUS InferShapePass::Run(MetaGraphT *graph) { MS_ASSERT(graph != nullptr); + for (auto idx : graph->inputIndex) { + auto input_tensor = graph->allTensors[idx].get(); + for (auto &dim : input_tensor->dims) { + if (dim == 0) { + MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value."; + dim = DEFAULT_DIM_VALUE; + } + } + } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc index 6eb1887e2a7..8908ebcd623 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -41,7 +41,14 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "value") { - attr->value = static_cast(onnx_node_attr.i()); + if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) { + auto tensor = onnx_node_attr.t(); + if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) { + attr->value = onnx_node_attr.f(); + } else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) { + attr->value = static_cast(onnx_node_attr.i()); + } + } } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 36ccd735cbf..105f6ca3740 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "kernel_shape") { if (onnx_node_attr.ints_size() == 2) { - attr->windowW = static_cast(onnx_node_attr.ints(0)); - attr->windowH = static_cast(onnx_node_attr.ints(1)); + attr->windowH = static_cast(onnx_node_attr.ints(0)); + attr->windowW = static_cast(onnx_node_attr.ints(1)); } } if (attribute_name == "strides") { if (onnx_node_attr.ints_size() == 2) { - attr->strideW = static_cast(onnx_node_attr.ints(0)); - attr->strideH = static_cast(onnx_node_attr.ints(1)); + attr->strideH = static_cast(onnx_node_attr.ints(0)); + attr->strideW = static_cast(onnx_node_attr.ints(1)); } } if (attribute_name == "auto_pad") {