This commit is contained in:
hexia 2020-08-20 15:21:17 +08:00
parent d541e261a0
commit a7de095d51
3 changed files with 38 additions and 10 deletions

View File

@ -14,6 +14,8 @@
# ============================================================================
import random
import json
import requests
import grpc
import numpy as np
import ms_service_pb2
@ -61,6 +63,8 @@ def test_bert():
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
segment_ids = np.zeros((2, 32), dtype=np.int32)
input_mask = np.zeros((2, 32), dtype=np.int32)
# grpc visit
channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
stub = ms_service_pb2_grpc.MSServiceStub(channel)
@ -82,17 +86,39 @@ def test_bert():
z.data = input_mask.tobytes()
result = stub.Predict(request)
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print("ms client received: ")
print(result_np)
grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print("ms grpc client received: ")
print(grpc_result)
# ms result
net = BertModel(bert_net_cfg, False)
bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
print("bert out: ")
print(bert_out[0])
bert_out_size = len(bert_out)
# compare grpc result
for i in range(bert_out_size):
result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
logger.info("i:{}, result_np:{}, bert_out:{}".
grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
logger.info("i:{}, grpc_result:{}, bert_out:{}".
format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True)
assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True)
# http visit
data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]}
url = "http://127.0.0.1:5501"
input_json = json.dumps(data)
headers = {'Content-type': 'application/json'}
response = requests.post(url, data=input_json, headers=headers)
result = response.text
result = result.replace('\r', '\\r').replace('\n', '\\n')
result_json = json.loads(result, strict=False)
http_result = np.array(result_json['tensor'])
print("ms http client received: ")
print(http_result[0][:200])
# compare http result
for i in range(bert_out_size):
logger.info("i:{}, http_result:{}, bert_out:{}".
format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape))
assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True)

View File

@ -26,8 +26,8 @@ from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig
bert_net_cfg = BertConfig(
batch_size=2,
seq_length=32,
vocab_size=21128,
hidden_size=768,
vocab_size=12,
hidden_size=12,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,

View File

@ -61,13 +61,15 @@ start_service()
echo "$2 faile to start."
fi
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
count=0
while [[ ${result} -ne 1 && ${count} -lt 150 ]]
do
sleep 1
count=$(($count+1))
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
done
if [ ${count} -eq 150 ]