serving: support scalar for tensors input
This commit is contained in:
parent
12f3665167
commit
f5d60dad7e
|
@ -148,7 +148,10 @@ bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const
|
||||||
vector<size_t> trans_input;
|
vector<size_t> trans_input;
|
||||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
|
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input),
|
||||||
[](const int dim) { return static_cast<size_t>(dim); });
|
[](const int dim) { return static_cast<size_t>(dim); });
|
||||||
if (trans_input != parameter_shape) {
|
auto is_scalar_shape = [](const vector<size_t> &shape) {
|
||||||
|
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
||||||
|
};
|
||||||
|
if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) {
|
||||||
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
|
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input)
|
||||||
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
|
<< ", but the parameter shape is " << PrintInputShape(parameter_shape)
|
||||||
<< ". parameter : " << parameter->DebugString();
|
<< ". parameter : " << parameter->DebugString();
|
||||||
|
|
|
@ -58,7 +58,7 @@ Status GetPostMessage(struct evhttp_request *const req, std::string *const buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckRequestValid(struct evhttp_request *const http_request) {
|
Status CheckRequestValid(const struct evhttp_request *const http_request) {
|
||||||
Status status(SUCCESS);
|
Status status(SUCCESS);
|
||||||
switch (evhttp_request_get_command(http_request)) {
|
switch (evhttp_request_get_command(http_request)) {
|
||||||
case EVHTTP_REQ_POST:
|
case EVHTTP_REQ_POST:
|
||||||
|
@ -96,6 +96,60 @@ Status CheckMessageValid(const json &message_info, HTTP_TYPE *const type) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
|
||||||
|
std::vector<int64_t> json_shape;
|
||||||
|
const json *tmp_json = &json_array;
|
||||||
|
while (tmp_json->is_array()) {
|
||||||
|
if (tmp_json->empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
json_shape.push_back(tmp_json->size());
|
||||||
|
tmp_json = &tmp_json->at(0);
|
||||||
|
}
|
||||||
|
return json_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetScalarDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, HTTP_DATA_TYPE type) {
|
||||||
|
Status status(SUCCESS);
|
||||||
|
auto type_name = [](const json &json_data) -> std::string {
|
||||||
|
if (json_data.is_number_integer()) {
|
||||||
|
return "integer";
|
||||||
|
} else if (json_data.is_number_float()) {
|
||||||
|
return "float";
|
||||||
|
}
|
||||||
|
return json_data.type_name();
|
||||||
|
};
|
||||||
|
const json *json_data = &json_data_array;
|
||||||
|
if (json_data_array.is_array()) {
|
||||||
|
if (json_data_array.size() != 1 || json_data_array[0].is_array()) {
|
||||||
|
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected scalar data is scalar or shape(1) array, "
|
||||||
|
"now array shape is "
|
||||||
|
<< GetJsonArrayShape(json_data_array);
|
||||||
|
MSI_LOG_ERROR << status.StatusMessage();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
json_data = &json_data_array.at(0);
|
||||||
|
}
|
||||||
|
if (type == HTTP_DATA_INT) {
|
||||||
|
auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data());
|
||||||
|
if (!json_data->is_number_integer()) {
|
||||||
|
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(*json_data);
|
||||||
|
MSI_LOG_ERROR << status.StatusMessage();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
data[0] = json_data->get<int32_t>();
|
||||||
|
} else if (type == HTTP_DATA_FLOAT) {
|
||||||
|
auto data = reinterpret_cast<float *>(request_tensor->mutable_data());
|
||||||
|
if (!json_data->is_number_float()) {
|
||||||
|
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(*json_data);
|
||||||
|
MSI_LOG_ERROR << status.StatusMessage();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
data[0] = json_data->get<float>();
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index,
|
Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index,
|
||||||
HTTP_DATA_TYPE type) {
|
HTTP_DATA_TYPE type) {
|
||||||
Status status(SUCCESS);
|
Status status(SUCCESS);
|
||||||
|
@ -173,19 +227,6 @@ Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *con
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
|
|
||||||
std::vector<int64_t> json_shape;
|
|
||||||
const json *tmp_json = &json_array;
|
|
||||||
while (tmp_json->is_array()) {
|
|
||||||
if (tmp_json->empty()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
json_shape.push_back(tmp_json->size());
|
|
||||||
tmp_json = &tmp_json->at(0);
|
|
||||||
}
|
|
||||||
return json_shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) {
|
Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) {
|
||||||
Status status = SUCCESS;
|
Status status = SUCCESS;
|
||||||
auto tensors = message_info.find(HTTP_DATA);
|
auto tensors = message_info.find(HTTP_DATA);
|
||||||
|
@ -266,21 +307,26 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
|
||||||
const auto &tensor = tensors->at(i);
|
const auto &tensor = tensors->at(i);
|
||||||
ServingTensor request_tensor(*(request->mutable_data(i)));
|
ServingTensor request_tensor(*(request->mutable_data(i)));
|
||||||
|
|
||||||
// check data shape
|
|
||||||
auto const &json_shape = GetJsonArrayShape(tensor);
|
|
||||||
if (json_shape != request_tensor.shape()) { // data shape not match
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "input " << i << " shape is invalid, expected " << request_tensor.shape() << ", given " << json_shape;
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto iter = infer_type2_http_type.find(request_tensor.data_type());
|
auto iter = infer_type2_http_type.find(request_tensor.data_type());
|
||||||
if (iter == infer_type2_http_type.end()) {
|
if (iter == infer_type2_http_type.end()) {
|
||||||
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
|
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
HTTP_DATA_TYPE type = iter->second;
|
HTTP_DATA_TYPE type = iter->second;
|
||||||
|
// check data shape
|
||||||
|
auto const &json_shape = GetJsonArrayShape(tensor);
|
||||||
|
auto is_scalar_shape = [](const std::vector<int64_t> &shape) {
|
||||||
|
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
||||||
|
};
|
||||||
|
if (is_scalar_shape(request_tensor.shape())) {
|
||||||
|
return GetScalarDataFromJson(tensor, &request_tensor, type);
|
||||||
|
} else {
|
||||||
|
if (json_shape != request_tensor.shape()) { // data shape not match
|
||||||
|
status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " shape is invalid, expected "
|
||||||
|
<< request_tensor.shape() << ", given " << json_shape;
|
||||||
|
MSI_LOG_ERROR << status.StatusMessage();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
size_t depth = 0;
|
size_t depth = 0;
|
||||||
size_t data_index = 0;
|
size_t data_index = 0;
|
||||||
status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
|
status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
|
||||||
|
@ -289,6 +335,7 @@ Status TransTensorToPredictRequest(const json &message_info, PredictRequest *con
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue