From 65da4463c115eccdcad111aca0980c98387f20d0 Mon Sep 17 00:00:00 2001 From: hexia Date: Mon, 20 Jul 2020 16:57:56 +0800 Subject: [PATCH] check model input --- include/inference.h | 5 +- .../session/ascend_inference_session.cc | 77 +++++++++++++++++++ .../session/ascend_inference_session.h | 3 + mindspore/ccsrc/backend/session/session.cc | 6 ++ mindspore/ccsrc/backend/session/session.h | 3 + .../ccsrc/backend/session/session_basic.h | 3 + serving/core/server.cc | 5 ++ 7 files changed, 101 insertions(+), 1 deletion(-) diff --git a/include/inference.h b/include/inference.h index 7e5ee27d49a..9aac37f3237 100644 --- a/include/inference.h +++ b/include/inference.h @@ -25,6 +25,7 @@ namespace mindspore { class FuncGraph; namespace inference { +using VectorForMSTensorPtr = std::vector>; class MS_API MSSession { public: MSSession() = default; @@ -33,7 +34,9 @@ class MS_API MSSession { virtual uint32_t CompileGraph(std::shared_ptr funcGraphPtr) = 0; - virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) = 0; + virtual MultiTensor RunGraph(uint32_t graph_id, const VectorForMSTensorPtr &inputs) = 0; + + virtual bool CheckModelInputs(uint32_t graph_id, const VectorForMSTensorPtr &inputs) const = 0; }; std::shared_ptr MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc index d251eb20398..0999bfada7e 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include #include "backend/session/ascend_inference_session.h" #include "frontend/operator/ops.h" #include "ir/tensor.h" @@ -85,5 +87,80 @@ GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { } return graph_id; } + +bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, + const std::vector > &inputs) { + MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernel_graph_inputs = kernel_graph->inputs(); + size_t no_weight_input = 0; + for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { + tensor::TensorPtr tensor = nullptr; + if (!kernel_graph_inputs[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; + continue; + } + auto parameter = kernel_graph_inputs[i]->cast(); + if (!AnfAlgo::IsParameterWeight(parameter)) { + // compare input number + if (no_weight_input >= inputs.size()) { + MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() + << "] less than that of graph."; + return false; + } + auto input = inputs[no_weight_input++]; + if (!CompareInput(input, parameter)) { + MS_LOG(ERROR) << "Please check the input information."; + return false; + } + } + } + return true; +} + +bool AscendInferenceSession::CompareInput(const std::shared_ptr &input, + const ParameterPtr ¶meter) { + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(parameter); + // compare dims + auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); + if (input->shape().size() != parameter_shape.size()) { + MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size() + << ", but the parameter dim is " << parameter_shape.size() + << ". parameter : " << parameter->DebugString(); + return false; + } + + // compare shape + auto input_shape = input->shape(); + vector trans_input; + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), + [](const int dim) { return static_cast(dim); }); + if (trans_input != parameter_shape) { + MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) + << ", but the parameter shape is " << PrintInputShape(parameter_shape) + << ". parameter : " << parameter->DebugString(); + return false; + } + + // compare data type + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); + if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { + MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() + << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) + << ". parameter : " << parameter->DebugString(); + return false; + } + return true; +} + +std::string AscendInferenceSession::PrintInputShape(std::vector shape) { + string res = "["; + for (auto dim : shape) { + res += " " + std::to_string(dim); + } + return res + " ]"; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h index 5364ae8d4ee..664aeadee51 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.h +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -39,6 +39,9 @@ class AscendInferenceSession : public AscendSession { void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; GraphId CompileGraph(NotNull func_graph) override; + bool CheckModelInputs(uint32_t graph_id, const std::vector> &inputs) override; + bool CompareInput(const std::shared_ptr &input, const ParameterPtr ¶meter); + std::string PrintInputShape(std::vector shape); }; MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); } // namespace session diff --git a/mindspore/ccsrc/backend/session/session.cc b/mindspore/ccsrc/backend/session/session.cc index 95484a11132..b5c9c695371 100644 --- a/mindspore/ccsrc/backend/session/session.cc +++ b/mindspore/ccsrc/backend/session/session.cc @@ -204,5 +204,11 @@ int Session::Init(const std::string &device, uint32_t device_id) { return 0; } +bool Session::CheckModelInputs(uint32_t graph_id, + const std::vector> &inputs) const { + MS_ASSERT(session_impl_ != nullptr); + return session_impl_->CheckModelInputs(graph_id, inputs); +} + Session::Session() = default; } // namespace mindspore::inference diff --git a/mindspore/ccsrc/backend/session/session.h b/mindspore/ccsrc/backend/session/session.h index 6ea9cfaa474..0298b3379b2 100644 --- a/mindspore/ccsrc/backend/session/session.h +++ b/mindspore/ccsrc/backend/session/session.h @@ -37,6 +37,9 @@ class Session : public MSSession { MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; + bool CheckModelInputs(uint32_t graph_id, + const std::vector> &inputs) const override; + int Init(const std::string &device, uint32_t device_id); static void RegAllOp(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index a8ef0a7e1e3..09804dc0df7 100755 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -106,6 +106,9 @@ class SessionBasic { virtual void GetSummaryNodes(KernelGraph *graph); void AssignParamKey(const KernelGraphPtr &kernel_graph); void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector &inputs_const); + virtual bool CheckModelInputs(uint32_t graph_id, const std::vector> &inputs) { + return true; + } #ifdef ENABLE_DEBUGGER // set debugger diff --git a/serving/core/server.cc b/serving/core/server.cc index 5ba7ad36a7f..0ec2385f94c 100644 --- a/serving/core/server.cc +++ b/serving/core/server.cc @@ -67,6 +67,11 @@ Status Session::Predict(const std::vector &inputs, inference::Multi std::lock_guard lock(mutex_); MS_LOG(INFO) << "run Predict"; + if (!session_->CheckModelInputs(graph_id_, inputs)) { + MS_LOG(ERROR) << "Input error."; + return FAILED; + } + *outputs = session_->RunGraph(graph_id_, inputs); MS_LOG(INFO) << "run Predict finished"; return SUCCESS;