serving: support scalar for tensors input
This commit is contained in:
parent
12f3665167
commit
f5d60dad7e
|
@ -148,7 +148,10 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
|||
vector<size_t> trans_input;
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
|
||||
[](const int dim) { return static_cast<size_t>(dim); });
|
||||
if (trans_input != parameter_shape) {
|
||||
auto is_scalar_shape = [](const vector<size_t> &shape) {
|
||||
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
||||
};
|
||||
if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) {
|
||||
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
|
||||
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
|
||||
<< ". parameter : " << parameter->DebugString();
|
||||
|
|
|
@ -58,7 +58,7 @@ Status GetPostMessage(struct evhttp_request *const req, std::string *const buf)
|
|||
}
|
||||
}
|
||||
|
||||
Status CheckRequestValid(struct evhttp_request *const http_request) {
|
||||
Status CheckRequestValid(const struct evhttp_request *const http_request) {
|
||||
Status status(SUCCESS);
|
||||
switch (evhttp_request_get_command(http_request)) {
|
||||
case EVHTTP_REQ_POST:
|
||||
|
@ -96,6 +96,60 @@ Status CheckMessageValid(const json &message_info, HTTP_TYPE *const type) {
|
|||
return status;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
|
||||
std::vector<int64_t> json_shape;
|
||||
const json *tmp_json = &json_array;
|
||||
while (tmp_json->is_array()) {
|
||||
if (tmp_json->empty()) {
|
||||
break;
|
||||
}
|
||||
json_shape.push_back(tmp_json->size());
|
||||
tmp_json = &tmp_json->at(0);
|
||||
}
|
||||
return json_shape;
|
||||
}
|
||||
|
||||
Status GetScalarDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, HTTP_DATA_TYPE type) {
|
||||
Status status(SUCCESS);
|
||||
auto type_name = [](const json &json_data) -> std::string {
|
||||
if (json_data.is_number_integer()) {
|
||||
return "integer";
|
||||
} else if (json_data.is_number_float()) {
|
||||
return "float";
|
||||
}
|
||||
return json_data.type_name();
|
||||
};
|
||||
const json *json_data = &json_data_array;
|
||||
if (json_data_array.is_array()) {
|
||||
if (json_data_array.size() != 1 || json_data_array[0].is_array()) {
|
||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected scalar data is scalar or shape(1) array, "
|
||||
"now array shape is "
|
||||
<< GetJsonArrayShape(json_data_array);
|
||||
MSI_LOG_ERROR << status.StatusMessage();
|
||||
return status;
|
||||
}
|
||||
json_data = &json_data_array.at(0);
|
||||
}
|
||||
if (type == HTTP_DATA_INT) {
|
||||
auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data());
|
||||
if (!json_data->is_number_integer()) {
|
||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(*json_data);
|
||||
MSI_LOG_ERROR << status.StatusMessage();
|
||||
return status;
|
||||
}
|
||||
data[0] = json_data->get<int32_t>();
|
||||
} else if (type == HTTP_DATA_FLOAT) {
|
||||
auto data = reinterpret_cast<float *>(request_tensor->mutable_data());
|
||||
if (!json_data->is_number_float()) {
|
||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(*json_data);
|
||||
MSI_LOG_ERROR << status.StatusMessage();
|
||||
return status;
|
||||
}
|
||||
data[0] = json_data->get<float>();
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index,
|
||||
HTTP_DATA_TYPE type) {
|
||||
Status status(SUCCESS);
|
||||
|
@ -173,19 +227,6 @@ Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *con
|
|||
return status;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
|
||||
std::vector<int64_t> json_shape;
|
||||
const json *tmp_json = &json_array;
|
||||
while (tmp_json->is_array()) {
|
||||
if (tmp_json->empty()) {
|
||||
break;
|
||||
}
|
||||
json_shape.push_back(tmp_json->size());
|
||||
tmp_json = &tmp_json->at(0);
|
||||
}
|
||||
return json_shape;
|
||||
}
|
||||
|
||||
Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) {
|
||||
Status status = SUCCESS;
|
||||
auto tensors = message_info.find(HTTP_DATA);
|
||||
|
@ -266,21 +307,26 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
|
|||
const auto &tensor = tensors->at(i);
|
||||
ServingTensor request_tensor(*(request->mutable_data(i)));
|
||||
|
||||
// check data shape
|
||||
auto const &json_shape = GetJsonArrayShape(tensor);
|
||||
if (json_shape != request_tensor.shape()) { // data shape not match
|
||||
status = INFER_STATUS(INVALID_INPUTS)
|
||||
<< "input " << i << " shape is invalid, expected " << request_tensor.shape() << ", given " << json_shape;
|
||||
MSI_LOG_ERROR << status.StatusMessage();
|
||||
return status;
|
||||
}
|
||||
|
||||
auto iter = infer_type2_http_type.find(request_tensor.data_type());
|
||||
if (iter == infer_type2_http_type.end()) {
|
||||
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
|
||||
return status;
|
||||
}
|
||||
HTTP_DATA_TYPE type = iter->second;
|
||||
// check data shape
|
||||
auto const &json_shape = GetJsonArrayShape(tensor);
|
||||
auto is_scalar_shape = [](const std::vector<int64_t> &shape) {
|
||||
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
||||
};
|
||||
if (is_scalar_shape(request_tensor.shape())) {
|
||||
return GetScalarDataFromJson(tensor, &request_tensor, type);
|
||||
} else {
|
||||
if (json_shape != request_tensor.shape()) { // data shape not match
|
||||
status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " shape is invalid, expected "
|
||||
<< request_tensor.shape() << ", given " << json_shape;
|
||||
MSI_LOG_ERROR << status.StatusMessage();
|
||||
return status;
|
||||
}
|
||||
size_t depth = 0;
|
||||
size_t data_index = 0;
|
||||
status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
|
||||
|
@ -289,6 +335,7 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
|
|||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue