forked from mindspore-Ecosystem/mindspore
!1394 Fix comment error and mod parameter check in graphdata
Merge pull request !1394 from heleiwang/fix_comments_error
This commit is contained in:
commit
ad9db524bb
|
@ -20,12 +20,13 @@ import numpy as np
|
|||
from mindspore._c_dataengine import Graph
|
||||
from mindspore._c_dataengine import Tensor
|
||||
|
||||
from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \
|
||||
check_gnn_get_node_feature
|
||||
|
||||
|
||||
class GraphData:
|
||||
"""
|
||||
Reads th graph dataset used for GNN training from the shared file and database.
|
||||
Reads the graph dataset used for GNN training from the shared file and database.
|
||||
|
||||
Args:
|
||||
dataset_file (str): One of file names in dataset.
|
||||
|
@ -33,6 +34,7 @@ class GraphData:
|
|||
(default=None).
|
||||
"""
|
||||
|
||||
@check_gnn_graphdata
|
||||
def __init__(self, dataset_file, num_parallel_workers=None):
|
||||
self._dataset_file = dataset_file
|
||||
if num_parallel_workers is None:
|
||||
|
@ -45,7 +47,7 @@ class GraphData:
|
|||
Get all nodes in the graph.
|
||||
|
||||
Args:
|
||||
node_type (int): Specify the tpye of node.
|
||||
node_type (int): Specify the type of node.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
@ -67,7 +69,7 @@ class GraphData:
|
|||
|
||||
Args:
|
||||
node_list (list or numpy.ndarray): The given list of nodes.
|
||||
neighbor_type (int): Specify the tpye of neighbor.
|
||||
neighbor_type (int): Specify the type of neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
|
|
@ -19,6 +19,7 @@ import inspect as ins
|
|||
import os
|
||||
from functools import wraps
|
||||
from multiprocessing import cpu_count
|
||||
import numpy as np
|
||||
from mindspore._c_expression import typing
|
||||
from . import samplers
|
||||
from . import datasets
|
||||
|
@ -1075,14 +1076,48 @@ def check_split(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_list_or_ndarray(param, param_name):
|
||||
if (not isinstance(param, list)) and (not hasattr(param, 'tolist')):
|
||||
raise TypeError("Wrong input type for {0}, should be list, got {1}".format(
|
||||
def check_gnn_graphdata(method):
|
||||
"""check the input arguments of graphdata."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check dataset_file; required argument
|
||||
dataset_file = param_dict.get('dataset_file')
|
||||
if dataset_file is None:
|
||||
raise ValueError("dataset_file is not provided.")
|
||||
check_dataset_file(dataset_file)
|
||||
|
||||
nreq_param_int = ['num_parallel_workers']
|
||||
|
||||
check_param_type(nreq_param_int, param_dict, int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_list_or_ndarray(param, param_name):
|
||||
"""Check if the input parameter is list or numpy.ndarray."""
|
||||
|
||||
if isinstance(param, list):
|
||||
for m in param:
|
||||
if not isinstance(m, int):
|
||||
raise TypeError(
|
||||
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
||||
elif isinstance(param, np.ndarray):
|
||||
if not param.dtype == np.int32:
|
||||
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||
param_name, param.dtype))
|
||||
else:
|
||||
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||
param_name, type(param)))
|
||||
|
||||
|
||||
def check_gnn_get_all_nodes(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method):
|
|||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
check_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||
|
||||
# check neighbor_type; required argument
|
||||
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
|
||||
|
@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_aligned_list(param, param_name):
|
||||
def check_aligned_list(param, param_name, membor_type):
|
||||
"""Check whether the structure of each member of the list is the same."""
|
||||
|
||||
if not isinstance(param, list):
|
||||
raise TypeError("Parameter {0} is not a list".format(param_name))
|
||||
membor_have_list = None
|
||||
list_len = None
|
||||
for membor in param:
|
||||
if isinstance(membor, list):
|
||||
check_aligned_list(membor, param_name)
|
||||
check_aligned_list(membor, param_name, membor_type)
|
||||
if membor_have_list not in (None, True):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
|
@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name):
|
|||
membor_have_list = True
|
||||
list_len = len(membor)
|
||||
else:
|
||||
if not isinstance(membor, membor_type):
|
||||
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
|
||||
param_name, type(membor)))
|
||||
if membor_have_list not in (None, False):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
|
@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name):
|
|||
|
||||
def check_gnn_get_node_feature(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
node_list = param_dict.get("node_list")
|
||||
check_list_or_ndarray(node_list, 'node_list')
|
||||
if isinstance(node_list, list):
|
||||
check_aligned_list(node_list, 'node_list')
|
||||
check_aligned_list(node_list, 'node_list', int)
|
||||
elif isinstance(node_list, np.ndarray):
|
||||
if not node_list.dtype == np.int32:
|
||||
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||
node_list, node_list.dtype))
|
||||
else:
|
||||
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||
'node_list', type(node_list)))
|
||||
|
||||
# check feature_types; required argument
|
||||
check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types')
|
||||
check_gnn_list_or_ndarray(param_dict.get(
|
||||
"feature_types"), 'feature_types')
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -23,8 +24,7 @@ def test_graphdata_getfullneighbor():
|
|||
g = ds.GraphData(DATASET_FILE, 2)
|
||||
nodes = g.get_all_nodes(1)
|
||||
assert len(nodes) == 10
|
||||
nodes_list = nodes.tolist()
|
||||
neighbor = g.get_all_neighbors(nodes_list, 2)
|
||||
neighbor = g.get_all_neighbors(nodes, 2)
|
||||
assert neighbor.shape == (10, 6)
|
||||
row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
|
||||
assert row_tensor[0].shape == (10, 6)
|
||||
|
@ -60,6 +60,14 @@ def test_graphdata_getnodefeature_input_check():
|
|||
input_list = [[1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, 1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 0.1], [1, 1]]
|
||||
g.get_node_feature(input_list, 1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = np.array([[1, 0.1], [1, 1]])
|
||||
g.get_node_feature(input_list, 1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, ["a"])
|
||||
|
|
Loading…
Reference in New Issue