fix comments error and modify parameter check

This commit is contained in:
heleiwang 2020-05-23 15:36:05 +08:00
parent 9ceea12636
commit f28f883cac
3 changed files with 72 additions and 15 deletions

View File

@ -20,12 +20,13 @@ import numpy as np
from mindspore._c_dataengine import Graph
from mindspore._c_dataengine import Tensor
from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \
check_gnn_get_node_feature
class GraphData:
"""
Reads th graph dataset used for GNN training from the shared file and database.
Reads the graph dataset used for GNN training from the shared file and database.
Args:
dataset_file (str): One of file names in dataset.
@ -33,6 +34,7 @@ class GraphData:
(default=None).
"""
@check_gnn_graphdata
def __init__(self, dataset_file, num_parallel_workers=None):
self._dataset_file = dataset_file
if num_parallel_workers is None:
@ -45,7 +47,7 @@ class GraphData:
Get all nodes in the graph.
Args:
node_type (int): Specify the tpye of node.
node_type (int): Specify the type of node.
Returns:
numpy.ndarray: array of nodes.
@ -67,7 +69,7 @@ class GraphData:
Args:
node_list (list or numpy.ndarray): The given list of nodes.
neighbor_type (int): Specify the tpye of neighbor.
neighbor_type (int): Specify the type of neighbor.
Returns:
numpy.ndarray: array of nodes.

View File

@ -19,6 +19,7 @@ import inspect as ins
import os
from functools import wraps
from multiprocessing import cpu_count
import numpy as np
from mindspore._c_expression import typing
from . import samplers
from . import datasets
@ -1075,14 +1076,48 @@ def check_split(method):
return new_method
def check_list_or_ndarray(param, param_name):
if (not isinstance(param, list)) and (not hasattr(param, 'tolist')):
raise TypeError("Wrong input type for {0}, should be list, got {1}".format(
def check_gnn_graphdata(method):
"""check the input arguments of graphdata."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check dataset_file; required argument
dataset_file = param_dict.get('dataset_file')
if dataset_file is None:
raise ValueError("dataset_file is not provided.")
check_dataset_file(dataset_file)
nreq_param_int = ['num_parallel_workers']
check_param_type(nreq_param_int, param_dict, int)
return method(*args, **kwargs)
return new_method
def check_gnn_list_or_ndarray(param, param_name):
"""Check if the input parameter is list or numpy.ndarray."""
if isinstance(param, list):
for m in param:
if not isinstance(m, int):
raise TypeError(
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
elif isinstance(param, np.ndarray):
if not param.dtype == np.int32:
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))
else:
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
param_name, type(param)))
def check_gnn_get_all_nodes(method):
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method):
param_dict = make_param_dict(method, args, kwargs)
# check node_list; required argument
check_list_or_ndarray(param_dict.get("node_list"), 'node_list')
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
# check neighbor_type; required argument
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method):
return new_method
def check_aligned_list(param, param_name):
def check_aligned_list(param, param_name, membor_type):
"""Check whether the structure of each member of the list is the same."""
if not isinstance(param, list):
raise TypeError("Parameter {0} is not a list".format(param_name))
membor_have_list = None
list_len = None
for membor in param:
if isinstance(membor, list):
check_aligned_list(membor, param_name)
check_aligned_list(membor, param_name, membor_type)
if membor_have_list not in (None, True):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name):
membor_have_list = True
list_len = len(membor)
else:
if not isinstance(membor, membor_type):
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
param_name, type(membor)))
if membor_have_list not in (None, False):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name):
def check_gnn_get_node_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
# check node_list; required argument
node_list = param_dict.get("node_list")
check_list_or_ndarray(node_list, 'node_list')
if isinstance(node_list, list):
check_aligned_list(node_list, 'node_list')
check_aligned_list(node_list, 'node_list', int)
elif isinstance(node_list, np.ndarray):
if not node_list.dtype == np.int32:
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
node_list, node_list.dtype))
else:
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
'node_list', type(node_list)))
# check feature_types; required argument
check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types')
check_gnn_list_or_ndarray(param_dict.get(
"feature_types"), 'feature_types')
return method(*args, **kwargs)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import pytest
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
@ -23,8 +24,7 @@ def test_graphdata_getfullneighbor():
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
nodes_list = nodes.tolist()
neighbor = g.get_all_neighbors(nodes_list, 2)
neighbor = g.get_all_neighbors(nodes, 2)
assert neighbor.shape == (10, 6)
row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
assert row_tensor[0].shape == (10, 6)
@ -60,6 +60,14 @@ def test_graphdata_getnodefeature_input_check():
input_list = [[1, 1], [1, 1]]
g.get_node_feature(input_list, 1)
with pytest.raises(TypeError):
input_list = [[1, 0.1], [1, 1]]
g.get_node_feature(input_list, 1)
with pytest.raises(TypeError):
input_list = np.array([[1, 0.1], [1, 1]])
g.get_node_feature(input_list, 1)
with pytest.raises(TypeError):
input_list = [[1, 1], [1, 1]]
g.get_node_feature(input_list, ["a"])