1. fix generator_dataset hangs
2. fix test_graphdata_distributed.py failing randomly
This commit is contained in:
parent
9a3baf4f6c
commit
4870abc848
|
@ -3217,12 +3217,14 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
|||
while True:
|
||||
# Fetch index, block
|
||||
try:
|
||||
idx = idx_queue.get(timeout=10)
|
||||
# Index is generated very fast, so the timeout is very short
|
||||
idx = idx_queue.get(timeout=0.01)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
except queue.Empty:
|
||||
if eof.is_set() or eoe.is_set():
|
||||
raise Exception("Generator worker receives queue.Empty")
|
||||
return
|
||||
# If eoe or eof is not set, continue to get data from idx_queue
|
||||
continue
|
||||
if idx is None:
|
||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||
|
@ -3234,10 +3236,17 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
|||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||
result = dataset[idx]
|
||||
# Send data, block
|
||||
try:
|
||||
result_queue.put(result)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
while True:
|
||||
try:
|
||||
result_queue.put(result, timeout=5)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
except queue.Full:
|
||||
if eof.is_set():
|
||||
return
|
||||
# If eof is not set, continue to put data to result_queue
|
||||
continue
|
||||
break
|
||||
del result, idx
|
||||
if eoe.is_set() and idx_queue.empty():
|
||||
return
|
||||
|
|
|
@ -929,10 +929,10 @@ def check_split(method):
|
|||
|
||||
|
||||
def check_hostname(hostname):
|
||||
if len(hostname) > 255:
|
||||
if not hostname or len(hostname) > 255:
|
||||
return False
|
||||
if hostname[-1] == ".":
|
||||
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
||||
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
||||
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||
return all(allowed.match(x) for x in hostname.split("."))
|
||||
|
||||
|
@ -952,7 +952,7 @@ def check_gnn_graphdata(method):
|
|||
raise ValueError("The hostname is illegal")
|
||||
type_check(working_mode, (str,), "working_mode")
|
||||
if working_mode not in {'local', 'client', 'server'}:
|
||||
raise ValueError("Invalid working mode")
|
||||
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
|
||||
type_check(port, (int,), "port")
|
||||
check_value(port, (1024, 65535), "port")
|
||||
type_check(num_client, (int,), "num_client")
|
||||
|
|
|
@ -23,12 +23,12 @@ from mindspore import log as logger
|
|||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
|
||||
|
||||
def graphdata_startserver():
|
||||
def graphdata_startserver(server_port):
|
||||
"""
|
||||
start graphdata server
|
||||
"""
|
||||
logger.info('test start server.\n')
|
||||
ds.GraphData(DATASET_FILE, 1, 'server')
|
||||
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
|
||||
|
||||
|
||||
class RandomBatchedSampler(ds.Sampler):
|
||||
|
@ -83,11 +83,13 @@ def test_graphdata_distributed():
|
|||
"""
|
||||
logger.info('test distributed.\n')
|
||||
|
||||
p1 = Process(target=graphdata_startserver)
|
||||
server_port = random.randint(10000, 60000)
|
||||
|
||||
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
||||
p1.start()
|
||||
time.sleep(2)
|
||||
|
||||
g = ds.GraphData(DATASET_FILE, 1, 'client')
|
||||
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
||||
nodes = g.get_all_nodes(1)
|
||||
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
||||
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
||||
|
|
Loading…
Reference in New Issue