forked from mindspore-Ecosystem/mindspore
iii
This commit is contained in:
parent
d541e261a0
commit
a7de095d51
|
@ -14,6 +14,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
import grpc
|
import grpc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ms_service_pb2
|
import ms_service_pb2
|
||||||
|
@ -61,6 +63,8 @@ def test_bert():
|
||||||
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
|
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
|
||||||
segment_ids = np.zeros((2, 32), dtype=np.int32)
|
segment_ids = np.zeros((2, 32), dtype=np.int32)
|
||||||
input_mask = np.zeros((2, 32), dtype=np.int32)
|
input_mask = np.zeros((2, 32), dtype=np.int32)
|
||||||
|
|
||||||
|
# grpc visit
|
||||||
channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
||||||
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
|
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
|
||||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||||
|
@ -82,17 +86,39 @@ def test_bert():
|
||||||
z.data = input_mask.tobytes()
|
z.data = input_mask.tobytes()
|
||||||
|
|
||||||
result = stub.Predict(request)
|
result = stub.Predict(request)
|
||||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||||
print("ms client received: ")
|
print("ms grpc client received: ")
|
||||||
print(result_np)
|
print(grpc_result)
|
||||||
|
|
||||||
|
# ms result
|
||||||
net = BertModel(bert_net_cfg, False)
|
net = BertModel(bert_net_cfg, False)
|
||||||
bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
|
bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
|
||||||
print("bert out: ")
|
print("bert out: ")
|
||||||
print(bert_out[0])
|
print(bert_out[0])
|
||||||
bert_out_size = len(bert_out)
|
bert_out_size = len(bert_out)
|
||||||
|
|
||||||
|
# compare grpc result
|
||||||
for i in range(bert_out_size):
|
for i in range(bert_out_size):
|
||||||
result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
|
grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
|
||||||
logger.info("i:{}, result_np:{}, bert_out:{}".
|
logger.info("i:{}, grpc_result:{}, bert_out:{}".
|
||||||
format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
|
format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
|
||||||
assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True)
|
assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True)
|
||||||
|
|
||||||
|
# http visit
|
||||||
|
data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]}
|
||||||
|
url = "http://127.0.0.1:5501"
|
||||||
|
input_json = json.dumps(data)
|
||||||
|
headers = {'Content-type': 'application/json'}
|
||||||
|
response = requests.post(url, data=input_json, headers=headers)
|
||||||
|
result = response.text
|
||||||
|
result = result.replace('\r', '\\r').replace('\n', '\\n')
|
||||||
|
result_json = json.loads(result, strict=False)
|
||||||
|
http_result = np.array(result_json['tensor'])
|
||||||
|
print("ms http client received: ")
|
||||||
|
print(http_result[0][:200])
|
||||||
|
|
||||||
|
# compare http result
|
||||||
|
for i in range(bert_out_size):
|
||||||
|
logger.info("i:{}, http_result:{}, bert_out:{}".
|
||||||
|
format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape))
|
||||||
|
assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True)
|
||||||
|
|
|
@ -26,8 +26,8 @@ from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
seq_length=32,
|
seq_length=32,
|
||||||
vocab_size=21128,
|
vocab_size=12,
|
||||||
hidden_size=768,
|
hidden_size=12,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
intermediate_size=3072,
|
intermediate_size=3072,
|
||||||
|
|
|
@ -61,13 +61,15 @@ start_service()
|
||||||
echo "$2 faile to start."
|
echo "$2 faile to start."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
|
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
||||||
|
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
||||||
count=0
|
count=0
|
||||||
while [[ ${result} -ne 1 && ${count} -lt 150 ]]
|
while [[ ${result} -ne 1 && ${count} -lt 150 ]]
|
||||||
do
|
do
|
||||||
sleep 1
|
sleep 1
|
||||||
count=$(($count+1))
|
count=$(($count+1))
|
||||||
result=`grep -E 'MS Serving grpc listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l`
|
result=`grep -E -A5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
||||||
|
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
||||||
done
|
done
|
||||||
|
|
||||||
if [ ${count} -eq 150 ]
|
if [ ${count} -eq 150 ]
|
||||||
|
|
Loading…
Reference in New Issue