!7260 [MSLITE] Support assigned input data shapes while running models.
Merge pull request !7260 from wangshaocong/bugfix_master
This commit is contained in:
commit
d6287ae6d8
|
@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() {
|
||||||
std::cout << "CompileGraph failed while running ", model_name.c_str();
|
std::cout << "CompileGraph failed while running ", model_name.c_str();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
if (!flags_->input_shape_list_.empty()) {
|
||||||
|
std::vector<std::vector<int>> input_shapes;
|
||||||
|
std::string input_dims_list = flags_->input_shape_list_;
|
||||||
|
while (!input_dims_list.empty()) {
|
||||||
|
auto position =
|
||||||
|
input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
|
||||||
|
std::string input_dims = input_dims_list.substr(0, position);
|
||||||
|
std::vector<int> input_shape;
|
||||||
|
while (!input_dims.empty()) {
|
||||||
|
auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
|
||||||
|
std::string dim = input_dims.substr(0, pos);
|
||||||
|
input_shape.emplace_back(std::stoi(dim));
|
||||||
|
input_dims = input_dims.substr(pos);
|
||||||
|
}
|
||||||
|
input_shapes.emplace_back(input_shape);
|
||||||
|
input_dims_list = input_dims_list.substr(position);
|
||||||
|
}
|
||||||
|
ret = session_->Resize(session_->GetInputs(), input_shapes);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Input tensor resize failed.";
|
||||||
|
std::cout << "Input tensor resize failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
model->Free();
|
model->Free();
|
||||||
ms_inputs_ = session_->GetInputs();
|
ms_inputs_ = session_->GetInputs();
|
||||||
auto end_prepare_time = GetTimeUs();
|
auto end_prepare_time = GetTimeUs();
|
||||||
|
|
|
@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
||||||
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
|
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
|
||||||
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
|
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
|
||||||
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
|
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
|
||||||
|
AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes",
|
||||||
|
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
~BenchmarkFlags() override = default;
|
~BenchmarkFlags() override = default;
|
||||||
|
@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
||||||
InDataType in_data_type_;
|
InDataType in_data_type_;
|
||||||
std::string in_data_type_in_ = "bin";
|
std::string in_data_type_in_ = "bin";
|
||||||
int cpu_bind_mode_ = 1;
|
int cpu_bind_mode_ = 1;
|
||||||
|
std::string input_shape_list_;
|
||||||
// MarkPerformance
|
// MarkPerformance
|
||||||
int loop_count_;
|
int loop_count_;
|
||||||
int num_threads_;
|
int num_threads_;
|
||||||
|
|
|
@ -26,6 +26,9 @@ using mindspore::lite::Tensor;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr int DEFAULT_DIM_VALUE = -1;
|
||||||
|
}
|
||||||
|
namespace {
|
||||||
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
|
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
|
||||||
const schema::PrimitiveType node_type) {
|
const schema::PrimitiveType node_type) {
|
||||||
std::vector<Tensor *> lite_tensors;
|
std::vector<Tensor *> lite_tensors;
|
||||||
|
@ -85,6 +88,15 @@ void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> outp
|
||||||
} // namespace
|
} // namespace
|
||||||
STATUS InferShapePass::Run(MetaGraphT *graph) {
|
STATUS InferShapePass::Run(MetaGraphT *graph) {
|
||||||
MS_ASSERT(graph != nullptr);
|
MS_ASSERT(graph != nullptr);
|
||||||
|
for (auto idx : graph->inputIndex) {
|
||||||
|
auto input_tensor = graph->allTensors[idx].get();
|
||||||
|
for (auto &dim : input_tensor->dims) {
|
||||||
|
if (dim == 0) {
|
||||||
|
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value.";
|
||||||
|
dim = DEFAULT_DIM_VALUE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||||
auto &node = *iter;
|
auto &node = *iter;
|
||||||
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);
|
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);
|
||||||
|
|
|
@ -41,7 +41,14 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
|
||||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||||
const auto &attribute_name = onnx_node_attr.name();
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
if (attribute_name == "value") {
|
if (attribute_name == "value") {
|
||||||
attr->value = static_cast<int32_t>(onnx_node_attr.i());
|
if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) {
|
||||||
|
auto tensor = onnx_node_attr.t();
|
||||||
|
if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) {
|
||||||
|
attr->value = onnx_node_attr.f();
|
||||||
|
} else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) {
|
||||||
|
attr->value = static_cast<int32_t>(onnx_node_attr.i());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
||||||
const auto &attribute_name = onnx_node_attr.name();
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
if (attribute_name == "kernel_shape") {
|
if (attribute_name == "kernel_shape") {
|
||||||
if (onnx_node_attr.ints_size() == 2) {
|
if (onnx_node_attr.ints_size() == 2) {
|
||||||
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||||
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (attribute_name == "strides") {
|
if (attribute_name == "strides") {
|
||||||
if (onnx_node_attr.ints_size() == 2) {
|
if (onnx_node_attr.ints_size() == 2) {
|
||||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (attribute_name == "auto_pad") {
|
if (attribute_name == "auto_pad") {
|
||||||
|
|
Loading…
Reference in New Issue