forked from mindspore-Ecosystem/mindspore
iii
This commit is contained in:
parent
d541e261a0
commit
a7de095d51
|
@ -14,6 +14,8 @@
|
|||
# ============================================================================
|
||||
|
||||
import random
|
||||
import json
|
||||
import requests
|
||||
import grpc
|
||||
import numpy as np
|
||||
import ms_service_pb2
|
||||
|
@ -61,6 +63,8 @@ def test_bert():
|
|||
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
|
||||
segment_ids = np.zeros((2, 32), dtype=np.int32)
|
||||
input_mask = np.zeros((2, 32), dtype=np.int32)
|
||||
|
||||
# grpc visit
|
||||
channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
||||
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
|
||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||
|
@ -82,17 +86,39 @@ def test_bert():
|
|||
z.data = input_mask.tobytes()
|
||||
|
||||
result = stub.Predict(request)
|
||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||
print("ms client received: ")
|
||||
print(result_np)
|
||||
grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||
print("ms grpc client received: ")
|
||||
print(grpc_result)
|
||||
|
||||
# ms result
|
||||
net = BertModel(bert_net_cfg, False)
|
||||
bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
|
||||
print("bert out: ")
|
||||
print(bert_out[0])
|
||||
bert_out_size = len(bert_out)
|
||||
|
||||
# compare grpc result
|
||||
for i in range(bert_out_size):
|
||||
result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
|
||||
logger.info("i:{}, result_np:{}, bert_out:{}".
|
||||
grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
|
||||
logger.info("i:{}, grpc_result:{}, bert_out:{}".
|
||||
format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
|
||||
assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True)
|
||||
assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True)
|
||||
|
||||
# http visit
|
||||
data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]}
|
||||
url = "http://127.0.0.1:5501"
|
||||
input_json = json.dumps(data)
|
||||
headers = {'Content-type': 'application/json'}
|
||||
response = requests.post(url, data=input_json, headers=headers)
|
||||
result = response.text
|
||||
result = result.replace('\r', '\\r').replace('\n', '\\n')
|
||||
result_json = json.loads(result, strict=False)
|
||||
http_result = np.array(result_json['tensor'])
|
||||
print("ms http client received: ")
|
||||
print(http_result[0][:200])
|
||||
|
||||
# compare http result
|
||||
for i in range(bert_out_size):
|
||||
logger.info("i:{}, http_result:{}, bert_out:{}".
|
||||
format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape))
|
||||
assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True)
|
||||
|
|
|
@ -26,8 +26,8 @@ from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig
|
|||
bert_net_cfg = BertConfig(
|
||||
batch_size=2,
|
||||
seq_length=32,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
vocab_size=12,
|
||||
hidden_size=12,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
|
|
|
@ -61,13 +61,15 @@ start_service()
|
|||
echo "$2 faile to start."
|
||||
fi
|
||||
|
||||
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
|
||||
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
||||
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
||||
count=0
|
||||
while [[ ${result} -ne 1 && ${count} -lt 150 ]]
|
||||
do
|
||||
sleep 1
|
||||
count=$(($count+1))
|
||||
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
|
||||
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
||||
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
||||
done
|
||||
|
||||
if [ ${count} -eq 150 ]
|
||||
|
|
Loading…
Reference in New Issue