!7260 [MSLITE] Support assigned input data shapes while running models.
Merge pull request !7260 from wangshaocong/bugfix_master
This commit is contained in:
commit
d6287ae6d8
|
@ -406,6 +406,30 @@ int Benchmark::RunBenchmark() {
|
|||
std::cout << "CompileGraph failed while running ", model_name.c_str();
|
||||
return ret;
|
||||
}
|
||||
if (!flags_->input_shape_list_.empty()) {
|
||||
std::vector<std::vector<int>> input_shapes;
|
||||
std::string input_dims_list = flags_->input_shape_list_;
|
||||
while (!input_dims_list.empty()) {
|
||||
auto position =
|
||||
input_dims_list.find(";") != input_dims_list.npos ? input_dims_list.find(";") + 1 : input_dims_list.length();
|
||||
std::string input_dims = input_dims_list.substr(0, position);
|
||||
std::vector<int> input_shape;
|
||||
while (!input_dims.empty()) {
|
||||
auto pos = input_dims.find(",") != input_dims.npos ? input_dims.find(",") + 1 : input_dims.length();
|
||||
std::string dim = input_dims.substr(0, pos);
|
||||
input_shape.emplace_back(std::stoi(dim));
|
||||
input_dims = input_dims.substr(pos);
|
||||
}
|
||||
input_shapes.emplace_back(input_shape);
|
||||
input_dims_list = input_dims_list.substr(position);
|
||||
}
|
||||
ret = session_->Resize(session_->GetInputs(), input_shapes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Input tensor resize failed.";
|
||||
std::cout << "Input tensor resize failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
model->Free();
|
||||
ms_inputs_ = session_->GetInputs();
|
||||
auto end_prepare_time = GetTimeUs();
|
||||
|
|
|
@ -70,6 +70,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
|
||||
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
|
||||
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
|
||||
AddFlag(&BenchmarkFlags::input_shape_list_, "inputShapes",
|
||||
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32;1,1,32,32,1", "");
|
||||
}
|
||||
|
||||
~BenchmarkFlags() override = default;
|
||||
|
@ -86,6 +88,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
InDataType in_data_type_;
|
||||
std::string in_data_type_in_ = "bin";
|
||||
int cpu_bind_mode_ = 1;
|
||||
std::string input_shape_list_;
|
||||
// MarkPerformance
|
||||
int loop_count_;
|
||||
int num_threads_;
|
||||
|
|
|
@ -26,6 +26,9 @@ using mindspore::lite::Tensor;
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int DEFAULT_DIM_VALUE = -1;
|
||||
}
|
||||
namespace {
|
||||
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
|
||||
const schema::PrimitiveType node_type) {
|
||||
std::vector<Tensor *> lite_tensors;
|
||||
|
@ -85,6 +88,15 @@ void FreeTensors(std::vector<Tensor *> input_tensors, std::vector<Tensor *> outp
|
|||
} // namespace
|
||||
STATUS InferShapePass::Run(MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
for (auto idx : graph->inputIndex) {
|
||||
auto input_tensor = graph->allTensors[idx].get();
|
||||
for (auto &dim : input_tensor->dims) {
|
||||
if (dim == 0) {
|
||||
MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to 32 as a default value.";
|
||||
dim = DEFAULT_DIM_VALUE;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
auto &node = *iter;
|
||||
auto input_tensors = ConvertTensorToLiteTensor(graph, node->inputIndex, node->primitive->value.type);
|
||||
|
|
|
@ -41,7 +41,14 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
|
|||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "value") {
|
||||
attr->value = static_cast<int32_t>(onnx_node_attr.i());
|
||||
if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) {
|
||||
auto tensor = onnx_node_attr.t();
|
||||
if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) {
|
||||
attr->value = onnx_node_attr.f();
|
||||
} else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) {
|
||||
attr->value = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -66,14 +66,14 @@ STATUS OnnxPoolParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "kernel_shape") {
|
||||
if (onnx_node_attr.ints_size() == 2) {
|
||||
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->windowH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->windowW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
}
|
||||
}
|
||||
if (attribute_name == "strides") {
|
||||
if (onnx_node_attr.ints_size() == 2) {
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0));
|
||||
attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1));
|
||||
}
|
||||
}
|
||||
if (attribute_name == "auto_pad") {
|
||||
|
|
Loading…
Reference in New Issue