From 0d52888fc5a7abe04ff9bb9a8024e19e9832cc4a Mon Sep 17 00:00:00 2001 From: heleiwang Date: Tue, 23 Jun 2020 17:50:05 +0800 Subject: [PATCH] fix misspell and check parameters --- mindspore/ccsrc/dataset/engine/gnn/graph.cc | 29 ++++++++++++++ mindspore/ccsrc/dataset/engine/gnn/graph.h | 2 + mindspore/dataset/engine/validators.py | 42 ++++++++++----------- tests/ut/cpp/dataset/gnn_graph_test.cc | 22 ++++++++++- 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index 10176573973..a143bd4e386 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -149,14 +149,37 @@ Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType return Status::OK(); } +Status Graph::CheckSamplesNum(NodeIdType samples_num) { + NodeIdType all_nodes_number = + std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, + [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); + if ((samples_num < 1) || (samples_num > all_nodes_number)) { + std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + + ", got " + std::to_string(samples_num); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + Status Graph::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, const std::vector &neighbor_types, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), "The sizes of neighbor_nums and neighbor_types are inconsistent."); + for (const auto &num : neighbor_nums) { + RETURN_IF_NOT_OK(CheckSamplesNum(num)); + } + for (const auto &type : neighbor_types) { + if (node_type_map_.find(type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } std::vector> neighbors_vec(node_list.size()); for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr input_node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); neighbors_vec[node_idx].emplace_back(node_list[node_idx]); std::vector input_list = {node_list[node_idx]}; for (size_t i = 0; i < neighbor_nums.size(); ++i) { @@ -204,6 +227,12 @@ Status Graph::NegativeSample(const std::vector &data, const std::uno Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, NodeType neg_neighbor_type, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); + if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector> neighbors_vec; neighbors_vec.resize(node_list.size()); for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index ea103630536..344a6c6bf21 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -226,6 +226,8 @@ class Graph { Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, int32_t samples_num, std::vector *out_samples); + Status CheckSamplesNum(NodeIdType samples_num); + std::string dataset_file_; int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 005f7072aa2..5bfd7656d31 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1110,10 +1110,10 @@ def check_gnn_list_or_ndarray(param, param_name): for m in param: if not isinstance(m, int): raise TypeError( - "Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m))) + "Each member in {0} should be of type int. Got {1}.".format(param_name, type(m))) elif isinstance(param, np.ndarray): if not param.dtype == np.int32: - raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( param_name, param.dtype)) else: raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( @@ -1196,15 +1196,15 @@ def check_gnn_get_sampled_neighbors(method): # check neighbor_nums; required argument neighbor_nums = param_dict.get("neighbor_nums") check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') - if len(neighbor_nums) > 6: - raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( + if not neighbor_nums or len(neighbor_nums) > 6: + raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( 'neighbor_nums', len(neighbor_nums))) # check neighbor_types; required argument neighbor_types = param_dict.get("neighbor_types") check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') - if len(neighbor_nums) > 6: - raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( + if not neighbor_types or len(neighbor_types) > 6: + raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( 'neighbor_types', len(neighbor_types))) if len(neighbor_nums) != len(neighbor_types): @@ -1256,7 +1256,7 @@ def check_gnn_random_walk(method): return new_method -def check_aligned_list(param, param_name, membor_type): +def check_aligned_list(param, param_name, member_type): """Check whether the structure of each member of the list is the same.""" if not isinstance(param, list): @@ -1264,27 +1264,27 @@ def check_aligned_list(param, param_name, membor_type): if not param: raise TypeError( "Parameter {0} or its members are empty".format(param_name)) - membor_have_list = None + member_have_list = None list_len = None - for membor in param: - if isinstance(membor, list): - check_aligned_list(membor, param_name, membor_type) - if membor_have_list not in (None, True): + for member in param: + if isinstance(member, list): + check_aligned_list(member, param_name, member_type) + if member_have_list not in (None, True): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) - if list_len is not None and len(membor) != list_len: + if list_len is not None and len(member) != list_len: raise TypeError("The size of each member of parameter {0} is inconsistent".format( param_name)) - membor_have_list = True - list_len = len(membor) + member_have_list = True + list_len = len(member) else: - if not isinstance(membor, membor_type): - raise TypeError("Each membor in {0} should be of type int. Got {1}.".format( - param_name, type(membor))) - if membor_have_list not in (None, False): + if not isinstance(member, member_type): + raise TypeError("Each member in {0} should be of type int. Got {1}.".format( + param_name, type(member))) + if member_have_list not in (None, False): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) - membor_have_list = False + member_have_list = False def check_gnn_get_node_feature(method): @@ -1300,7 +1300,7 @@ def check_gnn_get_node_feature(method): check_aligned_list(node_list, 'node_list', int) elif isinstance(node_list, np.ndarray): if not node_list.dtype == np.int32: - raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( node_list, node_list.dtype)) else: raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index ce2aca4ffd0..dc74e66b0c0 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -158,6 +158,18 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); + neighbors.reset(); + s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); + + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); + + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); + EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); + neighbors.reset(); s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != @@ -198,9 +210,17 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); + neg_neighbors.reset(); + s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors); + EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); + + neg_neighbors.reset(); + s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors); + EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); + neg_neighbors.reset(); s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); - EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); + EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); } TEST_F(MindDataTestGNNGraph, TestRandomWalk) {