!7260 [MSLITE] Support assigned input data shapes while running models.

Merge pull request !7260 from wangshaocong/bugfix_master
This commit is contained in:
mindspore-ci-bot 2020-10-16 10:57:00 +08:00 committed by Gitee
commit d6287ae6d8
5 changed files with 51 additions and 5 deletions

View File

@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() {
std::cout << "CompileGraph failed while running ", model_name.c_str(); std::cout << "CompileGraph failed while running ", model_name.c_str();
return ret; return ret;
} }
if (!flags_->input_shape_list_.empty()) {
std::vector<std::vector<int>> input_shapes;
std::string input_dims_list = flags_->input_shape_list_;
while (!input_dims_list.empty()) {
auto position =
input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
std::string input_dims = input_dims_list.substr(0, position);
std::vector<int> input_shape;
while (!input_dims.empty()) {
auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
std::string dim = input_dims.substr(0, pos);
input_shape.emplace_back(std::stoi(dim));
input_dims = input_dims.substr(pos);
}
input_shapes.emplace_back(input_shape);
input_dims_list = input_dims_list.substr(position);
}
ret = session_->Resize(session_->GetInputs(), input_shapes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Input tensor resize failed.";
std::cout << "Input tensor resize failed.";
return ret;
}
}
model->Free(); model->Free();
ms_inputs_ = session_->GetInputs(); ms_inputs_ = session_->GetInputs();
auto end_prepare_time = GetTimeUs(); auto end_prepare_time = GetTimeUs();

View File

@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", "");
} }
~BenchmarkFlags() override = default; ~BenchmarkFlags() override = default;
@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
InDataType in_data_type_; InDataType in_data_type_;
std::string in_data_type_in_ = "bin"; std::string in_data_type_in_ = "bin";
int cpu_bind_mode_ = 1; int cpu_bind_mode_ = 1;
std::string input_shape_list_;
// MarkPerformance // MarkPerformance
int loop_count_; int loop_count_;
int num_threads_; int num_threads_;

View File

@ -26,6 +26,9 @@ using mindspore::lite::Tensor;
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
constexpr int DEFAULT_DIM_VALUE = -1;
}
namespace {
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs, std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
const schema::PrimitiveType node_type) { const schema::PrimitiveType node_type) {
std::vector<Tensor *> lite_tensors; std::vector<Tensor *> lite_tensors;
@ -85,6 +88,15 @@ void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> outp
} // namespace } // namespace
STATUS InferShapePass::Run(MetaGraphT *graph) { STATUS InferShapePass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
for (auto idx : graph->inputIndex) {
auto input_tensor = graph->allTensors[idx].get();
for (auto &dim : input_tensor->dims) {
if (dim == 0) {
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value.";
dim = DEFAULT_DIM_VALUE;
}
}
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter; auto &node = *iter;
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type); auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);

View File

@ -41,7 +41,14 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
for (const auto &onnx_node_attr : onnx_node.attribute()) { for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "value") { if (attribute_name == "value") {
attr->value = static_cast<int32_t>(onnx_node_attr.i()); if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) {
auto tensor = onnx_node_attr.t();
if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) {
attr->value = onnx_node_attr.f();
} else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) {
attr->value = static_cast<int32_t>(onnx_node_attr.i());
}
}
} }
} }

View File

@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
const auto &attribute_name = onnx_node_attr.name(); const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "kernel_shape") { if (attribute_name == "kernel_shape") {
if (onnx_node_attr.ints_size() == 2) { if (onnx_node_attr.ints_size() == 2) {
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(1));
} }
} }
if (attribute_name == "strides") { if (attribute_name == "strides") {
if (onnx_node_attr.ints_size() == 2) { if (onnx_node_attr.ints_size() == 2) {
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1)); attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
} }
} }
if (attribute_name == "auto_pad") { if (attribute_name == "auto_pad") {