forked from mindspore-Ecosystem/mindspore
!38743 fix param check in graph api
Merge pull request !38743 from ms_yan/graph_small_fix
This commit is contained in:
commit
fc712ac363
|
@ -193,7 +193,7 @@ mindspore.dataset.Graph
|
|||
异常:
|
||||
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: get_graph_feature(edge_list, feature_types)
|
||||
.. py:method:: get_graph_feature(feature_types)
|
||||
|
||||
依据给定的 `feature_types` 获取存储在Graph中对应的特征。
|
||||
|
||||
|
|
|
@ -819,3 +819,11 @@ def check_dict(data, key_type, value_type, param_name):
|
|||
if not isinstance(value, value_type):
|
||||
raise TypeError("value of '{0}' in parameter {1} should be {2} type, but got: {3}"
|
||||
.format(key, param_name, value_type, type(value)))
|
||||
|
||||
|
||||
def check_feature_shape(data, shape, param_name):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if len(value.shape) != 2 or value.shape[0] != shape:
|
||||
raise ValueError("Shape of item '{0}' in '{1}' should be 2 dimension, and shape of first dimension "
|
||||
"should be: {2}, but got: {3}.".format(key, param_name, shape, value.shape))
|
||||
|
|
|
@ -626,10 +626,16 @@ class Graph(GraphData):
|
|||
if node_feat != dict():
|
||||
num_nodes = node_feat.get(list(node_feat.keys())[0]).shape[0]
|
||||
|
||||
node_type = replace_none(node_type, np.array(["0"] * num_nodes))
|
||||
node_type = np.array(node_type)
|
||||
if node_type is not None:
|
||||
node_type = np.array(node_type)
|
||||
if len(node_type.shape) != 1 or node_type.shape[0] != num_nodes:
|
||||
raise ValueError(
|
||||
"Input 'node_type' should be 1 dimension, and its length should be {}, but got {}.".format(
|
||||
num_nodes, len(node_type)))
|
||||
else:
|
||||
node_type = np.array(["0"] * num_nodes)
|
||||
|
||||
edge_type = replace_none(edge_type, np.array(["0"] * edges.shape[1]))
|
||||
edge_type = np.array(edge_type)
|
||||
|
||||
self._working_mode = working_mode
|
||||
self.data_format = "array"
|
||||
|
|
|
@ -28,7 +28,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
|||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
|
||||
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \
|
||||
check_valid_list_tuple, check_dict
|
||||
check_valid_list_tuple, check_dict, check_feature_shape
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -1861,28 +1861,32 @@ def check_gnn_graph(method):
|
|||
hostname, port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(edges, (list, np.ndarray), "edges")
|
||||
if len(np.array(edges).shape) != 2:
|
||||
raise ValueError(
|
||||
"Input 'edges' should be with 2 dimension, but got {} dimension.".format(len(np.array(edges).shape)))
|
||||
check_dict(node_feat, str, np.ndarray, "node_feat")
|
||||
check_dict(edge_feat, str, np.ndarray, "edge_feat")
|
||||
check_dict(graph_feat, str, np.ndarray, "graph_feat")
|
||||
if node_type:
|
||||
if node_type is not None:
|
||||
type_check(node_type, (list, np.ndarray), "node_type")
|
||||
if edge_type:
|
||||
type_check(edge_type, (None, list, np.ndarray), "edge_type")
|
||||
if not all(isinstance(item, str) for item in list(node_type)):
|
||||
raise TypeError("Type of each element in 'node_type' should be str.")
|
||||
if edge_type is not None:
|
||||
type_check(edge_type, (list, np.ndarray), "edge_type")
|
||||
edge_type = np.array(edge_type)
|
||||
if len(edge_type.shape) != 1 or edge_type.shape[0] != edges.shape[1]:
|
||||
raise ValueError(
|
||||
"Input 'edge_type' should be 1 dimension, and its length should be {}, but got {}.".format(
|
||||
edges.shape[1], edge_type.shape[0]))
|
||||
if not all(isinstance(item, str) for item in list(edge_type)):
|
||||
raise TypeError("Type of each element in 'edge_type' should be str.")
|
||||
|
||||
# check shape of node_feat and edge_feat
|
||||
num_nodes = np.max(edges) + 1
|
||||
if node_feat and isinstance(node_feat, dict):
|
||||
num_nodes = node_feat[list(node_feat.keys())[0]].shape[0]
|
||||
if node_feat:
|
||||
for key, value in node_feat.items():
|
||||
if len(value.shape) != 2 or value.shape[0] != num_nodes:
|
||||
raise ValueError("value of item '{0}' in node_feat should with shape [num_nodes, num_node_features]"
|
||||
"(here num_nodes is: {1}), but got: {2}".format(key, num_nodes, value.shape))
|
||||
if edge_feat:
|
||||
for key, value in edge_feat.items():
|
||||
if len(value.shape) != 2 or value.shape[0] != edges.shape[1]:
|
||||
raise ValueError("value of item '{0}' in edge_feat should with shape [num_edges, num_node_features]"
|
||||
"(here num_edges is: {1}), but got: {2}".format(key, edges.shape[1], value.shape))
|
||||
check_feature_shape(node_feat, num_nodes, "node_feat")
|
||||
check_feature_shape(edge_feat, edges.shape[1], "edge_feat")
|
||||
|
||||
if num_parallel_workers is not None:
|
||||
check_num_parallel_workers(num_parallel_workers)
|
||||
|
@ -1898,6 +1902,7 @@ def check_gnn_graph(method):
|
|||
check_value(num_client, (1, 255), "num_client")
|
||||
type_check(auto_shutdown, (bool,), "auto_shutdown")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue