forked from mindspore-Ecosystem/mindspore
!1394 Fix comment error and mod parameter check in graphdata
Merge pull request !1394 from heleiwang/fix_comments_error
This commit is contained in:
commit
ad9db524bb
|
@ -20,12 +20,13 @@ import numpy as np
|
||||||
from mindspore._c_dataengine import Graph
|
from mindspore._c_dataengine import Graph
|
||||||
from mindspore._c_dataengine import Tensor
|
from mindspore._c_dataengine import Tensor
|
||||||
|
|
||||||
from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature
|
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \
|
||||||
|
check_gnn_get_node_feature
|
||||||
|
|
||||||
|
|
||||||
class GraphData:
|
class GraphData:
|
||||||
"""
|
"""
|
||||||
Reads th graph dataset used for GNN training from the shared file and database.
|
Reads the graph dataset used for GNN training from the shared file and database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_file (str): One of file names in dataset.
|
dataset_file (str): One of file names in dataset.
|
||||||
|
@ -33,6 +34,7 @@ class GraphData:
|
||||||
(default=None).
|
(default=None).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@check_gnn_graphdata
|
||||||
def __init__(self, dataset_file, num_parallel_workers=None):
|
def __init__(self, dataset_file, num_parallel_workers=None):
|
||||||
self._dataset_file = dataset_file
|
self._dataset_file = dataset_file
|
||||||
if num_parallel_workers is None:
|
if num_parallel_workers is None:
|
||||||
|
@ -45,7 +47,7 @@ class GraphData:
|
||||||
Get all nodes in the graph.
|
Get all nodes in the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_type (int): Specify the tpye of node.
|
node_type (int): Specify the type of node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: array of nodes.
|
numpy.ndarray: array of nodes.
|
||||||
|
@ -67,7 +69,7 @@ class GraphData:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_list (list or numpy.ndarray): The given list of nodes.
|
node_list (list or numpy.ndarray): The given list of nodes.
|
||||||
neighbor_type (int): Specify the tpye of neighbor.
|
neighbor_type (int): Specify the type of neighbor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: array of nodes.
|
numpy.ndarray: array of nodes.
|
||||||
|
|
|
@ -19,6 +19,7 @@ import inspect as ins
|
||||||
import os
|
import os
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
import numpy as np
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
from . import samplers
|
from . import samplers
|
||||||
from . import datasets
|
from . import datasets
|
||||||
|
@ -1075,14 +1076,48 @@ def check_split(method):
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_list_or_ndarray(param, param_name):
|
def check_gnn_graphdata(method):
|
||||||
if (not isinstance(param, list)) and (not hasattr(param, 'tolist')):
|
"""check the input arguments of graphdata."""
|
||||||
raise TypeError("Wrong input type for {0}, should be list, got {1}".format(
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(*args, **kwargs):
|
||||||
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
|
||||||
|
# check dataset_file; required argument
|
||||||
|
dataset_file = param_dict.get('dataset_file')
|
||||||
|
if dataset_file is None:
|
||||||
|
raise ValueError("dataset_file is not provided.")
|
||||||
|
check_dataset_file(dataset_file)
|
||||||
|
|
||||||
|
nreq_param_int = ['num_parallel_workers']
|
||||||
|
|
||||||
|
check_param_type(nreq_param_int, param_dict, int)
|
||||||
|
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_gnn_list_or_ndarray(param, param_name):
|
||||||
|
"""Check if the input parameter is list or numpy.ndarray."""
|
||||||
|
|
||||||
|
if isinstance(param, list):
|
||||||
|
for m in param:
|
||||||
|
if not isinstance(m, int):
|
||||||
|
raise TypeError(
|
||||||
|
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
||||||
|
elif isinstance(param, np.ndarray):
|
||||||
|
if not param.dtype == np.int32:
|
||||||
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||||
|
param_name, param.dtype))
|
||||||
|
else:
|
||||||
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||||
param_name, type(param)))
|
param_name, type(param)))
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_get_all_nodes(method):
|
def check_gnn_get_all_nodes(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(*args, **kwargs):
|
def new_method(*args, **kwargs):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
|
||||||
# check node_list; required argument
|
# check node_list; required argument
|
||||||
check_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||||
|
|
||||||
# check neighbor_type; required argument
|
# check neighbor_type; required argument
|
||||||
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
|
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
|
||||||
|
@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method):
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_aligned_list(param, param_name):
|
def check_aligned_list(param, param_name, membor_type):
|
||||||
"""Check whether the structure of each member of the list is the same."""
|
"""Check whether the structure of each member of the list is the same."""
|
||||||
|
|
||||||
if not isinstance(param, list):
|
if not isinstance(param, list):
|
||||||
raise TypeError("Parameter {0} is not a list".format(param_name))
|
raise TypeError("Parameter {0} is not a list".format(param_name))
|
||||||
membor_have_list = None
|
membor_have_list = None
|
||||||
list_len = None
|
list_len = None
|
||||||
for membor in param:
|
for membor in param:
|
||||||
if isinstance(membor, list):
|
if isinstance(membor, list):
|
||||||
check_aligned_list(membor, param_name)
|
check_aligned_list(membor, param_name, membor_type)
|
||||||
if membor_have_list not in (None, True):
|
if membor_have_list not in (None, True):
|
||||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||||
param_name))
|
param_name))
|
||||||
|
@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name):
|
||||||
membor_have_list = True
|
membor_have_list = True
|
||||||
list_len = len(membor)
|
list_len = len(membor)
|
||||||
else:
|
else:
|
||||||
|
if not isinstance(membor, membor_type):
|
||||||
|
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
|
||||||
|
param_name, type(membor)))
|
||||||
if membor_have_list not in (None, False):
|
if membor_have_list not in (None, False):
|
||||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||||
param_name))
|
param_name))
|
||||||
|
@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name):
|
||||||
|
|
||||||
def check_gnn_get_node_feature(method):
|
def check_gnn_get_node_feature(method):
|
||||||
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(*args, **kwargs):
|
def new_method(*args, **kwargs):
|
||||||
param_dict = make_param_dict(method, args, kwargs)
|
param_dict = make_param_dict(method, args, kwargs)
|
||||||
|
|
||||||
# check node_list; required argument
|
# check node_list; required argument
|
||||||
node_list = param_dict.get("node_list")
|
node_list = param_dict.get("node_list")
|
||||||
check_list_or_ndarray(node_list, 'node_list')
|
|
||||||
if isinstance(node_list, list):
|
if isinstance(node_list, list):
|
||||||
check_aligned_list(node_list, 'node_list')
|
check_aligned_list(node_list, 'node_list', int)
|
||||||
|
elif isinstance(node_list, np.ndarray):
|
||||||
|
if not node_list.dtype == np.int32:
|
||||||
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||||
|
node_list, node_list.dtype))
|
||||||
|
else:
|
||||||
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||||
|
'node_list', type(node_list)))
|
||||||
|
|
||||||
# check feature_types; required argument
|
# check feature_types; required argument
|
||||||
check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types')
|
check_gnn_list_or_ndarray(param_dict.get(
|
||||||
|
"feature_types"), 'feature_types')
|
||||||
|
|
||||||
return method(*args, **kwargs)
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
@ -23,8 +24,7 @@ def test_graphdata_getfullneighbor():
|
||||||
g = ds.GraphData(DATASET_FILE, 2)
|
g = ds.GraphData(DATASET_FILE, 2)
|
||||||
nodes = g.get_all_nodes(1)
|
nodes = g.get_all_nodes(1)
|
||||||
assert len(nodes) == 10
|
assert len(nodes) == 10
|
||||||
nodes_list = nodes.tolist()
|
neighbor = g.get_all_neighbors(nodes, 2)
|
||||||
neighbor = g.get_all_neighbors(nodes_list, 2)
|
|
||||||
assert neighbor.shape == (10, 6)
|
assert neighbor.shape == (10, 6)
|
||||||
row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
|
row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
|
||||||
assert row_tensor[0].shape == (10, 6)
|
assert row_tensor[0].shape == (10, 6)
|
||||||
|
@ -60,6 +60,14 @@ def test_graphdata_getnodefeature_input_check():
|
||||||
input_list = [[1, 1], [1, 1]]
|
input_list = [[1, 1], [1, 1]]
|
||||||
g.get_node_feature(input_list, 1)
|
g.get_node_feature(input_list, 1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
input_list = [[1, 0.1], [1, 1]]
|
||||||
|
g.get_node_feature(input_list, 1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
input_list = np.array([[1, 0.1], [1, 1]])
|
||||||
|
g.get_node_feature(input_list, 1)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
input_list = [[1, 1], [1, 1]]
|
input_list = [[1, 1], [1, 1]]
|
||||||
g.get_node_feature(input_list, ["a"])
|
g.get_node_feature(input_list, ["a"])
|
||||||
|
|
Loading…
Reference in New Issue