From a7de095d51f61a0509124602c938b6d85975ca2b Mon Sep 17 00:00:00 2001 From: hexia Date: Thu, 20 Aug 2020 15:21:17 +0800 Subject: [PATCH] iii --- tests/st/serving/client_example.py | 38 +++++++++++++++++++++++++----- tests/st/serving/generate_model.py | 4 ++-- tests/st/serving/serving.sh | 6 +++-- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/tests/st/serving/client_example.py b/tests/st/serving/client_example.py index 164852acee4..c5b9c7ef46a 100644 --- a/tests/st/serving/client_example.py +++ b/tests/st/serving/client_example.py @@ -14,6 +14,8 @@ # ============================================================================ import random +import json +import requests import grpc import numpy as np import ms_service_pb2 @@ -61,6 +63,8 @@ def test_bert(): input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32) segment_ids = np.zeros((2, 32), dtype=np.int32) input_mask = np.zeros((2, 32), dtype=np.int32) + + # grpc visit channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)]) stub = ms_service_pb2_grpc.MSServiceStub(channel) @@ -82,17 +86,39 @@ def test_bert(): z.data = input_mask.tobytes() result = stub.Predict(request) - result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) - print("ms client received: ") - print(result_np) + grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) + print("ms grpc client received: ") + print(grpc_result) + # ms result net = BertModel(bert_net_cfg, False) bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask)) print("bert out: ") print(bert_out[0]) bert_out_size = len(bert_out) + + # compare grpc result for i in range(bert_out_size): - result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims) - logger.info("i:{}, result_np:{}, bert_out:{}". + grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims) + logger.info("i:{}, grpc_result:{}, bert_out:{}". format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape)) - assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True) + assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True) + + # http visit + data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]} + url = "http://127.0.0.1:5501" + input_json = json.dumps(data) + headers = {'Content-type': 'application/json'} + response = requests.post(url, data=input_json, headers=headers) + result = response.text + result = result.replace('\r', '\\r').replace('\n', '\\n') + result_json = json.loads(result, strict=False) + http_result = np.array(result_json['tensor']) + print("ms http client received: ") + print(http_result[0][:200]) + + # compare http result + for i in range(bert_out_size): + logger.info("i:{}, http_result:{}, bert_out:{}". + format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape)) + assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True) diff --git a/tests/st/serving/generate_model.py b/tests/st/serving/generate_model.py index 2fdee2344a7..807862715d6 100644 --- a/tests/st/serving/generate_model.py +++ b/tests/st/serving/generate_model.py @@ -26,8 +26,8 @@ from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig bert_net_cfg = BertConfig( batch_size=2, seq_length=32, - vocab_size=21128, - hidden_size=768, + vocab_size=12, + hidden_size=12, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, diff --git a/tests/st/serving/serving.sh b/tests/st/serving/serving.sh index 95cae5305f7..1b77db34d57 100644 --- a/tests/st/serving/serving.sh +++ b/tests/st/serving/serving.sh @@ -61,13 +61,15 @@ start_service() echo "$2 faile to start." fi - result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` + result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log | + grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l` count=0 while [[ ${result} -ne 1 && ${count} -lt 150 ]] do sleep 1 count=$(($count+1)) - result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` + result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log | + grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l` done if [ ${count} -eq 150 ]