forked from mindspore-Ecosystem/mindspore
check model input
This commit is contained in:
parent
4e0cfafcf9
commit
65da4463c1
|
@ -25,6 +25,7 @@
|
|||
namespace mindspore {
|
||||
class FuncGraph;
|
||||
namespace inference {
|
||||
using VectorForMSTensorPtr = std::vector<std::shared_ptr<inference::MSTensor>>;
|
||||
class MS_API MSSession {
|
||||
public:
|
||||
MSSession() = default;
|
||||
|
@ -33,7 +34,9 @@ class MS_API MSSession {
|
|||
|
||||
virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0;
|
||||
|
||||
virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) = 0;
|
||||
virtual MultiTensor RunGraph(uint32_t graph_id, const VectorForMSTensorPtr &inputs) = 0;
|
||||
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const VectorForMSTensorPtr &inputs) const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "backend/session/ascend_inference_session.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -85,5 +87,80 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
}
|
||||
return graph_id;
|
||||
}
|
||||
|
||||
bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor> > &inputs) {
|
||||
MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id;
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto kernel_graph_inputs = kernel_graph->inputs();
|
||||
size_t no_weight_input = 0;
|
||||
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
||||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
// compare input number
|
||||
if (no_weight_input >= inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size()
|
||||
<< "] less than that of graph.";
|
||||
return false;
|
||||
}
|
||||
auto input = inputs[no_weight_input++];
|
||||
if (!CompareInput(input, parameter)) {
|
||||
MS_LOG(ERROR) << "Please check the input information.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendInferenceSession::CompareInput(const std::shared_ptr<inference::MSTensor> &input,
|
||||
const ParameterPtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
// compare dims
|
||||
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||
if (input->shape().size() != parameter_shape.size()) {
|
||||
MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size()
|
||||
<< ", but the parameter dim is " << parameter_shape.size()
|
||||
<< ". parameter : " << parameter->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
// compare shape
|
||||
auto input_shape = input->shape();
|
||||
vector<size_t> trans_input;
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
|
||||
[](const int dim) { return static_cast<size_t>(dim); });
|
||||
if (trans_input != parameter_shape) {
|
||||
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
|
||||
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
|
||||
<< ". parameter : " << parameter->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
// compare data type
|
||||
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
|
||||
if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) {
|
||||
MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type()
|
||||
<< ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0)
|
||||
<< ". parameter : " << parameter->DebugString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) {
|
||||
string res = "[";
|
||||
for (auto dim : shape) {
|
||||
res += " " + std::to_string(dim);
|
||||
}
|
||||
return res + " ]";
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,6 +39,9 @@ class AscendInferenceSession : public AscendSession {
|
|||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override;
|
||||
bool CompareInput(const std::shared_ptr<inference::MSTensor> &input, const ParameterPtr ¶meter);
|
||||
std::string PrintInputShape(std::vector<size_t> shape);
|
||||
};
|
||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||
} // namespace session
|
||||
|
|
|
@ -204,5 +204,11 @@ int Session::Init(const std::string &device, uint32_t device_id) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
bool Session::CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
return session_impl_->CheckModelInputs(graph_id, inputs);
|
||||
}
|
||||
|
||||
Session::Session() = default;
|
||||
} // namespace mindspore::inference
|
||||
|
|
|
@ -37,6 +37,9 @@ class Session : public MSSession {
|
|||
|
||||
MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override;
|
||||
|
||||
bool CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const override;
|
||||
|
||||
int Init(const std::string &device, uint32_t device_id);
|
||||
|
||||
static void RegAllOp();
|
||||
|
|
|
@ -106,6 +106,9 @@ class SessionBasic {
|
|||
virtual void GetSummaryNodes(KernelGraph *graph);
|
||||
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
||||
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// set debugger
|
||||
|
|
|
@ -67,6 +67,11 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
|
|||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
MS_LOG(INFO) << "run Predict";
|
||||
|
||||
if (!session_->CheckModelInputs(graph_id_, inputs)) {
|
||||
MS_LOG(ERROR) << "Input error.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
*outputs = session_->RunGraph(graph_id_, inputs);
|
||||
MS_LOG(INFO) << "run Predict finished";
|
||||
return SUCCESS;
|
||||
|
|
Loading…
Reference in New Issue