forked from mindspore-Ecosystem/mindspore
!4916 fix generator_dataset hangs and test_graphdata_distributed.py failing randomly
Merge pull request !4916 from heleiwang/gnn_fix_bug
This commit is contained in:
commit
ac81886328
|
@ -3222,12 +3222,14 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
||||||
while True:
|
while True:
|
||||||
# Fetch index, block
|
# Fetch index, block
|
||||||
try:
|
try:
|
||||||
idx = idx_queue.get(timeout=10)
|
# Index is generated very fast, so the timeout is very short
|
||||||
|
idx = idx_queue.get(timeout=0.01)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
if eof.is_set() or eoe.is_set():
|
if eof.is_set() or eoe.is_set():
|
||||||
raise Exception("Generator worker receives queue.Empty")
|
return
|
||||||
|
# If eoe or eof is not set, continue to get data from idx_queue
|
||||||
continue
|
continue
|
||||||
if idx is None:
|
if idx is None:
|
||||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||||
|
@ -3239,10 +3241,17 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
||||||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||||
result = dataset[idx]
|
result = dataset[idx]
|
||||||
# Send data, block
|
# Send data, block
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
result_queue.put(result)
|
result_queue.put(result, timeout=5)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
except queue.Full:
|
||||||
|
if eof.is_set():
|
||||||
|
return
|
||||||
|
# If eof is not set, continue to put data to result_queue
|
||||||
|
continue
|
||||||
|
break
|
||||||
del result, idx
|
del result, idx
|
||||||
if eoe.is_set() and idx_queue.empty():
|
if eoe.is_set() and idx_queue.empty():
|
||||||
return
|
return
|
||||||
|
|
|
@ -929,7 +929,7 @@ def check_split(method):
|
||||||
|
|
||||||
|
|
||||||
def check_hostname(hostname):
|
def check_hostname(hostname):
|
||||||
if len(hostname) > 255:
|
if not hostname or len(hostname) > 255:
|
||||||
return False
|
return False
|
||||||
if hostname[-1] == ".":
|
if hostname[-1] == ".":
|
||||||
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
||||||
|
@ -952,7 +952,7 @@ def check_gnn_graphdata(method):
|
||||||
raise ValueError("The hostname is illegal")
|
raise ValueError("The hostname is illegal")
|
||||||
type_check(working_mode, (str,), "working_mode")
|
type_check(working_mode, (str,), "working_mode")
|
||||||
if working_mode not in {'local', 'client', 'server'}:
|
if working_mode not in {'local', 'client', 'server'}:
|
||||||
raise ValueError("Invalid working mode")
|
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
|
||||||
type_check(port, (int,), "port")
|
type_check(port, (int,), "port")
|
||||||
check_value(port, (1024, 65535), "port")
|
check_value(port, (1024, 65535), "port")
|
||||||
type_check(num_client, (int,), "num_client")
|
type_check(num_client, (int,), "num_client")
|
||||||
|
|
|
@ -23,12 +23,12 @@ from mindspore import log as logger
|
||||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||||
|
|
||||||
|
|
||||||
def graphdata_startserver():
|
def graphdata_startserver(server_port):
|
||||||
"""
|
"""
|
||||||
start graphdata server
|
start graphdata server
|
||||||
"""
|
"""
|
||||||
logger.info('test start server.\n')
|
logger.info('test start server.\n')
|
||||||
ds.GraphData(DATASET_FILE, 1, 'server')
|
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
|
||||||
|
|
||||||
|
|
||||||
class RandomBatchedSampler(ds.Sampler):
|
class RandomBatchedSampler(ds.Sampler):
|
||||||
|
@ -83,11 +83,13 @@ def test_graphdata_distributed():
|
||||||
"""
|
"""
|
||||||
logger.info('test distributed.\n')
|
logger.info('test distributed.\n')
|
||||||
|
|
||||||
p1 = Process(target=graphdata_startserver)
|
server_port = random.randint(10000, 60000)
|
||||||
|
|
||||||
|
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
||||||
p1.start()
|
p1.start()
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
g = ds.GraphData(DATASET_FILE, 1, 'client')
|
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
||||||
nodes = g.get_all_nodes(1)
|
nodes = g.get_all_nodes(1)
|
||||||
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
||||||
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
||||||
|
|
Loading…
Reference in New Issue