forked from mindspore-Ecosystem/mindspore
!2507 fix misspell and check parameters on graphdata
Merge pull request !2507 from heleiwang/r0.5_fix_misspell
This commit is contained in:
commit
572236bdd7
|
@ -149,14 +149,37 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::CheckSamplesNum(NodeIdType samples_num) {
|
||||
NodeIdType all_nodes_number =
|
||||
std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
|
||||
[](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
|
||||
if ((samples_num < 1) || (samples_num > all_nodes_number)) {
|
||||
std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) +
|
||||
", got " + std::to_string(samples_num);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
|
||||
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
|
||||
for (const auto &num : neighbor_nums) {
|
||||
RETURN_IF_NOT_OK(CheckSamplesNum(num));
|
||||
}
|
||||
for (const auto &type : neighbor_types) {
|
||||
if (node_type_map_.find(type) == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid neighbor type:" + std::to_string(type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size());
|
||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||
std::shared_ptr<Node> input_node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node));
|
||||
neighbors_vec[node_idx].emplace_back(node_list[node_idx]);
|
||||
std::vector<NodeIdType> input_list = {node_list[node_idx]};
|
||||
for (size_t i = 0; i < neighbor_nums.size(); ++i) {
|
||||
|
@ -204,6 +227,12 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno
|
|||
Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
||||
if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIdType>> neighbors_vec;
|
||||
neighbors_vec.resize(node_list.size());
|
||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||
|
@ -266,7 +295,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
|
|||
if (!nodes || nodes->Size() == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
|
||||
TensorRow tensors;
|
||||
for (const auto &f_type : feature_types) {
|
||||
std::shared_ptr<Feature> default_feature;
|
||||
|
|
|
@ -226,6 +226,8 @@ class Graph {
|
|||
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
|
||||
int32_t samples_num, std::vector<NodeIdType> *out_samples);
|
||||
|
||||
Status CheckSamplesNum(NodeIdType samples_num);
|
||||
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
std::mt19937 rnd_;
|
||||
|
|
|
@ -1109,10 +1109,10 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|||
for m in param:
|
||||
if not isinstance(m, int):
|
||||
raise TypeError(
|
||||
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
||||
"Each member in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
||||
elif isinstance(param, np.ndarray):
|
||||
if not param.dtype == np.int32:
|
||||
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
param_name, param.dtype))
|
||||
else:
|
||||
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||
|
@ -1195,15 +1195,15 @@ def check_gnn_get_sampled_neighbors(method):
|
|||
# check neighbor_nums; required argument
|
||||
neighbor_nums = param_dict.get("neighbor_nums")
|
||||
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
|
||||
if len(neighbor_nums) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
||||
if not neighbor_nums or len(neighbor_nums) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
||||
'neighbor_nums', len(neighbor_nums)))
|
||||
|
||||
# check neighbor_types; required argument
|
||||
neighbor_types = param_dict.get("neighbor_types")
|
||||
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
||||
if len(neighbor_nums) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
||||
if not neighbor_types or len(neighbor_types) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
||||
'neighbor_types', len(neighbor_types)))
|
||||
|
||||
if len(neighbor_nums) != len(neighbor_types):
|
||||
|
@ -1255,7 +1255,7 @@ def check_gnn_random_walk(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_aligned_list(param, param_name, membor_type):
|
||||
def check_aligned_list(param, param_name, member_type):
|
||||
"""Check whether the structure of each member of the list is the same."""
|
||||
|
||||
if not isinstance(param, list):
|
||||
|
@ -1263,27 +1263,27 @@ def check_aligned_list(param, param_name, membor_type):
|
|||
if not param:
|
||||
raise TypeError(
|
||||
"Parameter {0} or its members are empty".format(param_name))
|
||||
membor_have_list = None
|
||||
member_have_list = None
|
||||
list_len = None
|
||||
for membor in param:
|
||||
if isinstance(membor, list):
|
||||
check_aligned_list(membor, param_name, membor_type)
|
||||
if membor_have_list not in (None, True):
|
||||
for member in param:
|
||||
if isinstance(member, list):
|
||||
check_aligned_list(member, param_name, member_type)
|
||||
if member_have_list not in (None, True):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
if list_len is not None and len(membor) != list_len:
|
||||
if list_len is not None and len(member) != list_len:
|
||||
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
membor_have_list = True
|
||||
list_len = len(membor)
|
||||
member_have_list = True
|
||||
list_len = len(member)
|
||||
else:
|
||||
if not isinstance(membor, membor_type):
|
||||
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
|
||||
param_name, type(membor)))
|
||||
if membor_have_list not in (None, False):
|
||||
if not isinstance(member, member_type):
|
||||
raise TypeError("Each member in {0} should be of type int. Got {1}.".format(
|
||||
param_name, type(member)))
|
||||
if member_have_list not in (None, False):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
membor_have_list = False
|
||||
member_have_list = False
|
||||
|
||||
|
||||
def check_gnn_get_node_feature(method):
|
||||
|
@ -1299,7 +1299,7 @@ def check_gnn_get_node_feature(method):
|
|||
check_aligned_list(node_list, 'node_list', int)
|
||||
elif isinstance(node_list, np.ndarray):
|
||||
if not node_list.dtype == np.int32:
|
||||
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
||||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
node_list, node_list.dtype))
|
||||
else:
|
||||
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
||||
|
|
|
@ -158,6 +158,18 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
|||
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
|
||||
|
@ -198,9 +210,17 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
|||
s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
|
||||
|
||||
neg_neighbors.reset();
|
||||
s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
|
||||
|
||||
neg_neighbors.reset();
|
||||
s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
|
||||
|
||||
neg_neighbors.reset();
|
||||
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
||||
|
|
Loading…
Reference in New Issue