1. fix generator_dataset hangs

2. fix test_graphdata_distributed.py failing randomly
This commit is contained in:
heleiwang 2020-08-21 15:44:04 +08:00
parent 9a3baf4f6c
commit 4870abc848
3 changed files with 24 additions and 13 deletions

View File

@ -3217,12 +3217,14 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
while True:
# Fetch index, block
try:
idx = idx_queue.get(timeout=10)
# Index is generated very fast, so the timeout is very short
idx = idx_queue.get(timeout=0.01)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty:
if eof.is_set() or eoe.is_set():
raise Exception("Generator worker receives queue.Empty")
return
# If eoe or eof is not set, continue to get data from idx_queue
continue
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
@ -3234,10 +3236,17 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
while True:
try:
result_queue.put(result, timeout=5)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Full:
if eof.is_set():
return
# If eof is not set, continue to put data to result_queue
continue
break
del result, idx
if eoe.is_set() and idx_queue.empty():
return

View File

@ -929,10 +929,10 @@ def check_split(method):
def check_hostname(hostname):
if len(hostname) > 255:
if not hostname or len(hostname) > 255:
return False
if hostname[-1] == ".":
hostname = hostname[:-1] # strip exactly one dot from the right, if present
hostname = hostname[:-1] # strip exactly one dot from the right, if present
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in hostname.split("."))
@ -952,7 +952,7 @@ def check_gnn_graphdata(method):
raise ValueError("The hostname is illegal")
type_check(working_mode, (str,), "working_mode")
if working_mode not in {'local', 'client', 'server'}:
raise ValueError("Invalid working mode")
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
type_check(port, (int,), "port")
check_value(port, (1024, 65535), "port")
type_check(num_client, (int,), "num_client")

View File

@ -23,12 +23,12 @@ from mindspore import log as logger
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
def graphdata_startserver():
def graphdata_startserver(server_port):
"""
start graphdata server
"""
logger.info('test start server.\n')
ds.GraphData(DATASET_FILE, 1, 'server')
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
class RandomBatchedSampler(ds.Sampler):
@ -83,11 +83,13 @@ def test_graphdata_distributed():
"""
logger.info('test distributed.\n')
p1 = Process(target=graphdata_startserver)
server_port = random.randint(10000, 60000)
p1 = Process(target=graphdata_startserver, args=(server_port,))
p1.start()
time.sleep(2)
g = ds.GraphData(DATASET_FILE, 1, 'client')
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
nodes = g.get_all_nodes(1)
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])