!3672 fix serving input numbers

Merge pull request !3672 from hexia/fix_input_check
This commit is contained in:
mindspore-ci-bot 2020-07-30 09:17:16 +08:00 committed by Gitee
commit 6ea2aa4e73
2 changed files with 43 additions and 20 deletions

View File

@ -94,25 +94,33 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect
MS_EXCEPTION_IF_NULL(kernel_graph);
auto kernel_graph_inputs = kernel_graph->inputs();
size_t no_weight_input = 0;
vector<ParameterPtr> paras;
// find parameters of graph inputs
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
tensor::TensorPtr tensor = nullptr;
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
continue;
}
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
if (!AnfAlgo::IsParameterWeight(parameter)) {
// compare input number
if (no_weight_input >= inputs.size()) {
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
<< "] less than that of graph.";
return false;
}
auto input = inputs[no_weight_input++];
if (!CompareInput(input, parameter)) {
MS_LOG(ERROR) << "Please check the input information.";
return false;
}
paras.push_back(parameter);
}
}
// check inputs
for (size_t i = 0; i < paras.size(); ++i) {
// compare input number
if (paras.size() != inputs.size()) {
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
<< "] but the graph input number is [" << paras.size() << "]";
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
return false;
}
auto input = inputs[no_weight_input++];
if (!CompareInput(input, paras[i])) {
MS_LOG(ERROR) << "Please check the input information.";
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
return false;
}
}
return true;
@ -123,12 +131,6 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
MS_EXCEPTION_IF_NULL(parameter);
// compare dims
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
if (input->shape().size() != parameter_shape.size()) {
MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size()
<< ", but the parameter dim is " << parameter_shape.size()
<< ". parameter : " << parameter->DebugString();
return false;
}
// compare shape
auto input_shape = input->shape();
@ -153,12 +155,31 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
return true;
}
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) const {
template <typename T>
std::string AscendInferenceSession::PrintInputShape(std::vector<T> shape) const {
string res = "[";
for (auto dim : shape) {
res += " " + std::to_string(dim);
}
return res + " ]";
}
std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> &paras,
const std::vector<tensor::TensorPtr> &inputs) const {
std::string graph = "graph inputs:{ ";
for (size_t i = 0; i < paras.size(); ++i) {
graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(paras[i], 0).size()) +
", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(paras[i], 0)) + ", data type " +
std::to_string(AnfAlgo::GetSelectKernelBuildInfo(paras[i])->GetOutputDeviceType(0)) + " }";
}
std::string actual = "actual inputs:{ ";
for (size_t i = 0; i < inputs.size(); ++i) {
actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " +
PrintInputShape(inputs[i]->shape()) + ", data type " + std::to_string(inputs[i]->data_type()) + " }";
}
return graph + " " + actual;
}
} // namespace session
} // namespace mindspore

View File

@ -41,7 +41,9 @@ class AscendInferenceSession : public AscendSession {
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const override;
bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr &parameter) const;
std::string PrintInputShape(std::vector<size_t> shape) const;
template <typename T>
std::string PrintInputShape(std::vector<T> shape) const;
std::string InputsInfo(const std::vector<ParameterPtr> &paras, const std::vector<tensor::TensorPtr> &inputs) const;
};
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
} // namespace session