fix misspell and check parameters

This commit is contained in:
heleiwang 2020-06-23 17:50:05 +08:00
parent 5b14292f69
commit 0d52888fc5
4 changed files with 73 additions and 22 deletions

View File

@ -149,14 +149,37 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
return Status::OK();
}
Status Graph::CheckSamplesNum(NodeIdType samples_num) {
NodeIdType all_nodes_number =
std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
[](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
if ((samples_num < 1) || (samples_num > all_nodes_number)) {
std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) +
", got " + std::to_string(samples_num);
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
for (const auto &num : neighbor_nums) {
RETURN_IF_NOT_OK(CheckSamplesNum(num));
}
for (const auto &type : neighbor_types) {
if (node_type_map_.find(type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
std::shared_ptr<Node> input_node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node));
neighbors_vec[node_idx].emplace_back(node_list[node_idx]);
std::vector<NodeIdType> input_list = {node_list[node_idx]};
for (size_t i = 0; i < neighbor_nums.size(); ++i) {
@ -204,6 +227,12 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno
Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<std::vector<NodeIdType>> neighbors_vec;
neighbors_vec.resize(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {

View File

@ -226,6 +226,8 @@ class Graph {
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
int32_t samples_num, std::vector<NodeIdType> *out_samples);
Status CheckSamplesNum(NodeIdType samples_num);
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_;

View File

@ -1110,10 +1110,10 @@ def check_gnn_list_or_ndarray(param, param_name):
for m in param:
if not isinstance(m, int):
raise TypeError(
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
"Each member in {0} should be of type int. Got {1}.".format(param_name, type(m)))
elif isinstance(param, np.ndarray):
if not param.dtype == np.int32:
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))
else:
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
@ -1196,15 +1196,15 @@ def check_gnn_get_sampled_neighbors(method):
# check neighbor_nums; required argument
neighbor_nums = param_dict.get("neighbor_nums")
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
if len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
if not neighbor_nums or len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
'neighbor_nums', len(neighbor_nums)))
# check neighbor_types; required argument
neighbor_types = param_dict.get("neighbor_types")
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
if len(neighbor_nums) > 6:
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
if not neighbor_types or len(neighbor_types) > 6:
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
'neighbor_types', len(neighbor_types)))
if len(neighbor_nums) != len(neighbor_types):
@ -1256,7 +1256,7 @@ def check_gnn_random_walk(method):
return new_method
def check_aligned_list(param, param_name, membor_type):
def check_aligned_list(param, param_name, member_type):
"""Check whether the structure of each member of the list is the same."""
if not isinstance(param, list):
@ -1264,27 +1264,27 @@ def check_aligned_list(param, param_name, membor_type):
if not param:
raise TypeError(
"Parameter {0} or its members are empty".format(param_name))
membor_have_list = None
member_have_list = None
list_len = None
for membor in param:
if isinstance(membor, list):
check_aligned_list(membor, param_name, membor_type)
if membor_have_list not in (None, True):
for member in param:
if isinstance(member, list):
check_aligned_list(member, param_name, member_type)
if member_have_list not in (None, True):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
if list_len is not None and len(membor) != list_len:
if list_len is not None and len(member) != list_len:
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
param_name))
membor_have_list = True
list_len = len(membor)
member_have_list = True
list_len = len(member)
else:
if not isinstance(membor, membor_type):
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
param_name, type(membor)))
if membor_have_list not in (None, False):
if not isinstance(member, member_type):
raise TypeError("Each member in {0} should be of type int. Got {1}.".format(
param_name, type(member)))
if member_have_list not in (None, False):
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
param_name))
membor_have_list = False
member_have_list = False
def check_gnn_get_node_feature(method):
@ -1300,7 +1300,7 @@ def check_gnn_get_node_feature(method):
check_aligned_list(node_list, 'node_list', int)
elif isinstance(node_list, np.ndarray):
if not node_list.dtype == np.int32:
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
node_list, node_list.dtype))
else:
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(

View File

@ -158,6 +158,18 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
@ -198,9 +210,17 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neg_neighbors.reset();
s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
neg_neighbors.reset();
s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
neg_neighbors.reset();
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
}
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {