!4916 fix generator_dataset hangs and test_graphdata_distributed.py failing randomly

Merge pull request !4916 from heleiwang/gnn_fix_bug
This commit is contained in:
mindspore-ci-bot 2020-08-22 10:51:01 +08:00 committed by Gitee
commit ac81886328
3 changed files with 24 additions and 13 deletions

View File

@ -3222,12 +3222,14 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
while True: while True:
# Fetch index, block # Fetch index, block
try: try:
idx = idx_queue.get(timeout=10) # Index is generated very fast, so the timeout is very short
idx = idx_queue.get(timeout=0.01)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty: except queue.Empty:
if eof.is_set() or eoe.is_set(): if eof.is_set() or eoe.is_set():
raise Exception("Generator worker receives queue.Empty") return
# If eoe or eof is not set, continue to get data from idx_queue
continue continue
if idx is None: if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue. # When the queue is out of scope from master process, a None item can be fetched from the queue.
@ -3239,10 +3241,17 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx] result = dataset[idx]
# Send data, block # Send data, block
while True:
try: try:
result_queue.put(result) result_queue.put(result, timeout=5)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Full:
if eof.is_set():
return
# If eof is not set, continue to put data to result_queue
continue
break
del result, idx del result, idx
if eoe.is_set() and idx_queue.empty(): if eoe.is_set() and idx_queue.empty():
return return

View File

@ -929,7 +929,7 @@ def check_split(method):
def check_hostname(hostname): def check_hostname(hostname):
if len(hostname) > 255: if not hostname or len(hostname) > 255:
return False return False
if hostname[-1] == ".": if hostname[-1] == ".":
hostname = hostname[:-1] # strip exactly one dot from the right, if present hostname = hostname[:-1] # strip exactly one dot from the right, if present
@ -952,7 +952,7 @@ def check_gnn_graphdata(method):
raise ValueError("The hostname is illegal") raise ValueError("The hostname is illegal")
type_check(working_mode, (str,), "working_mode") type_check(working_mode, (str,), "working_mode")
if working_mode not in {'local', 'client', 'server'}: if working_mode not in {'local', 'client', 'server'}:
raise ValueError("Invalid working mode") raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
type_check(port, (int,), "port") type_check(port, (int,), "port")
check_value(port, (1024, 65535), "port") check_value(port, (1024, 65535), "port")
type_check(num_client, (int,), "num_client") type_check(num_client, (int,), "num_client")

View File

@ -23,12 +23,12 @@ from mindspore import log as logger
DATASET_FILE = "../data/mindrecord/testGraphData/testdata" DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
def graphdata_startserver(): def graphdata_startserver(server_port):
""" """
start graphdata server start graphdata server
""" """
logger.info('test start server.\n') logger.info('test start server.\n')
ds.GraphData(DATASET_FILE, 1, 'server') ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
class RandomBatchedSampler(ds.Sampler): class RandomBatchedSampler(ds.Sampler):
@ -83,11 +83,13 @@ def test_graphdata_distributed():
""" """
logger.info('test distributed.\n') logger.info('test distributed.\n')
p1 = Process(target=graphdata_startserver) server_port = random.randint(10000, 60000)
p1 = Process(target=graphdata_startserver, args=(server_port,))
p1.start() p1.start()
time.sleep(2) time.sleep(2)
g = ds.GraphData(DATASET_FILE, 1, 'client') g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
nodes = g.get_all_nodes(1) nodes = g.get_all_nodes(1)
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])