serving: support scalar for tensors input

This commit is contained in:
xuyongfei 2020-09-14 14:18:48 +08:00
parent 12f3665167
commit f5d60dad7e
2 changed files with 80 additions and 30 deletions

View File

@ -148,7 +148,10 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
vector<size_t> trans_input;
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
[](const int dim) { return static_cast<size_t>(dim); });
if (trans_input != parameter_shape) {
auto is_scalar_shape = [](const vector<size_t> &shape) {
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
};
if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) {
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
<< ". parameter : " << parameter->DebugString();

View File

@ -58,7 +58,7 @@ Status GetPostMessage(struct evhttp_request *const req, std::string *const buf)
}
}
Status CheckRequestValid(struct evhttp_request *const http_request) {
Status CheckRequestValid(const struct evhttp_request *const http_request) {
Status status(SUCCESS);
switch (evhttp_request_get_command(http_request)) {
case EVHTTP_REQ_POST:
@ -96,6 +96,60 @@ Status CheckMessageValid(const json &message_info, HTTP_TYPE *const type) {
return status;
}
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
std::vector<int64_t> json_shape;
const json *tmp_json = &json_array;
while (tmp_json->is_array()) {
if (tmp_json->empty()) {
break;
}
json_shape.push_back(tmp_json->size());
tmp_json = &tmp_json->at(0);
}
return json_shape;
}
Status GetScalarDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, HTTP_DATA_TYPE type) {
Status status(SUCCESS);
auto type_name = [](const json &json_data) -> std::string {
if (json_data.is_number_integer()) {
return "integer";
} else if (json_data.is_number_float()) {
return "float";
}
return json_data.type_name();
};
const json *json_data = &json_data_array;
if (json_data_array.is_array()) {
if (json_data_array.size() != 1 || json_data_array[0].is_array()) {
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected scalar data is scalar or shape(1) array, "
"now array shape is "
<< GetJsonArrayShape(json_data_array);
MSI_LOG_ERROR << status.StatusMessage();
return status;
}
json_data = &json_data_array.at(0);
}
if (type == HTTP_DATA_INT) {
auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data());
if (!json_data->is_number_integer()) {
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(*json_data);
MSI_LOG_ERROR << status.StatusMessage();
return status;
}
data[0] = json_data->get<int32_t>();
} else if (type == HTTP_DATA_FLOAT) {
auto data = reinterpret_cast<float *>(request_tensor->mutable_data());
if (!json_data->is_number_float()) {
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(*json_data);
MSI_LOG_ERROR << status.StatusMessage();
return status;
}
data[0] = json_data->get<float>();
}
return SUCCESS;
}
Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index,
HTTP_DATA_TYPE type) {
Status status(SUCCESS);
@ -173,19 +227,6 @@ Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *con
return status;
}
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
std::vector<int64_t> json_shape;
const json *tmp_json = &json_array;
while (tmp_json->is_array()) {
if (tmp_json->empty()) {
break;
}
json_shape.push_back(tmp_json->size());
tmp_json = &tmp_json->at(0);
}
return json_shape;
}
Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) {
Status status = SUCCESS;
auto tensors = message_info.find(HTTP_DATA);
@ -266,21 +307,26 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
const auto &tensor = tensors->at(i);
ServingTensor request_tensor(*(request->mutable_data(i)));
// check data shape
auto const &json_shape = GetJsonArrayShape(tensor);
if (json_shape != request_tensor.shape()) { // data shape not match
status = INFER_STATUS(INVALID_INPUTS)
<< "input " << i << " shape is invalid, expected " << request_tensor.shape() << ", given " << json_shape;
MSI_LOG_ERROR << status.StatusMessage();
return status;
}
auto iter = infer_type2_http_type.find(request_tensor.data_type());
if (iter == infer_type2_http_type.end()) {
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
return status;
}
HTTP_DATA_TYPE type = iter->second;
// check data shape
auto const &json_shape = GetJsonArrayShape(tensor);
auto is_scalar_shape = [](const std::vector<int64_t> &shape) {
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
};
if (is_scalar_shape(request_tensor.shape())) {
return GetScalarDataFromJson(tensor, &request_tensor, type);
} else {
if (json_shape != request_tensor.shape()) { // data shape not match
status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " shape is invalid, expected "
<< request_tensor.shape() << ", given " << json_shape;
MSI_LOG_ERROR << status.StatusMessage();
return status;
}
size_t depth = 0;
size_t data_index = 0;
status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
@ -289,6 +335,7 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
return status;
}
}
}
return status;
}