diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc index 9de909c62e2..7c614c47ab3 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -148,7 +148,10 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const vector trans_input; (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), [](const int dim) { return static_cast(dim); }); - if (trans_input != parameter_shape) { + auto is_scalar_shape = [](const vector &shape) { + return shape.empty() || (shape.size() == 1 && shape[0] == 1); + }; + if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) { MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) << ", but the parameter shape is " << PrintInputShape(parameter_shape) << ". parameter : " << parameter->DebugString(); diff --git a/serving/core/http_process.cc b/serving/core/http_process.cc index efc9b34cd31..e738dca933a 100644 --- a/serving/core/http_process.cc +++ b/serving/core/http_process.cc @@ -58,7 +58,7 @@ Status GetPostMessage(struct evhttp_request *const req, std::string *const buf) } } -Status CheckRequestValid(struct evhttp_request *const http_request) { +Status CheckRequestValid(const struct evhttp_request *const http_request) { Status status(SUCCESS); switch (evhttp_request_get_command(http_request)) { case EVHTTP_REQ_POST: @@ -96,6 +96,60 @@ Status CheckMessageValid(const json &message_info, HTTP_TYPE *const type) { return status; } +std::vector GetJsonArrayShape(const json &json_array) { + std::vector json_shape; + const json *tmp_json = &json_array; + while (tmp_json->is_array()) { + if (tmp_json->empty()) { + break; + } + json_shape.push_back(tmp_json->size()); + tmp_json = &tmp_json->at(0); + } + return json_shape; +} + +Status GetScalarDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, HTTP_DATA_TYPE type) { + Status status(SUCCESS); + auto type_name = [](const json &json_data) -> std::string { + if (json_data.is_number_integer()) { + return "integer"; + } else if (json_data.is_number_float()) { + return "float"; + } + return json_data.type_name(); + }; + const json *json_data = &json_data_array; + if (json_data_array.is_array()) { + if (json_data_array.size() != 1 || json_data_array[0].is_array()) { + status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected scalar data is scalar or shape(1) array, " + "now array shape is " + << GetJsonArrayShape(json_data_array); + MSI_LOG_ERROR << status.StatusMessage(); + return status; + } + json_data = &json_data_array.at(0); + } + if (type == HTTP_DATA_INT) { + auto data = reinterpret_cast(request_tensor->mutable_data()); + if (!json_data->is_number_integer()) { + status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(*json_data); + MSI_LOG_ERROR << status.StatusMessage(); + return status; + } + data[0] = json_data->get(); + } else if (type == HTTP_DATA_FLOAT) { + auto data = reinterpret_cast(request_tensor->mutable_data()); + if (!json_data->is_number_float()) { + status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(*json_data); + MSI_LOG_ERROR << status.StatusMessage(); + return status; + } + data[0] = json_data->get(); + } + return SUCCESS; +} + Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index, HTTP_DATA_TYPE type) { Status status(SUCCESS); @@ -173,19 +227,6 @@ Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *con return status; } -std::vector GetJsonArrayShape(const json &json_array) { - std::vector json_shape; - const json *tmp_json = &json_array; - while (tmp_json->is_array()) { - if (tmp_json->empty()) { - break; - } - json_shape.push_back(tmp_json->size()); - tmp_json = &tmp_json->at(0); - } - return json_shape; -} - Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) { Status status = SUCCESS; auto tensors = message_info.find(HTTP_DATA); @@ -266,27 +307,33 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con const auto &tensor = tensors->at(i); ServingTensor request_tensor(*(request->mutable_data(i))); - // check data shape - auto const &json_shape = GetJsonArrayShape(tensor); - if (json_shape != request_tensor.shape()) { // data shape not match - status = INFER_STATUS(INVALID_INPUTS) - << "input " << i << " shape is invalid, expected " << request_tensor.shape() << ", given " << json_shape; - MSI_LOG_ERROR << status.StatusMessage(); - return status; - } - auto iter = infer_type2_http_type.find(request_tensor.data_type()); if (iter == infer_type2_http_type.end()) { ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now"); return status; } HTTP_DATA_TYPE type = iter->second; - size_t depth = 0; - size_t data_index = 0; - status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type); - if (status != SUCCESS) { - MSI_LOG_ERROR << "Transfer tensor to predict request failed"; - return status; + // check data shape + auto const &json_shape = GetJsonArrayShape(tensor); + auto is_scalar_shape = [](const std::vector &shape) { + return shape.empty() || (shape.size() == 1 && shape[0] == 1); + }; + if (is_scalar_shape(request_tensor.shape())) { + return GetScalarDataFromJson(tensor, &request_tensor, type); + } else { + if (json_shape != request_tensor.shape()) { // data shape not match + status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " shape is invalid, expected " + << request_tensor.shape() << ", given " << json_shape; + MSI_LOG_ERROR << status.StatusMessage(); + return status; + } + size_t depth = 0; + size_t data_index = 0; + status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Transfer tensor to predict request failed"; + return status; + } } } return status;