forked from OSSInnovation/mindspore
!3672 fix serving input numbers
Merge pull request !3672 from hexia/fix_input_check
This commit is contained in:
commit
6ea2aa4e73
|
@ -94,25 +94,33 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto kernel_graph_inputs = kernel_graph->inputs();
|
||||
size_t no_weight_input = 0;
|
||||
vector<ParameterPtr> paras;
|
||||
// find parameters of graph inputs
|
||||
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
||||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
// compare input number
|
||||
if (no_weight_input >= inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
||||
<< "] less than that of graph.";
|
||||
return false;
|
||||
}
|
||||
auto input = inputs[no_weight_input++];
|
||||
if (!CompareInput(input, parameter)) {
|
||||
MS_LOG(ERROR) << "Please check the input information.";
|
||||
return false;
|
||||
}
|
||||
paras.push_back(parameter);
|
||||
}
|
||||
}
|
||||
|
||||
// check inputs
|
||||
for (size_t i = 0; i < paras.size(); ++i) {
|
||||
// compare input number
|
||||
if (paras.size() != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
||||
<< "] but the graph input number is [" << paras.size() << "]";
|
||||
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||
return false;
|
||||
}
|
||||
auto input = inputs[no_weight_input++];
|
||||
if (!CompareInput(input, paras[i])) {
|
||||
MS_LOG(ERROR) << "Please check the input information.";
|
||||
MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -123,12 +131,6 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
|||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
// compare dims
|
||||
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||
if (input->shape().size() != parameter_shape.size()) {
|
||||
MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size()
|
||||
<< ", but the parameter dim is " << parameter_shape.size()
|
||||
<< ". parameter : " << parameter->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
// compare shape
|
||||
auto input_shape = input->shape();
|
||||
|
@ -153,12 +155,31 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
|||
return true;
|
||||
}
|
||||
|
||||
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) const {
|
||||
template <typename T>
|
||||
std::string AscendInferenceSession::PrintInputShape(std::vector<T> shape) const {
|
||||
string res = "[";
|
||||
for (auto dim : shape) {
|
||||
res += " " + std::to_string(dim);
|
||||
}
|
||||
return res + " ]";
|
||||
}
|
||||
|
||||
std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> ¶s,
|
||||
const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
std::string graph = "graph inputs:{ ";
|
||||
for (size_t i = 0; i < paras.size(); ++i) {
|
||||
graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(paras[i], 0).size()) +
|
||||
", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(paras[i], 0)) + ", data type " +
|
||||
std::to_string(AnfAlgo::GetSelectKernelBuildInfo(paras[i])->GetOutputDeviceType(0)) + " }";
|
||||
}
|
||||
|
||||
std::string actual = "actual inputs:{ ";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " +
|
||||
PrintInputShape(inputs[i]->shape()) + ", data type " + std::to_string(inputs[i]->data_type()) + " }";
|
||||
}
|
||||
return graph + " " + actual;
|
||||
}
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,9 @@ class AscendInferenceSession : public AscendSession {
|
|||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const override;
|
||||
bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const;
|
||||
std::string PrintInputShape(std::vector<size_t> shape) const;
|
||||
template <typename T>
|
||||
std::string PrintInputShape(std::vector<T> shape) const;
|
||||
std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
};
|
||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||
} // namespace session
|
||||
|
|
Loading…
Reference in New Issue